diff options
| author | Aldo Cortesi <aldo@nullcube.com> | 2012-06-21 10:56:30 +1200 | 
|---|---|---|
| committer | Aldo Cortesi <aldo@nullcube.com> | 2012-06-21 10:56:30 +1200 | 
| commit | 1089a52f3d16c4fef504586cae18a5d324e8d75c (patch) | |
| tree | 3616be49d90c63d3366254fa8b6894844c6484c9 | |
| parent | de00497b4098fece043a1550a94e112829bcbceb (diff) | |
| download | mitmproxy-1089a52f3d16c4fef504586cae18a5d324e8d75c.tar.gz mitmproxy-1089a52f3d16c4fef504586cae18a5d324e8d75c.tar.bz2 mitmproxy-1089a52f3d16c4fef504586cae18a5d324e8d75c.zip | |
Disconnect, rest refactoring.
| -rw-r--r-- | libpathod/pathod.py | 36 | ||||
| -rw-r--r-- | libpathod/rparse.py | 13 | ||||
| -rw-r--r-- | libpathod/utils.py | 25 | ||||
| -rw-r--r-- | test/test_pathod.py | 10 | ||||
| -rw-r--r-- | test/test_rparse.py | 18 | 
5 files changed, 39 insertions, 63 deletions
| diff --git a/libpathod/pathod.py b/libpathod/pathod.py index 8a29b9cb..e0a0764f 100644 --- a/libpathod/pathod.py +++ b/libpathod/pathod.py @@ -18,6 +18,10 @@ class PathodHandler(tcp.BaseHandler):              return None          method, path, httpversion = protocol.parse_init_http(line) +        headers = odict.ODictCaseless(protocol.read_headers(self.rfile)) +        content = protocol.read_http_body_request( +                    self.rfile, self.wfile, headers, httpversion, None +                )          if path.startswith(self.server.prefix):              spec = urllib.unquote(path)[len(self.server.prefix):]              try: @@ -27,24 +31,20 @@ class PathodHandler(tcp.BaseHandler):                      800,                      "Error parsing response spec: %s\n"%v.msg + v.marked()                  ) -            presp.serve(self.wfile) -            self.finish() -            return - -        headers = odict.ODictCaseless(protocol.read_headers(self.rfile)) -        content = protocol.read_http_body_request( -                    self.rfile, self.wfile, headers, httpversion, None -                ) -        cc = wsgi.ClientConn(self.client_address) -        req = wsgi.Request(cc, "http", method, path, headers, content) -        sn = self.connection.getsockname() -        app = wsgi.WSGIAdaptor( -            self.server.app, -            sn[0], -            self.server.port, -            version.NAMEVERSION -        ) -        app.serve(req, self.wfile) +            ret = presp.serve(self.wfile) +            if ret["disconnect"]: +                self.close() +        else: +            cc = wsgi.ClientConn(self.client_address) +            req = wsgi.Request(cc, "http", method, path, headers, content) +            sn = self.connection.getsockname() +            app = wsgi.WSGIAdaptor( +                self.server.app, +                sn[0], +                self.server.port, +                version.NAMEVERSION +            ) +            app.serve(req, self.wfile)  class Pathod(tcp.TCPServer): diff --git a/libpathod/rparse.py b/libpathod/rparse.py index 677c6b54..47084520 100644 --- a/libpathod/rparse.py +++ b/libpathod/rparse.py @@ -390,6 +390,9 @@ class Response:          return ret      def write_values(self, fp, vals, actions, sofar=0, skip=0, blocksize=BLOCKSIZE): +        """ +            Return True if connection should disconnect. +        """          while vals:              part = vals.pop()              for i in range(skip, len(part), blocksize): @@ -401,18 +404,15 @@ class Response:                      if p[1] == "pause":                          fp.write(d[:offset])                          time.sleep(p[2]) -                        self.write_values( +                        return self.write_values(                              fp, vals, actions,                              sofar=sofar+offset,                              skip=i+offset,                              blocksize=blocksize                          ) -                        return                      elif p[1] == "disconnect":                          fp.write(d[:offset]) -                        fp.finish() -                        fp.connection.stream.close() -                        return +                        return True                  fp.write(d)                  sofar += len(d)              skip = 0 @@ -447,9 +447,10 @@ class Response:          vals.reverse()          actions = self.ready_actions(self.length(), self.actions)          actions.reverse() -        self.write_values(fp, vals, actions[:]) +        disconnect = self.write_values(fp, vals, actions[:])          duration = time.time() - started          return dict( +            disconnect = disconnect,              started = started,              duration = duration,              actions = actions, diff --git a/libpathod/utils.py b/libpathod/utils.py index 0e3bda9d..f421b8a6 100644 --- a/libpathod/utils.py +++ b/libpathod/utils.py @@ -4,31 +4,6 @@ import rparse  class AnchorError(Exception): pass -class Sponge: -    def __getattr__(self, x): -        return Sponge() - -    def __call__(self, *args, **kwargs): -        pass - - -class DummyRequest: -    connection = Sponge() -    def __init__(self): -        self.buf = [] - -    def write(self, d, callback=None): -        self.buf.append(str(d)) -        if callback: -            callback() - -    def getvalue(self): -        return "".join(self.buf) - -    def finish(self): -        return - -  def parse_anchor_spec(s, settings):      """          For now, this is very simple, and you can't have an '=' in your regular diff --git a/test/test_pathod.py b/test/test_pathod.py index 3fd2388a..36a2d090 100644 --- a/test/test_pathod.py +++ b/test/test_pathod.py @@ -1,7 +1,6 @@  from libpathod import pathod -from tornado import httpserver -class TestApplication: +class _TestApplication:      def test_anchors(self):          a = pathod.PathodApp(staticdir=None)          a.add_anchor("/foo", "200") @@ -30,6 +29,7 @@ class TestApplication:          assert not a.log_by_id(0) -def test_make_server(): -    app = pathod.PathodApp() -    assert pathod.make_server(app, 0, "127.0.0.1", None) +class TestPathod: +    def test_instantiation(self): +        pathod.Pathod(("127.0.0.1", 0)) +         diff --git a/test/test_rparse.py b/test/test_rparse.py index f0db75fd..0813f22e 100644 --- a/test/test_rparse.py +++ b/test/test_rparse.py @@ -1,4 +1,4 @@ -import os +import os, cStringIO  from libpathod import rparse, utils  import tutils @@ -131,7 +131,7 @@ class TestMisc:          assert r.msg.val == "Unknown code"      def test_internal_response(self): -        d = utils.DummyRequest() +        d = cStringIO.StringIO()          s = rparse.InternalResponse(400, "foo")          s.serve(d) @@ -245,7 +245,7 @@ class TestResponse:      def test_write_values_disconnects(self):          r = self.dummy_response() -        s = utils.DummyRequest() +        s = cStringIO.StringIO()          tst = "foo"*100          r.write_values(s, [tst], [(0, "disconnect")], blocksize=5)          assert not s.getvalue() @@ -254,7 +254,7 @@ class TestResponse:          tst = "foo"*1025          r = rparse.parse({}, "400'msg'") -        s = utils.DummyRequest() +        s = cStringIO.StringIO()          r.write_values(s, [tst], [])          assert s.getvalue() == tst @@ -263,29 +263,29 @@ class TestResponse:          r = rparse.parse({}, "400'msg'")          for i in range(2, 10): -            s = utils.DummyRequest() +            s = cStringIO.StringIO()              r.write_values(s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i)              assert s.getvalue() == tst          for i in range(2, 10): -            s = utils.DummyRequest() +            s = cStringIO.StringIO()              r.write_values(s, [tst], [(1, "pause", 0)], blocksize=i)              assert s.getvalue() == tst          tst = ["".join(str(i) for i in range(10))]*5          for i in range(2, 10): -            s = utils.DummyRequest() +            s = cStringIO.StringIO()              r.write_values(s, tst[:], [(1, "pause", 0)], blocksize=i)              assert s.getvalue() == "".join(tst)      def test_render(self): -        s = utils.DummyRequest() +        s = cStringIO.StringIO()          r = rparse.parse({}, "400'msg'")          assert r.serve(s)      def test_length(self):          def testlen(x): -            s = utils.DummyRequest() +            s = cStringIO.StringIO()              x.serve(s)              assert x.length() == len(s.getvalue())          testlen(rparse.parse({}, "400'msg'")) | 
