diff options
| -rw-r--r-- | libmproxy/protocol/__init__.py | 46 | ||||
| -rw-r--r-- | libmproxy/protocol/http.py | 1 | ||||
| -rw-r--r-- | libmproxy/proxy.py | 33 | 
3 files changed, 55 insertions, 25 deletions
| diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py index 78930e05..123c31e0 100644 --- a/libmproxy/protocol/__init__.py +++ b/libmproxy/protocol/__init__.py @@ -1,5 +1,6 @@ -KILL = 0  # const for killed requests +from ..proxy import ServerConnection, AddressPriority +KILL = 0  # const for killed requests  class ConnectionTypeChange(Exception):      """ @@ -12,7 +13,7 @@ class ConnectionTypeChange(Exception):  class ProtocolHandler(object):      def __init__(self, c):          self.c = c -        """@type : libmproxy.proxy.ConnectionHandler""" +        """@type: libmproxy.proxy.ConnectionHandler"""      def handle_messages(self):          """ @@ -36,13 +37,46 @@ class TemporaryServerChangeMixin(object):      """      def change_server(self, address, ssl): -        self._backup_server = True -        raise NotImplementedError("You must not change host port port.") +        if address == self.c.server_conn.address(): +            return +        priority = AddressPriority.MANUALLY_CHANGED + +        if self.c.server_conn.priority > priority: +            self.log("Attempt to change server address, " +                     "but priority is too low (is: %s, got: %s)" % (self.server_conn.priority, priority)) +            return + +        self.log("Temporarily change server connection: %s:%s -> %s:%s" % ( +            self.c.server_conn.address.host, +            self.c.server_conn.address.port, +            address.host, +            address.port +        )) + +        if not hasattr(self, "_backup_server_conn"): +            self._backup_server_conn = self.c.server_conn +            self.c.server_conn = None +        else:  # This is at least the second temporary change. We can kill the current connection. +            self.c.del_server_connection() + +        self.c.set_server_address(address, priority) +        if ssl: +            self.establish_ssl(server=True)      def restore_server(self): -        if not hasattr(self,"_backup_server"): +        if not hasattr(self, "_backup_server_conn"):              return -        raise NotImplementedError + +        self.log("Restore original server connection: %s:%s -> %s:%s" % ( +            self.c.server_conn.address.host, +            self.c.server_conn.address.port, +            self._backup_server_conn.host, +            self._backup_server_conn.port +        )) + +        self.c.del_server_connection() +        self.c.server_conn = self._backup_server_conn +        del self._backup_server_conn  from . import http, tcp diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 59bd8900..2a9f9afe 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -1001,7 +1001,6 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):              if directly_addressed_at_mitmproxy:                  self.c.set_server_address((request.host, request.port), AddressPriority.FROM_PROTOCOL)                  request.flow.server_conn = self.c.server_conn  # Update server_conn attribute on the flow -                self.c.establish_server_connection()                  self.c.client_conn.wfile.write(                      'HTTP/1.1 200 Connection established\r\n' +                      ('Proxy-agent: %s\r\n' % self.c.server_version) + diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 6ff02a36..b650181f 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -215,7 +215,6 @@ class ConnectionHandler:          self.close = False          self.conntype = None          self.sni = None -        self.server_address_priority = None          self.mode = "regular"          if self.config.reverse_proxy: @@ -300,7 +299,6 @@ class ConnectionHandler:              self.log("serverdisconnect", ["%s:%s" % (self.server_conn.address.host, self.server_conn.address.port)])              self.channel.tell("serverdisconnect", self)          self.server_conn = None -        self.server_address_priority = None          self.sni = None      def determine_conntype(self): @@ -309,27 +307,25 @@ class ConnectionHandler:      def set_server_address(self, address, priority):          """ -        Sets a new server address with the given priority +        Sets a new server address with the given priority. +        Does not re-establish either connection or SSL handshake.          @type priority: AddressPriority          """          address = tcp.Address.wrap(address) -        self.log("Try to set server address: %s:%s" % (address.host, address.port)) -        if self.server_conn and (self.server_conn.priority > priority): -            self.log("Server address priority too low (is: %s, got: %s)" % (self.server_address_priority, priority)) -            return -        if self.server_conn and (self.server_conn.address == address): -            self.server_conn.priority = priority  # Possibly increase priority -            self.log("Addresses match, skip.") -            return +        if self.server_conn: +            if self.server_conn.priority > priority: +                self.log("Attempt to change server address, " +                         "but priority is too low (is: %s, got: %s)" % (self.server_conn.priority, priority)) +                return +            if self.server_conn.address == address: +                self.server_conn.priority = priority  # Possibly increase priority +                return -        server_conn = ServerConnection(address, priority) -        if self.server_conn and self.server_conn.connection:              self.del_server_connection() -            self.server_conn = server_conn -            self.establish_server_connection() -        else: -            self.server_conn = server_conn + +        self.log("Set new server address: %s:%s" % (address.host, address.port)) +        self.server_conn = ServerConnection(address, priority)      def establish_server_connection(self):          """ @@ -373,6 +369,7 @@ class ConnectionHandler:          if server:              if self.server_conn.ssl_established:                  raise ProxyError(502, "SSL to Server already established.") +            self.establish_server_connection()  # make sure there is a server connection.              self.server_conn.establish_ssl(self.config.clientcerts, self.sni)          if client:              if self.client_conn.ssl_established: @@ -384,7 +381,7 @@ class ConnectionHandler:      def server_reconnect(self, no_ssl=False):          address = self.server_conn.address          had_ssl = self.server_conn.ssl_established -        priority = self.server_address_priority +        priority = self.server_conn.priority          sni = self.sni          self.log("(server reconnect follows)")          self.del_server_connection() | 
