diff options
| -rw-r--r-- | examples/stub.py | 13 | ||||
| -rw-r--r-- | libmproxy/flow.py | 4 | ||||
| -rw-r--r-- | libmproxy/protocol/base.py | 10 | ||||
| -rw-r--r-- | libmproxy/protocol/http.py | 6 | ||||
| -rw-r--r-- | libmproxy/protocol/http_replay.py | 4 | ||||
| -rw-r--r-- | libmproxy/proxy/server.py | 5 | ||||
| -rw-r--r-- | test/test_proxy.py | 7 | ||||
| -rw-r--r-- | test/test_server.py | 4 | 
8 files changed, 38 insertions, 15 deletions
diff --git a/examples/stub.py b/examples/stub.py index d5502a47..bd3e7cd0 100644 --- a/examples/stub.py +++ b/examples/stub.py @@ -10,7 +10,7 @@ def start(context, argv):      context.log("start") -def clientconnect(context, conn_handler): +def clientconnect(context, root_layer):      """          Called when a client initiates a connection to the proxy. Note that a          connection can correspond to multiple HTTP requests @@ -18,7 +18,7 @@ def clientconnect(context, conn_handler):      context.log("clientconnect") -def serverconnect(context, conn_handler): +def serverconnect(context, server_connection):      """          Called when the proxy initiates a connection to the target server. Note that a          connection can correspond to multiple HTTP requests @@ -58,7 +58,14 @@ def error(context, flow):      context.log("error") -def clientdisconnect(context, conn_handler): +def serverdisconnect(context, server_connection): +    """ +        Called when the proxy closes the connection to the target server. +    """ +    context.log("serverdisconnect") + + +def clientdisconnect(context, root_layer):      """          Called when a client disconnects from the proxy.      """ diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 00ec83d2..5eac8da9 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -961,6 +961,10 @@ class FlowMaster(controller.Master):          self.run_script_hook("serverconnect", sc)          sc.reply() +    def handle_serverdisconnect(self, sc): +        self.run_script_hook("serverdisconnect", sc) +        sc.reply() +      def handle_error(self, f):          self.state.update_flow(f)          self.run_script_hook("error", f) diff --git a/libmproxy/protocol/base.py b/libmproxy/protocol/base.py index 4eb843e4..40ec0536 100644 --- a/libmproxy/protocol/base.py +++ b/libmproxy/protocol/base.py @@ -48,9 +48,11 @@ class _LayerCodeCompletion(object):          if True:              return          self.config = None -        """@type: libmproxy.proxy.config.ProxyConfig""" +        """@type: libmproxy.proxy.ProxyConfig"""          self.client_conn = None -        """@type: libmproxy.proxy.connection.ClientConnection""" +        """@type: libmproxy.models.ClientConnection""" +        self.server_conn = None +        """@type: libmproxy.models.ServerConnection"""          self.channel = None          """@type: libmproxy.controller.Channel""" @@ -62,6 +64,7 @@ class Layer(_LayerCodeCompletion):              ctx: The (read-only) higher layer.          """          self.ctx = ctx +        """@type: libmproxy.protocol.Layer"""          super(Layer, self).__init__(*args, **kwargs)      def __call__(self): @@ -149,13 +152,14 @@ class ServerConnectionMixin(object):          self.log("serverdisconnect", "debug", [repr(self.server_conn.address)])          self.server_conn.finish()          self.server_conn.close() -        # self.channel.tell("serverdisconnect", self) +        self.channel.tell("serverdisconnect", self.server_conn)          self.server_conn = ServerConnection(None)      def connect(self):          if not self.server_conn.address:              raise ProtocolException("Cannot connect to server, no server address given.")          self.log("serverconnect", "debug", [repr(self.server_conn.address)]) +        self.channel.ask("serverconnect", self.server_conn)          try:              self.server_conn.connect()          except tcp.NetLibError as e: diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 3b62c389..f0f4ac24 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -418,7 +418,7 @@ class HttpLayer(Layer):          # call the appropriate script hook - this is an opportunity for an          # inline script to set flow.stream = True          flow = self.channel.ask("responseheaders", flow) -        if flow is None or flow == Kill: +        if flow == Kill:              raise Kill()          if self.supports_streaming: @@ -442,7 +442,7 @@ class HttpLayer(Layer):              [repr(flow.response)]          )          response_reply = self.channel.ask("response", flow) -        if response_reply is None or response_reply == Kill: +        if response_reply == Kill:              raise Kill()      def process_request_hook(self, flow): @@ -462,7 +462,7 @@ class HttpLayer(Layer):              flow.request.scheme = "https" if self.__original_server_conn.tls_established else "http"          request_reply = self.channel.ask("request", flow) -        if request_reply is None or request_reply == Kill: +        if request_reply == Kill:              raise Kill()          if isinstance(request_reply, HTTPResponse):              flow.response = request_reply diff --git a/libmproxy/protocol/http_replay.py b/libmproxy/protocol/http_replay.py index c37fd131..2759a019 100644 --- a/libmproxy/protocol/http_replay.py +++ b/libmproxy/protocol/http_replay.py @@ -36,7 +36,7 @@ class RequestReplayThread(threading.Thread):              # If we have a channel, run script hooks.              if self.channel:                  request_reply = self.channel.ask("request", self.flow) -                if request_reply is None or request_reply == Kill: +                if request_reply == Kill:                      raise Kill()                  elif isinstance(request_reply, HTTPResponse):                      self.flow.response = request_reply @@ -82,7 +82,7 @@ class RequestReplayThread(threading.Thread):                  )              if self.channel:                  response_reply = self.channel.ask("response", self.flow) -                if response_reply is None or response_reply == Kill: +                if response_reply == Kill:                      raise Kill()          except (HttpError, NetLibError) as v:              self.flow.error = Error(repr(v)) diff --git a/libmproxy/proxy/server.py b/libmproxy/proxy/server.py index b565ef86..e9e8df09 100644 --- a/libmproxy/proxy/server.py +++ b/libmproxy/proxy/server.py @@ -106,6 +106,10 @@ class ConnectionHandler(object):          self.log("clientconnect", "info")          root_layer = self._create_root_layer() +        root_layer = self.channel.ask("clientconnect", root_layer) +        if root_layer == Kill: +            def root_layer(): +                raise Kill()          try:              root_layer() @@ -128,6 +132,7 @@ class ConnectionHandler(object):              print("Please lodge a bug report at: https://github.com/mitmproxy/mitmproxy", file=sys.stderr)          self.log("clientdisconnect", "info") +        self.channel.tell("clientdisconnect", root_layer)          self.client_conn.finish()      def log(self, msg, level): diff --git a/test/test_proxy.py b/test/test_proxy.py index b9ca2cce..cc6a79d0 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -172,11 +172,16 @@ class TestConnectionHandler:          root_layer = mock.Mock()          root_layer.side_effect = RuntimeError          config.mode.return_value = root_layer +        channel = mock.Mock() + +        def ask(_, x): +            return x +        channel.ask = ask          c = ConnectionHandler(              mock.MagicMock(),              ("127.0.0.1", 8080),              config, -            mock.MagicMock() +            channel          )          with tutils.capture_stderr(c.handle) as output:              assert "mitmproxy has crashed" in output diff --git a/test/test_server.py b/test/test_server.py index 23d802ca..a1259b7f 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -611,13 +611,11 @@ class MasterRedirectRequest(tservers.TestMaster):      def handle_request(self, f):          if f.request.path == "/p/201": -            # This part should have no impact, but it should not cause any exceptions. +            # This part should have no impact, but it should also not cause any exceptions.              addr = f.live.server_conn.address              addr2 = Address(("127.0.0.1", self.redirect_port))              f.live.set_server(addr2) -            f.live.connect()              f.live.set_server(addr) -            f.live.connect()              # This is the actual redirection.              f.request.port = self.redirect_port  | 
