aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2012-06-21 10:56:30 +1200
committerAldo Cortesi <aldo@nullcube.com>2012-06-21 10:56:30 +1200
commit1089a52f3d16c4fef504586cae18a5d324e8d75c (patch)
tree3616be49d90c63d3366254fa8b6894844c6484c9
parentde00497b4098fece043a1550a94e112829bcbceb (diff)
downloadmitmproxy-1089a52f3d16c4fef504586cae18a5d324e8d75c.tar.gz
mitmproxy-1089a52f3d16c4fef504586cae18a5d324e8d75c.tar.bz2
mitmproxy-1089a52f3d16c4fef504586cae18a5d324e8d75c.zip
Disconnect, rest refactoring.
-rw-r--r--libpathod/pathod.py36
-rw-r--r--libpathod/rparse.py13
-rw-r--r--libpathod/utils.py25
-rw-r--r--test/test_pathod.py10
-rw-r--r--test/test_rparse.py18
5 files changed, 39 insertions, 63 deletions
diff --git a/libpathod/pathod.py b/libpathod/pathod.py
index 8a29b9cb..e0a0764f 100644
--- a/libpathod/pathod.py
+++ b/libpathod/pathod.py
@@ -18,6 +18,10 @@ class PathodHandler(tcp.BaseHandler):
return None
method, path, httpversion = protocol.parse_init_http(line)
+ headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
+ content = protocol.read_http_body_request(
+ self.rfile, self.wfile, headers, httpversion, None
+ )
if path.startswith(self.server.prefix):
spec = urllib.unquote(path)[len(self.server.prefix):]
try:
@@ -27,24 +31,20 @@ class PathodHandler(tcp.BaseHandler):
800,
"Error parsing response spec: %s\n"%v.msg + v.marked()
)
- presp.serve(self.wfile)
- self.finish()
- return
-
- headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
- content = protocol.read_http_body_request(
- self.rfile, self.wfile, headers, httpversion, None
- )
- cc = wsgi.ClientConn(self.client_address)
- req = wsgi.Request(cc, "http", method, path, headers, content)
- sn = self.connection.getsockname()
- app = wsgi.WSGIAdaptor(
- self.server.app,
- sn[0],
- self.server.port,
- version.NAMEVERSION
- )
- app.serve(req, self.wfile)
+ ret = presp.serve(self.wfile)
+ if ret["disconnect"]:
+ self.close()
+ else:
+ cc = wsgi.ClientConn(self.client_address)
+ req = wsgi.Request(cc, "http", method, path, headers, content)
+ sn = self.connection.getsockname()
+ app = wsgi.WSGIAdaptor(
+ self.server.app,
+ sn[0],
+ self.server.port,
+ version.NAMEVERSION
+ )
+ app.serve(req, self.wfile)
class Pathod(tcp.TCPServer):
diff --git a/libpathod/rparse.py b/libpathod/rparse.py
index 677c6b54..47084520 100644
--- a/libpathod/rparse.py
+++ b/libpathod/rparse.py
@@ -390,6 +390,9 @@ class Response:
return ret
def write_values(self, fp, vals, actions, sofar=0, skip=0, blocksize=BLOCKSIZE):
+ """
+ Return True if connection should disconnect.
+ """
while vals:
part = vals.pop()
for i in range(skip, len(part), blocksize):
@@ -401,18 +404,15 @@ class Response:
if p[1] == "pause":
fp.write(d[:offset])
time.sleep(p[2])
- self.write_values(
+ return self.write_values(
fp, vals, actions,
sofar=sofar+offset,
skip=i+offset,
blocksize=blocksize
)
- return
elif p[1] == "disconnect":
fp.write(d[:offset])
- fp.finish()
- fp.connection.stream.close()
- return
+ return True
fp.write(d)
sofar += len(d)
skip = 0
@@ -447,9 +447,10 @@ class Response:
vals.reverse()
actions = self.ready_actions(self.length(), self.actions)
actions.reverse()
- self.write_values(fp, vals, actions[:])
+ disconnect = self.write_values(fp, vals, actions[:])
duration = time.time() - started
return dict(
+ disconnect = disconnect,
started = started,
duration = duration,
actions = actions,
diff --git a/libpathod/utils.py b/libpathod/utils.py
index 0e3bda9d..f421b8a6 100644
--- a/libpathod/utils.py
+++ b/libpathod/utils.py
@@ -4,31 +4,6 @@ import rparse
class AnchorError(Exception): pass
-class Sponge:
- def __getattr__(self, x):
- return Sponge()
-
- def __call__(self, *args, **kwargs):
- pass
-
-
-class DummyRequest:
- connection = Sponge()
- def __init__(self):
- self.buf = []
-
- def write(self, d, callback=None):
- self.buf.append(str(d))
- if callback:
- callback()
-
- def getvalue(self):
- return "".join(self.buf)
-
- def finish(self):
- return
-
-
def parse_anchor_spec(s, settings):
"""
For now, this is very simple, and you can't have an '=' in your regular
diff --git a/test/test_pathod.py b/test/test_pathod.py
index 3fd2388a..36a2d090 100644
--- a/test/test_pathod.py
+++ b/test/test_pathod.py
@@ -1,7 +1,6 @@
from libpathod import pathod
-from tornado import httpserver
-class TestApplication:
+class _TestApplication:
def test_anchors(self):
a = pathod.PathodApp(staticdir=None)
a.add_anchor("/foo", "200")
@@ -30,6 +29,7 @@ class TestApplication:
assert not a.log_by_id(0)
-def test_make_server():
- app = pathod.PathodApp()
- assert pathod.make_server(app, 0, "127.0.0.1", None)
+class TestPathod:
+ def test_instantiation(self):
+ pathod.Pathod(("127.0.0.1", 0))
+
diff --git a/test/test_rparse.py b/test/test_rparse.py
index f0db75fd..0813f22e 100644
--- a/test/test_rparse.py
+++ b/test/test_rparse.py
@@ -1,4 +1,4 @@
-import os
+import os, cStringIO
from libpathod import rparse, utils
import tutils
@@ -131,7 +131,7 @@ class TestMisc:
assert r.msg.val == "Unknown code"
def test_internal_response(self):
- d = utils.DummyRequest()
+ d = cStringIO.StringIO()
s = rparse.InternalResponse(400, "foo")
s.serve(d)
@@ -245,7 +245,7 @@ class TestResponse:
def test_write_values_disconnects(self):
r = self.dummy_response()
- s = utils.DummyRequest()
+ s = cStringIO.StringIO()
tst = "foo"*100
r.write_values(s, [tst], [(0, "disconnect")], blocksize=5)
assert not s.getvalue()
@@ -254,7 +254,7 @@ class TestResponse:
tst = "foo"*1025
r = rparse.parse({}, "400'msg'")
- s = utils.DummyRequest()
+ s = cStringIO.StringIO()
r.write_values(s, [tst], [])
assert s.getvalue() == tst
@@ -263,29 +263,29 @@ class TestResponse:
r = rparse.parse({}, "400'msg'")
for i in range(2, 10):
- s = utils.DummyRequest()
+ s = cStringIO.StringIO()
r.write_values(s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i)
assert s.getvalue() == tst
for i in range(2, 10):
- s = utils.DummyRequest()
+ s = cStringIO.StringIO()
r.write_values(s, [tst], [(1, "pause", 0)], blocksize=i)
assert s.getvalue() == tst
tst = ["".join(str(i) for i in range(10))]*5
for i in range(2, 10):
- s = utils.DummyRequest()
+ s = cStringIO.StringIO()
r.write_values(s, tst[:], [(1, "pause", 0)], blocksize=i)
assert s.getvalue() == "".join(tst)
def test_render(self):
- s = utils.DummyRequest()
+ s = cStringIO.StringIO()
r = rparse.parse({}, "400'msg'")
assert r.serve(s)
def test_length(self):
def testlen(x):
- s = utils.DummyRequest()
+ s = cStringIO.StringIO()
x.serve(s)
assert x.length() == len(s.getvalue())
testlen(rparse.parse({}, "400'msg'"))