diff options
| -rw-r--r-- | mitmproxy/addons.py | 20 | ||||
| -rw-r--r-- | mitmproxy/builtins/serverplayback.py | 13 | ||||
| -rw-r--r-- | mitmproxy/console/master.py | 18 | ||||
| -rw-r--r-- | mitmproxy/console/statusbar.py | 10 | ||||
| -rw-r--r-- | mitmproxy/console/window.py | 20 | ||||
| -rw-r--r-- | mitmproxy/protocol/http_replay.py | 3 | ||||
| -rw-r--r-- | test/mitmproxy/protocol/test_http1.py | 12 | ||||
| -rw-r--r-- | test/mitmproxy/test_addons.py | 4 | ||||
| -rw-r--r-- | test/mitmproxy/test_fuzzing.py | 9 | ||||
| -rw-r--r-- | test/mitmproxy/test_server.py | 333 | ||||
| -rw-r--r-- | test/mitmproxy/tservers.py | 49 | 
11 files changed, 265 insertions, 226 deletions
| diff --git a/mitmproxy/addons.py b/mitmproxy/addons.py index 329d1215..2658c0af 100644 --- a/mitmproxy/addons.py +++ b/mitmproxy/addons.py @@ -4,7 +4,7 @@ import pprint  def _get_name(itm): -    return getattr(itm, "name", itm.__class__.__name__) +    return getattr(itm, "name", itm.__class__.__name__.lower())  class Addons(object): @@ -13,6 +13,16 @@ class Addons(object):          self.master = master          master.options.changed.connect(self.options_update) +    def get(self, name): +        """ +            Retrieve an addon by name. Addon names are equal to the .name +            attribute on the instance, or the lower case class name if that +            does not exist. +        """ +        for i in self.chain: +            if name == _get_name(i): +                return i +      def options_update(self, options, updated):          for i in self.chain:              with self.master.handlecontext(): @@ -39,14 +49,6 @@ class Addons(object):          for i in self.chain:              self.invoke_with_context(i, "done") -    def has_addon(self, name): -        """ -            Is an addon with this name registered? -        """ -        for i in self.chain: -            if _get_name(i) == name: -                return True -      def __len__(self):          return len(self.chain) diff --git a/mitmproxy/builtins/serverplayback.py b/mitmproxy/builtins/serverplayback.py index fe56d68b..be82cad9 100644 --- a/mitmproxy/builtins/serverplayback.py +++ b/mitmproxy/builtins/serverplayback.py @@ -88,13 +88,14 @@ class ServerPlayback(object):      def configure(self, options, updated):          self.options = options -        if options.server_replay and "server_replay" in updated: -            try: -                flows = flow.read_flows_from_paths(options.server_replay) -            except exceptions.FlowReadException as e: -                raise exceptions.OptionsError(str(e)) +        if "server_replay" in updated:              self.clear() -            self.load(flows) +            if options.server_replay: +                try: +                    flows = flow.read_flows_from_paths(options.server_replay) +                except exceptions.FlowReadException as e: +                    raise exceptions.OptionsError(str(e)) +                self.load(flows)          # FIXME: These options have to be renamed to something more sensible -          # prefixed with serverplayback_ where appropriate, and playback_ where diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py index a6942ca4..1cb3a32b 100644 --- a/mitmproxy/console/master.py +++ b/mitmproxy/console/master.py @@ -248,9 +248,6 @@ class ConsoleMaster(flow.FlowMaster):          if options.client_replay:              self.client_playback_path(options.client_replay) -        if options.server_replay: -            self.server_playback_path(options.server_replay) -          self.view_stack = []          if options.app: @@ -391,21 +388,6 @@ class ConsoleMaster(flow.FlowMaster):          if flows:              self.start_client_playback(flows, False) -    def server_playback_path(self, path): -        if not isinstance(path, list): -            path = [path] -        flows = self._readflows(path) -        if flows: -            self.start_server_playback( -                flows, -                self.options.kill, self.options.rheaders, -                False, self.options.nopop, -                self.options.replay_ignore_params, -                self.options.replay_ignore_content, -                self.options.replay_ignore_payload_params, -                self.options.replay_ignore_host -            ) -      def spawn_editor(self, data):          text = not isinstance(data, bytes)          fd, name = tempfile.mkstemp('', "mproxy", text=text) diff --git a/mitmproxy/console/statusbar.py b/mitmproxy/console/statusbar.py index 43d68d51..6c4cc8b5 100644 --- a/mitmproxy/console/statusbar.py +++ b/mitmproxy/console/statusbar.py @@ -147,14 +147,12 @@ class StatusBar(urwid.WidgetWrap):          if self.master.client_playback:              r.append("[")              r.append(("heading_key", "cplayback")) -            r.append(":%s to go]" % self.master.client_playback.count()) -        if self.master.server_playback: +            r.append(":%s]" % self.master.client_playback.count()) +        if self.master.options.server_replay:              r.append("[")              r.append(("heading_key", "splayback")) -            if self.master.options.nopop: -                r.append(":%s in file]" % self.master.server_playback.count()) -            else: -                r.append(":%s to go]" % self.master.server_playback.count()) +            a = self.master.addons.get("serverplayback") +            r.append(":%s]" % a.count())          if self.master.options.ignore_hosts:              r.append("[")              r.append(("heading_key", "I")) diff --git a/mitmproxy/console/window.py b/mitmproxy/console/window.py index 35593643..159f68ed 100644 --- a/mitmproxy/console/window.py +++ b/mitmproxy/console/window.py @@ -57,13 +57,11 @@ class Window(urwid.Frame):                      callback = self.master.stop_client_playback_prompt,                  )          elif k == "s": -            if not self.master.server_playback: -                signals.status_prompt_path.send( -                    self, -                    prompt = "Server replay path", -                    callback = self.master.server_playback_path -                ) -            else: +            a = self.master.addons.get("serverplayback") +            if a.count(): +                def stop_server_playback(response): +                    if response == "y": +                        self.master.options.server_replay = []                  signals.status_prompt_onekey.send(                      self,                      prompt = "Stop current server replay?", @@ -71,7 +69,13 @@ class Window(urwid.Frame):                          ("yes", "y"),                          ("no", "n"),                      ), -                    callback = self.master.stop_server_playback_prompt, +                    callback = stop_server_playback +                ) +            else: +                signals.status_prompt_path.send( +                    self, +                    prompt = "Server playback path", +                    callback = lambda x: self.master.options.setter("server_replay")([x])                  )      def keypress(self, size, k): diff --git a/mitmproxy/protocol/http_replay.py b/mitmproxy/protocol/http_replay.py index bfde06c5..877eaa22 100644 --- a/mitmproxy/protocol/http_replay.py +++ b/mitmproxy/protocol/http_replay.py @@ -33,6 +33,7 @@ class RequestReplayThread(basethread.BaseThread):      def run(self):          r = self.flow.request          first_line_format_backup = r.first_line_format +        server = None          try:              self.flow.response = None @@ -103,3 +104,5 @@ class RequestReplayThread(basethread.BaseThread):              self.channel.tell("log", Log(traceback.format_exc(), "error"))          finally:              r.first_line_format = first_line_format_backup +            if server: +                server.finish() diff --git a/test/mitmproxy/protocol/test_http1.py b/test/mitmproxy/protocol/test_http1.py index 7d04c56b..2fc4ac63 100644 --- a/test/mitmproxy/protocol/test_http1.py +++ b/test/mitmproxy/protocol/test_http1.py @@ -18,14 +18,15 @@ class TestInvalidRequests(tservers.HTTPProxyTest):      def test_double_connect(self):          p = self.pathoc() -        r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port)) +        with p.connect(): +            r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port))          assert r.status_code == 400          assert b"Invalid HTTP request form" in r.content      def test_relative_request(self):          p = self.pathoc_raw() -        p.connect() -        r = p.request("get:/p/200") +        with p.connect(): +            r = p.request("get:/p/200")          assert r.status_code == 400          assert b"Invalid HTTP request form" in r.content @@ -61,5 +62,8 @@ class TestHeadContentLength(tservers.HTTPProxyTest):      def test_head_content_length(self):          p = self.pathoc() -        resp = p.request("""head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase) +        with p.connect(): +            resp = p.request( +                """head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase +            )          assert resp.headers["Content-Length"] == "42" diff --git a/test/mitmproxy/test_addons.py b/test/mitmproxy/test_addons.py index a5085ea0..52d7f07f 100644 --- a/test/mitmproxy/test_addons.py +++ b/test/mitmproxy/test_addons.py @@ -17,5 +17,5 @@ def test_simple():      m = controller.Master(o)      a = addons.Addons(m)      a.add(o, TAddon("one")) -    assert a.has_addon("one") -    assert not a.has_addon("two") +    assert a.get("one") +    assert not a.get("two") diff --git a/test/mitmproxy/test_fuzzing.py b/test/mitmproxy/test_fuzzing.py index 27ea36a6..905ba1cd 100644 --- a/test/mitmproxy/test_fuzzing.py +++ b/test/mitmproxy/test_fuzzing.py @@ -11,17 +11,20 @@ class TestFuzzy(tservers.HTTPProxyTest):      def test_idna_err(self):          req = r'get:"http://localhost:%s":i10,"\xc6"'          p = self.pathoc() -        assert p.request(req % self.server.port).status_code == 400 +        with p.connect(): +            assert p.request(req % self.server.port).status_code == 400      def test_nullbytes(self):          req = r'get:"http://localhost:%s":i19,"\x00"'          p = self.pathoc() -        assert p.request(req % self.server.port).status_code == 400 +        with p.connect(): +            assert p.request(req % self.server.port).status_code == 400      def test_invalid_ipv6_url(self):          req = 'get:"http://localhost:%s":i13,"["'          p = self.pathoc() -        resp = p.request(req % self.server.port) +        with p.connect(): +            resp = p.request(req % self.server.port)          assert resp.status_code == 400      # def test_invalid_upstream(self): diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index e0a8da47..321bb11f 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -91,11 +91,11 @@ class CommonMixin:      def test_invalid_http(self):          t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) -        t.connect() -        t.wfile.write(b"invalid\r\n\r\n") -        t.wfile.flush() -        line = t.rfile.readline() -        assert (b"Bad Request" in line) or (b"Bad Gateway" in line) +        with t.connect(): +            t.wfile.write(b"invalid\r\n\r\n") +            t.wfile.flush() +            line = t.rfile.readline() +            assert (b"Bad Request" in line) or (b"Bad Gateway" in line)      def test_sni(self):          if not self.ssl: @@ -208,20 +208,22 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):      def test_app_err(self):          p = self.pathoc() -        ret = p.request("get:'http://errapp/'") +        with p.connect(): +            ret = p.request("get:'http://errapp/'")          assert ret.status_code == 500          assert b"ValueError" in ret.content      def test_invalid_connect(self):          t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) -        t.connect() -        t.wfile.write(b"CONNECT invalid\n\n") -        t.wfile.flush() -        assert b"Bad Request" in t.rfile.readline() +        with t.connect(): +            t.wfile.write(b"CONNECT invalid\n\n") +            t.wfile.flush() +            assert b"Bad Request" in t.rfile.readline()      def test_upstream_ssl_error(self):          p = self.pathoc() -        ret = p.request("get:'https://localhost:%s/'" % self.server.port) +        with p.connect(): +            ret = p.request("get:'https://localhost:%s/'" % self.server.port)          assert ret.status_code == 400      def test_connection_close(self): @@ -232,25 +234,28 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):          # Lets sanity check that the connection does indeed stay open by          # issuing two requests over the same connection          p = self.pathoc() -        assert p.request("get:'%s'" % response) -        assert p.request("get:'%s'" % response) +        with p.connect(): +            assert p.request("get:'%s'" % response) +            assert p.request("get:'%s'" % response)          # Now check that the connection is closed as the client specifies          p = self.pathoc() -        assert p.request("get:'%s':h'Connection'='close'" % response) -        # There's a race here, which means we can get any of a number of errors. -        # Rather than introduce yet another sleep into the test suite, we just -        # relax the Exception specification. -        with raises(Exception): -            p.request("get:'%s'" % response) +        with p.connect(): +            assert p.request("get:'%s':h'Connection'='close'" % response) +            # There's a race here, which means we can get any of a number of errors. +            # Rather than introduce yet another sleep into the test suite, we just +            # relax the Exception specification. +            with raises(Exception): +                p.request("get:'%s'" % response)      def test_reconnect(self):          req = "get:'%s/p/200:b@1:da'" % self.server.urlbase          p = self.pathoc() -        assert p.request(req) -        # Server has disconnected. Mitmproxy should detect this, and reconnect. -        assert p.request(req) -        assert p.request(req) +        with p.connect(): +            assert p.request(req) +            # Server has disconnected. Mitmproxy should detect this, and reconnect. +            assert p.request(req) +            assert p.request(req)      def test_get_connection_switching(self):          def switched(l): @@ -260,18 +265,21 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):          req = "get:'%s/p/200:b@1'"          p = self.pathoc() -        assert p.request(req % self.server.urlbase) -        assert p.request(req % self.server2.urlbase) +        with p.connect(): +            assert p.request(req % self.server.urlbase) +            assert p.request(req % self.server2.urlbase)          assert switched(self.proxy.tlog)      def test_blank_leading_line(self):          p = self.pathoc() -        req = "get:'%s/p/201':i0,'\r\n'" -        assert p.request(req % self.server.urlbase).status_code == 201 +        with p.connect(): +            req = "get:'%s/p/201':i0,'\r\n'" +            assert p.request(req % self.server.urlbase).status_code == 201      def test_invalid_headers(self):          p = self.pathoc() -        resp = p.request("get:'http://foo':h':foo'='bar'") +        with p.connect(): +            resp = p.request("get:'http://foo':h':foo'='bar'")          assert resp.status_code == 400      def test_stream(self): @@ -301,15 +309,16 @@ class TestHTTPAuth(tservers.HTTPProxyTest):          self.master.options.auth_singleuser = "test:test"          assert self.pathod("202").status_code == 407          p = self.pathoc() -        ret = p.request(""" -            get -            'http://localhost:%s/p/202' -            h'%s'='%s' -        """ % ( -            self.server.port, -            http.authentication.BasicProxyAuth.AUTH_HEADER, -            authentication.assemble_http_basic_auth("basic", "test", "test") -        )) +        with p.connect(): +            ret = p.request(""" +                get +                'http://localhost:%s/p/202' +                h'%s'='%s' +            """ % ( +                self.server.port, +                http.authentication.BasicProxyAuth.AUTH_HEADER, +                authentication.assemble_http_basic_auth("basic", "test", "test") +            ))          assert ret.status_code == 202 @@ -318,14 +327,15 @@ class TestHTTPReverseAuth(tservers.ReverseProxyTest):          self.master.options.auth_singleuser = "test:test"          assert self.pathod("202").status_code == 401          p = self.pathoc() -        ret = p.request(""" -            get -            '/p/202' -            h'%s'='%s' -        """ % ( -            http.authentication.BasicWebsiteAuth.AUTH_HEADER, -            authentication.assemble_http_basic_auth("basic", "test", "test") -        )) +        with p.connect(): +            ret = p.request(""" +                get +                '/p/202' +                h'%s'='%s' +            """ % ( +                http.authentication.BasicWebsiteAuth.AUTH_HEADER, +                authentication.assemble_http_basic_auth("basic", "test", "test") +            ))          assert ret.status_code == 202 @@ -354,7 +364,8 @@ class TestHTTPS(tservers.HTTPProxyTest, CommonMixin, TcpMixin):      def test_error_post_connect(self):          p = self.pathoc() -        assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400 +        with p.connect(): +            assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400  class TestHTTPSCertfile(tservers.HTTPProxyTest, CommonMixin): @@ -389,7 +400,8 @@ class TestHTTPSUpstreamServerVerificationWTrustedCert(tservers.HTTPProxyTest):      def _request(self):          p = self.pathoc(sni="example.mitmproxy.org") -        return p.request("get:/p/242") +        with p.connect(): +            return p.request("get:/p/242")      def test_verification_w_cadir(self):          self.config.options.update( @@ -426,7 +438,8 @@ class TestHTTPSUpstreamServerVerificationWBadCert(tservers.HTTPProxyTest):      def _request(self):          p = self.pathoc(sni="example.mitmproxy.org") -        return p.request("get:/p/242") +        with p.connect(): +            return p.request("get:/p/242")      @classmethod      def get_options(cls): @@ -481,13 +494,15 @@ class TestSocks5(tservers.SocksModeTest):      def test_simple(self):          p = self.pathoc() -        p.socks_connect(("localhost", self.server.port)) -        f = p.request("get:/p/200") +        with p.connect(): +            p.socks_connect(("localhost", self.server.port)) +            f = p.request("get:/p/200")          assert f.status_code == 200      def test_with_authentication_only(self):          p = self.pathoc() -        f = p.request("get:/p/200") +        with p.connect(): +            f = p.request("get:/p/200")          assert f.status_code == 502          assert b"SOCKS5 mode failure" in f.content @@ -496,21 +511,21 @@ class TestSocks5(tservers.SocksModeTest):          mitmproxy doesn't support UDP or BIND SOCKS CMDs          """          p = self.pathoc() - -        socks.ClientGreeting( -            socks.VERSION.SOCKS5, -            [socks.METHOD.NO_AUTHENTICATION_REQUIRED] -        ).to_file(p.wfile) -        socks.Message( -            socks.VERSION.SOCKS5, -            socks.CMD.BIND, -            socks.ATYP.DOMAINNAME, -            ("example.com", 8080) -        ).to_file(p.wfile) - -        p.wfile.flush() -        p.rfile.read(2)  # read server greeting -        f = p.request("get:/p/200")  # the request doesn't matter, error response from handshake will be read anyway. +        with p.connect(): +            socks.ClientGreeting( +                socks.VERSION.SOCKS5, +                [socks.METHOD.NO_AUTHENTICATION_REQUIRED] +            ).to_file(p.wfile) +            socks.Message( +                socks.VERSION.SOCKS5, +                socks.CMD.BIND, +                socks.ATYP.DOMAINNAME, +                ("example.com", 8080) +            ).to_file(p.wfile) + +            p.wfile.flush() +            p.rfile.read(2)  # read server greeting +            f = p.request("get:/p/200")  # the request doesn't matter, error response from handshake will be read anyway.          assert f.status_code == 502          assert b"SOCKS5 mode failure" in f.content @@ -531,21 +546,23 @@ class TestHttps2Http(tservers.ReverseProxyTest):          p = pathoc.Pathoc(              ("localhost", self.proxy.port), ssl=True, sni=sni, fp=None          ) -        p.connect()          return p      def test_all(self):          p = self.pathoc(ssl=True) -        assert p.request("get:'/p/200'").status_code == 200 +        with p.connect(): +            assert p.request("get:'/p/200'").status_code == 200      def test_sni(self):          p = self.pathoc(ssl=True, sni="example.com") -        assert p.request("get:'/p/200'").status_code == 200 -        assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog) +        with p.connect(): +            assert p.request("get:'/p/200'").status_code == 200 +            assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog)      def test_http(self):          p = self.pathoc(ssl=False) -        assert p.request("get:'/p/200'").status_code == 200 +        with p.connect(): +            assert p.request("get:'/p/200'").status_code == 200  class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin): @@ -703,29 +720,29 @@ class TestRedirectRequest(tservers.HTTPProxyTest):          self.master.redirect_port = self.server2.port          p = self.pathoc() - -        self.server.clear_log() -        self.server2.clear_log() -        r1 = p.request("get:'/p/200'") -        assert r1.status_code == 200 -        assert self.server.last_log() -        assert not self.server2.last_log() - -        self.server.clear_log() -        self.server2.clear_log() -        r2 = p.request("get:'/p/201'") -        assert r2.status_code == 201 -        assert not self.server.last_log() -        assert self.server2.last_log() - -        self.server.clear_log() -        self.server2.clear_log() -        r3 = p.request("get:'/p/202'") -        assert r3.status_code == 202 -        assert self.server.last_log() -        assert not self.server2.last_log() - -        assert r1.content == r2.content == r3.content +        with p.connect(): +            self.server.clear_log() +            self.server2.clear_log() +            r1 = p.request("get:'/p/200'") +            assert r1.status_code == 200 +            assert self.server.last_log() +            assert not self.server2.last_log() + +            self.server.clear_log() +            self.server2.clear_log() +            r2 = p.request("get:'/p/201'") +            assert r2.status_code == 201 +            assert not self.server.last_log() +            assert self.server2.last_log() + +            self.server.clear_log() +            self.server2.clear_log() +            r3 = p.request("get:'/p/202'") +            assert r3.status_code == 202 +            assert self.server.last_log() +            assert not self.server2.last_log() + +            assert r1.content == r2.content == r3.content  class MasterStreamRequest(tservers.TestMaster): @@ -743,22 +760,22 @@ class TestStreamRequest(tservers.HTTPProxyTest):      def test_stream_simple(self):          p = self.pathoc() - -        # a request with 100k of data but without content-length -        r1 = p.request("get:'%s/p/200:r:b@100k:d102400'" % self.server.urlbase) -        assert r1.status_code == 200 -        assert len(r1.content) > 100000 +        with p.connect(): +            # a request with 100k of data but without content-length +            r1 = p.request("get:'%s/p/200:r:b@100k:d102400'" % self.server.urlbase) +            assert r1.status_code == 200 +            assert len(r1.content) > 100000      def test_stream_multiple(self):          p = self.pathoc() +        with p.connect(): +            # simple request with streaming turned on +            r1 = p.request("get:'%s/p/200'" % self.server.urlbase) +            assert r1.status_code == 200 -        # simple request with streaming turned on -        r1 = p.request("get:'%s/p/200'" % self.server.urlbase) -        assert r1.status_code == 200 - -        # now send back 100k of data, streamed but not chunked -        r1 = p.request("get:'%s/p/201:b@100k'" % self.server.urlbase) -        assert r1.status_code == 201 +            # now send back 100k of data, streamed but not chunked +            r1 = p.request("get:'%s/p/201:b@100k'" % self.server.urlbase) +            assert r1.status_code == 201      def test_stream_chunked(self):          connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -887,7 +904,8 @@ class TestUpstreamProxy(tservers.HTTPUpstreamProxyTest, CommonMixin, AppMixin):              ("~s", "baz", "ORLY")          ]          p = self.pathoc() -        req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) +        with p.connect(): +            req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase)          assert req.content == b"ORLY"          assert req.status_code == 418 @@ -948,7 +966,8 @@ class TestUpstreamProxySSL(      def test_simple(self):          p = self.pathoc() -        req = p.request("get:'/p/418:b\"content\"'") +        with p.connect(): +            req = p.request("get:'/p/418:b\"content\"'")          assert req.content == b"content"          assert req.status_code == 418 @@ -1006,48 +1025,49 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest):          ])          p = self.pathoc() -        req = p.request("get:'/p/418:b\"content\"'") -        assert req.content == b"content" -        assert req.status_code == 418 - -        assert self.proxy.tmaster.state.flow_count() == 2  # CONNECT and request -        # CONNECT, failing request, -        assert self.chain[0].tmaster.state.flow_count() == 4 -        # reCONNECT, request -        # failing request, request -        assert self.chain[1].tmaster.state.flow_count() == 2 -        # (doesn't store (repeated) CONNECTs from chain[0] -        #  as it is a regular proxy) - -        assert not self.chain[1].tmaster.state.flows[0].response  # killed -        assert self.chain[1].tmaster.state.flows[1].response - -        assert self.proxy.tmaster.state.flows[0].request.first_line_format == "authority" -        assert self.proxy.tmaster.state.flows[1].request.first_line_format == "relative" - -        assert self.chain[0].tmaster.state.flows[ -            0].request.first_line_format == "authority" -        assert self.chain[0].tmaster.state.flows[ -            1].request.first_line_format == "relative" -        assert self.chain[0].tmaster.state.flows[ -            2].request.first_line_format == "authority" -        assert self.chain[0].tmaster.state.flows[ -            3].request.first_line_format == "relative" - -        assert self.chain[1].tmaster.state.flows[ -            0].request.first_line_format == "relative" -        assert self.chain[1].tmaster.state.flows[ -            1].request.first_line_format == "relative" - -        req = p.request("get:'/p/418:b\"content2\"'") - -        assert req.status_code == 502 -        assert self.proxy.tmaster.state.flow_count() == 3  # + new request -        # + new request, repeated CONNECT from chain[1] -        assert self.chain[0].tmaster.state.flow_count() == 6 -        # (both terminated) -        # nothing happened here -        assert self.chain[1].tmaster.state.flow_count() == 2 +        with p.connect(): +            req = p.request("get:'/p/418:b\"content\"'") +            assert req.content == b"content" +            assert req.status_code == 418 + +            assert self.proxy.tmaster.state.flow_count() == 2  # CONNECT and request +            # CONNECT, failing request, +            assert self.chain[0].tmaster.state.flow_count() == 4 +            # reCONNECT, request +            # failing request, request +            assert self.chain[1].tmaster.state.flow_count() == 2 +            # (doesn't store (repeated) CONNECTs from chain[0] +            #  as it is a regular proxy) + +            assert not self.chain[1].tmaster.state.flows[0].response  # killed +            assert self.chain[1].tmaster.state.flows[1].response + +            assert self.proxy.tmaster.state.flows[0].request.first_line_format == "authority" +            assert self.proxy.tmaster.state.flows[1].request.first_line_format == "relative" + +            assert self.chain[0].tmaster.state.flows[ +                0].request.first_line_format == "authority" +            assert self.chain[0].tmaster.state.flows[ +                1].request.first_line_format == "relative" +            assert self.chain[0].tmaster.state.flows[ +                2].request.first_line_format == "authority" +            assert self.chain[0].tmaster.state.flows[ +                3].request.first_line_format == "relative" + +            assert self.chain[1].tmaster.state.flows[ +                0].request.first_line_format == "relative" +            assert self.chain[1].tmaster.state.flows[ +                1].request.first_line_format == "relative" + +            req = p.request("get:'/p/418:b\"content2\"'") + +            assert req.status_code == 502 +            assert self.proxy.tmaster.state.flow_count() == 3  # + new request +            # + new request, repeated CONNECT from chain[1] +            assert self.chain[0].tmaster.state.flow_count() == 6 +            # (both terminated) +            # nothing happened here +            assert self.chain[1].tmaster.state.flow_count() == 2  class AddUpstreamCertsToClientChainMixin: @@ -1066,12 +1086,13 @@ class AddUpstreamCertsToClientChainMixin:              d = f.read()          upstreamCert = SSLCert.from_pem(d)          p = self.pathoc() -        upstream_cert_found_in_client_chain = False -        for receivedCert in p.server_certs: -            if receivedCert.digest('sha256') == upstreamCert.digest('sha256'): -                upstream_cert_found_in_client_chain = True -                break -        assert(upstream_cert_found_in_client_chain == self.master.options.add_upstream_certs_to_client_chain) +        with p.connect(): +            upstream_cert_found_in_client_chain = False +            for receivedCert in p.server_certs: +                if receivedCert.digest('sha256') == upstreamCert.digest('sha256'): +                    upstream_cert_found_in_client_chain = True +                    break +            assert(upstream_cert_found_in_client_chain == self.master.options.add_upstream_certs_to_client_chain)  class TestHTTPSAddUpstreamCertsToClientChainTrue( diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index 1597f59c..4291f743 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -3,6 +3,7 @@ import threading  import tempfile  import flask  import mock +import sys  from mitmproxy.proxy.config import ProxyConfig  from mitmproxy.proxy.server import ProxyServer @@ -10,6 +11,7 @@ import pathod.test  import pathod.pathoc  from mitmproxy import flow, controller, options  from mitmproxy import builtins +import netlib.exceptions  testapp = flask.Flask(__name__) @@ -104,6 +106,14 @@ class ProxyTestBase(object):          cls.server.shutdown()          cls.server2.shutdown() +    def teardown(self): +        try: +            self.server.wait_for_silence() +        except netlib.exceptions.Timeout: +            # FIXME: Track down the Windows sync issues +            if sys.platform != "win32": +                raise +      def setup(self):          self.master.clear_log()          self.master.state.clear() @@ -125,6 +135,15 @@ class ProxyTestBase(object):          ) +class LazyPathoc(pathod.pathoc.Pathoc): +    def __init__(self, lazy_connect, *args, **kwargs): +        self.lazy_connect = lazy_connect +        pathod.pathoc.Pathoc.__init__(self, *args, **kwargs) + +    def connect(self): +        return pathod.pathoc.Pathoc.connect(self, self.lazy_connect) + +  class HTTPProxyTest(ProxyTestBase):      def pathoc_raw(self): @@ -134,14 +153,14 @@ class HTTPProxyTest(ProxyTestBase):          """              Returns a connected Pathoc instance.          """ -        p = pathod.pathoc.Pathoc( -            ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None -        )          if self.ssl: -            p.connect(("127.0.0.1", self.server.port)) +            conn = ("127.0.0.1", self.server.port)          else: -            p.connect() -        return p +            conn = None +        return LazyPathoc( +            conn, +            ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None +        )      def pathod(self, spec, sni=None):          """ @@ -152,18 +171,20 @@ class HTTPProxyTest(ProxyTestBase):              q = "get:'/p/%s'" % spec          else:              q = "get:'%s/p/%s'" % (self.server.urlbase, spec) -        return p.request(q) +        with p.connect(): +            return p.request(q)      def app(self, page):          if self.ssl:              p = pathod.pathoc.Pathoc(                  ("127.0.0.1", self.proxy.port), True, fp=None              ) -            p.connect((options.APP_HOST, options.APP_PORT)) -            return p.request("get:'%s'" % page) +            with p.connect((options.APP_HOST, options.APP_PORT)): +                return p.request("get:'%s'" % page)          else:              p = self.pathoc() -            return p.request("get:'http://%s%s'" % (options.APP_HOST, page)) +            with p.connect(): +                return p.request("get:'http://%s%s'" % (options.APP_HOST, page))  class TResolver: @@ -210,7 +231,8 @@ class TransparentProxyTest(ProxyTestBase):          else:              p = self.pathoc()              q = "get:'/p/%s'" % spec -        return p.request(q) +        with p.connect(): +            return p.request(q)      def pathoc(self, sni=None):          """ @@ -219,7 +241,6 @@ class TransparentProxyTest(ProxyTestBase):          p = pathod.pathoc.Pathoc(              ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None          ) -        p.connect()          return p @@ -247,7 +268,6 @@ class ReverseProxyTest(ProxyTestBase):          p = pathod.pathoc.Pathoc(              ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None          ) -        p.connect()          return p      def pathod(self, spec, sni=None): @@ -260,7 +280,8 @@ class ReverseProxyTest(ProxyTestBase):          else:              p = self.pathoc()              q = "get:'/p/%s'" % spec -        return p.request(q) +        with p.connect(): +            return p.request(q)  class SocksModeTest(HTTPProxyTest): | 
