From a9e6121a08c745961992c9fd2b4e4593063192f5 Mon Sep 17 00:00:00 2001
From: Maximilian Hils <git@maximilianhils.com>
Date: Fri, 8 Aug 2014 02:45:24 +0200
Subject: properly express state information on server connections, refs #315

---
 libmproxy/protocol/handle.py     |  6 ++++-
 libmproxy/protocol/http.py       | 50 ++++++++++++++++------------------------
 libmproxy/protocol/primitives.py |  8 +++++++
 libmproxy/proxy/connection.py    |  2 ++
 libmproxy/proxy/server.py        | 12 +++++++---
 5 files changed, 44 insertions(+), 34 deletions(-)

diff --git a/libmproxy/protocol/handle.py b/libmproxy/protocol/handle.py
index 42917ba1..a238b349 100644
--- a/libmproxy/protocol/handle.py
+++ b/libmproxy/protocol/handle.py
@@ -19,4 +19,8 @@ def handle_messages(conntype, connection_handler):
 
 
 def handle_error(conntype, connection_handler, error):
-    return _handler(conntype, connection_handler).handle_error(error)
\ No newline at end of file
+    return _handler(conntype, connection_handler).handle_error(error)
+
+
+def handle_server_reconnect(conntype, connection_handler, state):
+    return _handler(conntype, connection_handler).handle_server_reconnect(state)
\ No newline at end of file
diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py
index 8e470b01..e4f994c9 100644
--- a/libmproxy/protocol/http.py
+++ b/libmproxy/protocol/http.py
@@ -793,6 +793,7 @@ class HTTPFlow(Flow):
         """
         if isinstance(f, basestring):
             from .. import filt
+
             f = filt.parse(f)
             if not f:
                 raise ValueError("Invalid filter expression.")
@@ -974,6 +975,10 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
 
             if flow.request.form_in == "authority" and flow.response.code == 200:
                 self.ssl_upgrade()
+                # TODO: Eventually add headers (space/usefulness tradeoff)
+                self.c.server_conn.state.append(("http", {"state": "connect",
+                                                          "host": flow.request.host,
+                                                          "port": flow.request.port}))
 
             # If the user has changed the target server on this connection,
             # restore the original target server
@@ -984,6 +989,21 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
             self.handle_error(e, flow)
         return False
 
+    def handle_server_reconnect(self, state):
+        if state["state"] == "connect":
+            upstream_request = HTTPRequest("authority", "CONNECT", None, state["host"], state["port"], None,
+                                           (1, 1), ODictCaseless(), "")
+            self.c.server_conn.send(upstream_request._assemble())
+            resp = HTTPResponse.from_stream(self.c.server_conn.rfile, upstream_request.method)
+            if resp.code != 200:
+                raise proxy.ProxyError(resp.code,
+                                       "Cannot reestablish SSL " +
+                                       "connection with upstream proxy: \r\n" +
+                                       str(resp._assemble()))
+            return
+        else:  # pragma: nocover
+            raise RuntimeError("Unknown State: %s" % state["state"])
+
     def handle_error(self, error, flow=None):
 
         message = repr(error)
@@ -1024,35 +1044,6 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
         self.c.client_conn.wfile.write(html_content)
         self.c.client_conn.wfile.flush()
 
-    def hook_reconnect(self, upstream_request):
-        """
-        If the authority request has been forwarded upstream (because we have another proxy server there),
-        money-patch the ConnectionHandler.server_reconnect function to resend the CONNECT request on reconnect.
-        Hooking code isn't particulary beautiful, but it isolates this edge-case from
-        the protocol-agnostic ConnectionHandler
-        """
-        self.c.log("Hook reconnect function", level="debug")
-        original_reconnect_func = self.c.server_reconnect
-
-        def reconnect_http_proxy():
-            self.c.log("Hooked reconnect function", "debug")
-            self.c.log("Hook: Run original reconnect", "debug")
-            original_reconnect_func(no_ssl=True)
-            self.c.log("Hook: Write CONNECT request to upstream proxy", "debug",
-                       [upstream_request._assemble_first_line()])
-            self.c.server_conn.send(upstream_request._assemble())
-            self.c.log("Hook: Read answer to CONNECT request from proxy", "debug")
-            resp = HTTPResponse.from_stream(self.c.server_conn.rfile, upstream_request.method)
-            if resp.code != 200:
-                raise proxy.ProxyError(resp.code,
-                                       "Cannot reestablish SSL " +
-                                       "connection with upstream proxy: \r\n" +
-                                       str(resp.headers))
-            self.c.log("Hook: Establish SSL with upstream proxy", "debug")
-            self.c.establish_ssl(server=True)
-
-        self.c.server_reconnect = reconnect_http_proxy
-
     def ssl_upgrade(self):
         """
         Upgrade the connection to SSL after an authority (CONNECT) request has been made.
@@ -1089,7 +1080,6 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
                     self.skip_authentication = True
                     return False
                 else:
-                    self.hook_reconnect(request)
                     return True
         elif request.form_in == self.expected_form_in:
             if request.form_in == "absolute":
diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py
index f3ecdab7..7b936f7f 100644
--- a/libmproxy/protocol/primitives.py
+++ b/libmproxy/protocol/primitives.py
@@ -148,6 +148,14 @@ class ProtocolHandler(object):
         """
         raise NotImplementedError  # pragma: nocover
 
+    def handle_server_reconnect(self, state):
+        """
+        This method gets called if a server connection needs to reconnect and there's a state associated
+        with the server connection (e.g. a previously-sent CONNECT request or a SOCKS proxy request).
+        This method gets called after the connection has been restablished but before SSL is established.
+        """
+        raise NotImplementedError  # pragma: nocover
+
     def handle_error(self, error):
         """
         This method gets called should there be an uncaught exception during the connection.
diff --git a/libmproxy/proxy/connection.py b/libmproxy/proxy/connection.py
index 372bee2e..38436233 100644
--- a/libmproxy/proxy/connection.py
+++ b/libmproxy/proxy/connection.py
@@ -68,6 +68,7 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
         tcp.TCPClient.__init__(self, address)
         self.priority = priority
 
+        self.state = []  # a list containing (conntype, state) tuples
         self.peername = None
         self.sockname = None
         self.timestamp_start = None
@@ -76,6 +77,7 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
         self.timestamp_ssl_setup = None
 
     _stateobject_attributes = dict(
+        state=list,
         peername=tuple,
         sockname=tuple,
         timestamp_start=float,
diff --git a/libmproxy/proxy/server.py b/libmproxy/proxy/server.py
index c817b3b5..c3f1e048 100644
--- a/libmproxy/proxy/server.py
+++ b/libmproxy/proxy/server.py
@@ -7,7 +7,7 @@ from netlib import tcp
 from .primitives import ProxyServerError, Log, ProxyError, ConnectionTypeChange, \
     AddressPriority
 from .connection import ClientConnection, ServerConnection
-from ..protocol.handle import handle_messages, handle_error
+from ..protocol.handle import handle_messages, handle_error, handle_server_reconnect
 from .. import version
 
 
@@ -207,16 +207,22 @@ class ConnectionHandler:
                 ca_file=self.config.ca_file
             )
 
-    def server_reconnect(self, no_ssl=False):
+    def server_reconnect(self):
         address = self.server_conn.address
         had_ssl = self.server_conn.ssl_established
         priority = self.server_conn.priority
+        state = self.server_conn.state
         sni = self.sni
         self.log("(server reconnect follows)", "debug")
         self.del_server_connection()
         self.set_server_address(address, priority)
         self.establish_server_connection()
-        if had_ssl and not no_ssl:
+
+        for s in state:
+            handle_server_reconnect(s[0], self, s[1])
+        self.server_conn.state = state
+
+        if had_ssl:
             self.sni = sni
             self.establish_ssl(server=True)
 
-- 
cgit v1.2.3