aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2013-02-24 17:35:24 +1300
committerAldo Cortesi <aldo@nullcube.com>2013-02-24 17:35:24 +1300
commit705559d65e5dc5883395efb85bacbf1459eb243c (patch)
tree2fb5143c73753b8541a47fb8ca20d10e96438bf6
parentd0639e8925541bd6f6f386386c982d23b3828d3d (diff)
downloadmitmproxy-705559d65e5dc5883395efb85bacbf1459eb243c.tar.gz
mitmproxy-705559d65e5dc5883395efb85bacbf1459eb243c.tar.bz2
mitmproxy-705559d65e5dc5883395efb85bacbf1459eb243c.zip
Refactor to prepare for SNI fixes.
-rw-r--r--libmproxy/proxy.py99
-rw-r--r--test/test_proxy.py12
2 files changed, 55 insertions, 56 deletions
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py
index 088fe94c..d92e2da9 100644
--- a/libmproxy/proxy.py
+++ b/libmproxy/proxy.py
@@ -50,36 +50,13 @@ class ProxyConfig:
self.certstore = certutils.CertStore(certdir)
-class RequestReplayThread(threading.Thread):
- def __init__(self, config, flow, masterq):
- self.config, self.flow, self.channel = config, flow, controller.Channel(masterq)
- threading.Thread.__init__(self)
-
- def run(self):
- try:
- r = self.flow.request
- server = ServerConnection(self.config, r.host, r.port)
- server.connect(r.scheme)
- server.send(r)
- httpversion, code, msg, headers, content = http.read_response(
- server.rfile, r.method, self.config.body_size_limit
- )
- response = flow.Response(
- self.flow.request, httpversion, code, msg, headers, content, server.cert
- )
- self.channel.ask(response)
- except (ProxyError, http.HttpError, tcp.NetLibError), v:
- err = flow.Error(self.flow.request, str(v))
- self.channel.ask(err)
-
-
class ServerConnection(tcp.TCPClient):
def __init__(self, config, host, port):
tcp.TCPClient.__init__(self, host, port)
self.config = config
self.requestcount = 0
- def connect(self, scheme):
+ def connect(self, scheme, sni):
tcp.TCPClient.connect(self)
if scheme == "https":
clientcert = None
@@ -88,7 +65,7 @@ class ServerConnection(tcp.TCPClient):
if os.path.exists(path):
clientcert = path
try:
- self.convert_to_ssl(clientcert=clientcert, sni=self.host)
+ self.convert_to_ssl(cert=clientcert, sni=sni)
except tcp.NetLibError, v:
raise ProxyError(400, str(v))
@@ -109,12 +86,35 @@ class ServerConnection(tcp.TCPClient):
pass
+class RequestReplayThread(threading.Thread):
+ def __init__(self, config, flow, masterq):
+ self.config, self.flow, self.channel = config, flow, controller.Channel(masterq)
+ threading.Thread.__init__(self)
+
+ def run(self):
+ try:
+ r = self.flow.request
+ server = ServerConnection(self.config, r.host, r.port)
+ server.connect(r.scheme, r.host)
+ server.send(r)
+ httpversion, code, msg, headers, content = http.read_response(
+ server.rfile, r.method, self.config.body_size_limit
+ )
+ response = flow.Response(
+ self.flow.request, httpversion, code, msg, headers, content, server.cert
+ )
+ self.channel.ask(response)
+ except (ProxyError, http.HttpError, tcp.NetLibError), v:
+ err = flow.Error(self.flow.request, str(v))
+ self.channel.ask(err)
+
+
class ServerConnectionPool:
def __init__(self, config):
self.config = config
self.conn = None
- def get_connection(self, scheme, host, port):
+ def get_connection(self, scheme, host, port, sni):
sc = self.conn
if self.conn and (host, port) != (sc.host, sc.port):
sc.terminate()
@@ -122,7 +122,7 @@ class ServerConnectionPool:
if not self.conn:
try:
self.conn = ServerConnection(self.config, host, port)
- self.conn.connect(scheme)
+ self.conn.connect(scheme, sni)
except tcp.NetLibError, v:
raise ProxyError(502, v)
return self.conn
@@ -190,18 +190,18 @@ class ProxyHandler(tcp.BaseHandler):
# the case, we want to reconnect without sending an error
# to the client.
while 1:
+ sc = self.server_conn_pool.get_connection(scheme, host, port, host)
+ sc.send(request)
+ sc.rfile.reset_timestamps()
try:
- sc = self.server_conn_pool.get_connection(scheme, host, port)
- sc.send(request)
- sc.rfile.reset_timestamps()
httpversion, code, msg, headers, content = http.read_response(
sc.rfile,
request.method,
self.config.body_size_limit
)
except http.HttpErrorConnClosed, v:
+ self.server_conn_pool.del_connection(scheme, host, port)
if sc.requestcount > 1:
- self.server_conn_pool.del_connection(scheme, host, port)
continue
else:
raise
@@ -324,25 +324,6 @@ class ProxyHandler(tcp.BaseHandler):
self.rfile.first_byte_timestamp, utils.timestamp()
)
- def read_request_reverse(self, client_conn):
- line = self.get_line(self.rfile)
- if line == "":
- return None
- scheme, host, port = self.config.reverse_proxy
- r = http.parse_init_http(line)
- if not r:
- raise ProxyError(400, "Bad HTTP request line: %s"%repr(line))
- method, path, httpversion = r
- headers = self.read_headers(authenticate=False)
- content = http.read_http_body_request(
- self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
- )
- return flow.Request(
- client_conn, httpversion, host, port, "http", method, path, headers, content,
- self.rfile.first_byte_timestamp, utils.timestamp()
- )
-
-
def read_request_proxy(self, client_conn):
line = self.get_line(self.rfile)
if line == "":
@@ -398,6 +379,24 @@ class ProxyHandler(tcp.BaseHandler):
self.rfile.first_byte_timestamp, utils.timestamp()
)
+ def read_request_reverse(self, client_conn):
+ line = self.get_line(self.rfile)
+ if line == "":
+ return None
+ scheme, host, port = self.config.reverse_proxy
+ r = http.parse_init_http(line)
+ if not r:
+ raise ProxyError(400, "Bad HTTP request line: %s"%repr(line))
+ method, path, httpversion = r
+ headers = self.read_headers(authenticate=False)
+ content = http.read_http_body_request(
+ self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
+ )
+ return flow.Request(
+ client_conn, httpversion, host, port, "http", method, path, headers, content,
+ self.rfile.first_byte_timestamp, utils.timestamp()
+ )
+
def read_request(self, client_conn):
self.rfile.reset_timestamps()
if self.config.transparent_proxy:
diff --git a/test/test_proxy.py b/test/test_proxy.py
index bdac8697..b575a1d0 100644
--- a/test/test_proxy.py
+++ b/test/test_proxy.py
@@ -40,7 +40,7 @@ class TestServerConnection:
def test_simple(self):
sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port)
- sc.connect("http")
+ sc.connect("http", "host.com")
r = tutils.treq()
r.path = "/p/200:da"
sc.send(r)
@@ -54,7 +54,7 @@ class TestServerConnection:
def test_terminate_error(self):
sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port)
- sc.connect("http")
+ sc.connect("http", "host.com")
sc.connection = mock.Mock()
sc.connection.close = mock.Mock(side_effect=IOError)
sc.terminate()
@@ -75,14 +75,14 @@ class TestServerConnectionPool:
@mock.patch("libmproxy.proxy.ServerConnection", _dummysc)
def test_pooling(self):
p = proxy.ServerConnectionPool(proxy.ProxyConfig())
- c = p.get_connection("http", "localhost", 80)
- c2 = p.get_connection("http", "localhost", 80)
+ c = p.get_connection("http", "localhost", 80, "localhost")
+ c2 = p.get_connection("http", "localhost", 80, "localhost")
assert c is c2
- c3 = p.get_connection("http", "foo", 80)
+ c3 = p.get_connection("http", "foo", 80, "localhost")
assert not c is c3
@mock.patch("libmproxy.proxy.ServerConnection", _errsc)
def test_connection_error(self):
p = proxy.ServerConnectionPool(proxy.ProxyConfig())
- tutils.raises("502", p.get_connection, "http", "localhost", 80)
+ tutils.raises("502", p.get_connection, "http", "localhost", 80, "localhost")