From 2aa175a6ca657db0b4faa2aeb84a78b7ef3c4761 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 29 Jan 2013 10:55:19 +1300 Subject: Stub implementation of a server connection pool. --- libmproxy/proxy.py | 48 +++++++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index f14e4e3e..3bbb82ba 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -107,12 +107,30 @@ class ServerConnection(tcp.TCPClient): except IOError: pass +class ServerConnectionPool: + def __init__(self, config): + self.config = config + self.conn = None + + def get_connection(self, scheme, host, port): + sc = self.conn + if self.conn and (host, port) != (sc.host, sc.port): + sc.terminate() + self.conn = None + if not self.conn: + try: + self.conn = ServerConnection(self.config, host, port) + self.conn.connect(scheme) + except tcp.NetLibError, v: + raise ProxyError(502, v) + return self.conn + class ProxyHandler(tcp.BaseHandler): def __init__(self, config, connection, client_address, server, mqueue, server_version): self.mqueue, self.server_version = mqueue, server_version self.config = config - self.server_conn = None + self.server_conn_pool = ServerConnectionPool(config) self.proxy_connect_state = None self.sni = None tcp.BaseHandler.__init__(self, connection, client_address, server) @@ -133,18 +151,6 @@ class ProxyHandler(tcp.BaseHandler): ) cd._send(self.mqueue) - def server_connect(self, scheme, host, port): - sc = self.server_conn - if sc and (host, port) != (sc.host, sc.port): - sc.terminate() - self.server_conn = None - if not self.server_conn: - try: - self.server_conn = ServerConnection(self.config, host, port) - self.server_conn.connect(scheme) - except tcp.NetLibError, v: - raise ProxyError(502, v) - def handle_request(self, cc): try: request, err = None, None @@ -173,21 +179,21 @@ class ProxyHandler(tcp.BaseHandler): scheme, host, port = self.config.reverse_proxy else: scheme, host, port = request.scheme, request.host, request.port - self.server_connect(scheme, host, port) - self.server_conn.send(request) - self.server_conn.rfile.reset_timestamps() + sc = self.server_conn_pool.get_connection(scheme, host, port) + sc.send(request) + sc.rfile.reset_timestamps() httpversion, code, msg, headers, content = http.read_response( - self.server_conn.rfile, + sc.rfile, request.method, self.config.body_size_limit ) response = flow.Response( - request, httpversion, code, msg, headers, content, self.server_conn.cert, self.server_conn.rfile.first_byte_timestamp, utils.timestamp() + request, httpversion, code, msg, headers, content, sc.cert, + sc.rfile.first_byte_timestamp, utils.timestamp() ) - response = response._send(self.mqueue) if response is None: - self.server_conn.terminate() + sc.terminate() if response is None: return self.send_response(response) @@ -310,7 +316,7 @@ class ProxyHandler(tcp.BaseHandler): self.rfile.first_byte_timestamp, utils.timestamp() ) - + def read_request_proxy(self, client_conn): line = self.get_line(self.rfile) if line == "": -- cgit v1.2.3 From 782bbee8c0a7d14be25e87d61c9c6db03b532eb7 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 29 Jan 2013 11:35:57 +1300 Subject: Unit tests for ServerConnectionPool --- libmproxy/proxy.py | 1 + test/test_proxy.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 3bbb82ba..f2c6db1f 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -107,6 +107,7 @@ class ServerConnection(tcp.TCPClient): except IOError: pass + class ServerConnectionPool: def __init__(self, config): self.config = config diff --git a/test/test_proxy.py b/test/test_proxy.py index c73f61d8..0edb2fee 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -1,7 +1,7 @@ from libmproxy import proxy, flow import tutils from libpathod import test -from netlib import http +from netlib import http, tcp import mock @@ -58,3 +58,31 @@ class TestServerConnection: sc.connection = mock.Mock() sc.connection.close = mock.Mock(side_effect=IOError) sc.terminate() + + + +def _dummysc(config, host, port): + return mock.MagicMock(config=config, host=host, port=port) + + +def _errsc(config, host, port): + m = mock.MagicMock(config=config, host=host, port=port) + m.connect = mock.MagicMock(side_effect=tcp.NetLibError()) + return m + + +class TestServerConnectionPool: + @mock.patch("libmproxy.proxy.ServerConnection", _dummysc) + def test_pooling(self): + p = proxy.ServerConnectionPool(proxy.ProxyConfig()) + c = p.get_connection("http", "localhost", 80) + c2 = p.get_connection("http", "localhost", 80) + assert c is c2 + c3 = p.get_connection("http", "foo", 80) + assert not c is c3 + + @mock.patch("libmproxy.proxy.ServerConnection", _errsc) + def test_connection_error(self): + p = proxy.ServerConnectionPool(proxy.ProxyConfig()) + tutils.raises("502", p.get_connection, "http", "localhost", 80) + -- cgit v1.2.3 From 1ccb2c5dea9530682aae83d489f1738d9286fa4e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 16 Feb 2013 16:46:16 +1300 Subject: Test WSGI app calling. - Factor out test servers into a separate file - Adjust docs to note new Flask dependency --- README.mkd | 26 ++++---- test/test_proxy.py | 2 +- test/test_server.py | 24 +++++-- test/tservers.py | 183 ++++++++++++++++++++++++++++++++++++++++++++++++++++ test/tutils.py | 167 +---------------------------------------------- 5 files changed, 215 insertions(+), 187 deletions(-) create mode 100644 test/tservers.py diff --git a/README.mkd b/README.mkd index b7d5f4ee..4179ce2b 100644 --- a/README.mkd +++ b/README.mkd @@ -50,28 +50,24 @@ Requirements ------------ * [Python](http://www.python.org) 2.7.x. +* [netlib](http://pypi.python.org/pypi/netlib) 0.2.2 or newer. * [PyOpenSSL](http://pypi.python.org/pypi/pyOpenSSL) 0.13 or newer. * [pyasn1](http://pypi.python.org/pypi/pyasn1) 0.1.2 or newer. * [urwid](http://excess.org/urwid/) version 1.1 or newer. * [PIL](http://www.pythonware.com/products/pil/) version 1.1 or newer. * [lxml](http://lxml.de/) version 2.3 or newer. -* [netlib](http://pypi.python.org/pypi/netlib) 0.2.2 or newer. -The following auxiliary components may be needed if you plan to hack on -mitmproxy: +__mitmproxy__ is tested and developed on OSX, Linux and OpenBSD. Windows is not +officially supported at the moment. -* The test suite uses the [nose](http://readthedocs.org/docs/nose/en/latest/) unit testing - framework and requires [human_curl](https://github.com/Lispython/human_curl) and - [pathod](http://pathod.org). -* Rendering the documentation requires [countershape](http://github.com/cortesi/countershape). -__mitmproxy__ is tested and developed on OSX, Linux and OpenBSD. Windows is not -supported at the moment. +Hacking +------- + +The following components are needed if you plan to hack on mitmproxy: -You should also make sure that your console environment is set up with the -following: +* The test suite uses the [nose](http://readthedocs.org/docs/nose/en/latest/) unit testing + framework and requires [human_curl](https://github.com/Lispython/human_curl), + [pathod](http://pathod.org) and [flask](http://flask.pocoo.org/). +* Rendering the documentation requires [countershape](http://github.com/cortesi/countershape). -* EDITOR environment variable to determine the external editor. -* PAGER environment variable to determine the external pager. -* Appropriate entries in your mailcap files to determine external - viewers for request and response contents. diff --git a/test/test_proxy.py b/test/test_proxy.py index 0edb2fee..bdac8697 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -67,7 +67,7 @@ def _dummysc(config, host, port): def _errsc(config, host, port): m = mock.MagicMock(config=config, host=host, port=port) - m.connect = mock.MagicMock(side_effect=tcp.NetLibError()) + m.connect = mock.MagicMock(side_effect=tcp.NetLibError()) return m diff --git a/test/test_server.py b/test/test_server.py index 0a2f142e..58dfee58 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -1,7 +1,7 @@ import socket, time from netlib import tcp from libpathod import pathoc -import tutils +import tutils, tservers """ Note that the choice of response code in these tests matters more than you @@ -39,7 +39,19 @@ class SanityMixin: assert l.error -class TestHTTP(tutils.HTTPProxTest, SanityMixin): +class TestHTTP(tservers.HTTPProxTest, SanityMixin): + def test_app(self): + p = self.pathoc() + ret = p.request("get:'http://testapp/'") + assert ret[1] == 200 + assert ret[4] == "testapp" + + def test_app_err(self): + p = self.pathoc() + ret = p.request("get:'http://errapp/'") + assert ret[1] == 500 + assert "ValueError" in ret[4] + def test_invalid_http(self): t = tcp.TCPClient("127.0.0.1", self.proxy.port) t.connect() @@ -69,7 +81,7 @@ class TestHTTP(tutils.HTTPProxTest, SanityMixin): assert l.response.code == 304 -class TestHTTPS(tutils.HTTPProxTest, SanityMixin): +class TestHTTPS(tservers.HTTPProxTest, SanityMixin): ssl = True clientcerts = True def test_clientcert(self): @@ -77,15 +89,15 @@ class TestHTTPS(tutils.HTTPProxTest, SanityMixin): assert self.last_log()["request"]["clientcert"]["keyinfo"] -class TestReverse(tutils.ReverseProxTest, SanityMixin): +class TestReverse(tservers.ReverseProxTest, SanityMixin): reverse = True -class TestTransparent(tutils.TransparentProxTest, SanityMixin): +class TestTransparent(tservers.TransparentProxTest, SanityMixin): transparent = True -class TestProxy(tutils.HTTPProxTest): +class TestProxy(tservers.HTTPProxTest): def test_http(self): f = self.pathod("304") assert f.status_code == 304 diff --git a/test/tservers.py b/test/tservers.py new file mode 100644 index 00000000..2966a436 --- /dev/null +++ b/test/tservers.py @@ -0,0 +1,183 @@ +import threading, Queue +import flask +import human_curl as hurl +import libpathod.test, libpathod.pathoc +from libmproxy import proxy, flow, controller +import tutils + +testapp = flask.Flask(__name__) + +@testapp.route("/") +def hello(): + return "testapp" + +@testapp.route("/error") +def error(): + raise ValueError("An exception...") + + +def errapp(environ, start_response): + raise ValueError("errapp") + + +class TestMaster(flow.FlowMaster): + def __init__(self, testq, config): + s = proxy.ProxyServer(config, 0) + s.apps.add(testapp, "testapp", 80) + s.apps.add(errapp, "errapp", 80) + state = flow.State() + flow.FlowMaster.__init__(self, s, state) + self.testq = testq + + def handle(self, m): + flow.FlowMaster.handle(self, m) + m._ack() + + +class ProxyThread(threading.Thread): + def __init__(self, testq, config): + self.tmaster = TestMaster(testq, config) + controller.should_exit = False + threading.Thread.__init__(self) + + @property + def port(self): + return self.tmaster.server.port + + def run(self): + self.tmaster.run() + + def shutdown(self): + self.tmaster.shutdown() + + +class ProxTestBase: + @classmethod + def setupAll(cls): + cls.tqueue = Queue.Queue() + cls.server = libpathod.test.Daemon(ssl=cls.ssl) + pconf = cls.get_proxy_config() + config = proxy.ProxyConfig( + certfile=tutils.test_data.path("data/testkey.pem"), + **pconf + ) + cls.proxy = ProxyThread(cls.tqueue, config) + cls.proxy.start() + + @property + def master(cls): + return cls.proxy.tmaster + + @classmethod + def teardownAll(cls): + cls.proxy.shutdown() + cls.server.shutdown() + + def setUp(self): + self.master.state.clear() + + @property + def scheme(self): + return "https" if self.ssl else "http" + + @property + def proxies(self): + """ + The URL base for the server instance. + """ + return ( + (self.scheme, ("127.0.0.1", self.proxy.port)) + ) + + @property + def urlbase(self): + """ + The URL base for the server instance. + """ + return self.server.urlbase + + def last_log(self): + return self.server.last_log() + + +class HTTPProxTest(ProxTestBase): + ssl = None + clientcerts = False + @classmethod + def get_proxy_config(cls): + d = dict() + if cls.clientcerts: + d["clientcerts"] = tutils.test_data.path("data/clientcert") + return d + + def pathoc(self, connect_to = None): + p = libpathod.pathoc.Pathoc("localhost", self.proxy.port) + p.connect(connect_to) + return p + + def pathod(self, spec): + """ + Constructs a pathod request, with the appropriate base and proxy. + """ + return hurl.get( + self.urlbase + "/p/" + spec, + proxy=self.proxies, + validate_cert=False, + #debug=hurl.utils.stdout_debug + ) + + +class TResolver: + def __init__(self, port): + self.port = port + + def original_addr(self, sock): + return ("127.0.0.1", self.port) + + +class TransparentProxTest(ProxTestBase): + ssl = None + @classmethod + def get_proxy_config(cls): + return dict( + transparent_proxy = dict( + resolver = TResolver(cls.server.port), + sslports = [] + ) + ) + + def pathod(self, spec): + """ + Constructs a pathod request, with the appropriate base and proxy. + """ + r = hurl.get( + "http://127.0.0.1:%s"%self.proxy.port + "/p/" + spec, + validate_cert=False, + #debug=hurl.utils.stdout_debug + ) + return r + + +class ReverseProxTest(ProxTestBase): + ssl = None + @classmethod + def get_proxy_config(cls): + return dict( + reverse_proxy = ( + "https" if cls.ssl else "http", + "127.0.0.1", + cls.server.port + ) + ) + + def pathod(self, spec): + """ + Constructs a pathod request, with the appropriate base and proxy. + """ + r = hurl.get( + "http://127.0.0.1:%s"%self.proxy.port + "/p/" + spec, + validate_cert=False, + #debug=hurl.utils.stdout_debug + ) + return r + diff --git a/test/tutils.py b/test/tutils.py index 9868c778..d5497bae 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -1,10 +1,8 @@ -import threading, Queue import os, shutil, tempfile from contextlib import contextmanager -from libmproxy import proxy, flow, controller, utils +from libmproxy import flow, utils from netlib import certutils -import human_curl as hurl -import libpathod.test, libpathod.pathoc + def treq(conn=None): if not conn: @@ -42,166 +40,6 @@ def tflow_err(): return f -class TestMaster(flow.FlowMaster): - def __init__(self, testq, config): - s = proxy.ProxyServer(config, 0) - state = flow.State() - flow.FlowMaster.__init__(self, s, state) - self.testq = testq - - def handle(self, m): - flow.FlowMaster.handle(self, m) - m._ack() - - -class ProxyThread(threading.Thread): - def __init__(self, testq, config): - self.tmaster = TestMaster(testq, config) - controller.should_exit = False - threading.Thread.__init__(self) - - @property - def port(self): - return self.tmaster.server.port - - def run(self): - self.tmaster.run() - - def shutdown(self): - self.tmaster.shutdown() - - -class ProxTestBase: - @classmethod - def setupAll(cls): - cls.tqueue = Queue.Queue() - cls.server = libpathod.test.Daemon(ssl=cls.ssl) - pconf = cls.get_proxy_config() - config = proxy.ProxyConfig( - certfile=test_data.path("data/testkey.pem"), - **pconf - ) - cls.proxy = ProxyThread(cls.tqueue, config) - cls.proxy.start() - - @property - def master(cls): - return cls.proxy.tmaster - - @classmethod - def teardownAll(cls): - cls.proxy.shutdown() - cls.server.shutdown() - - def setUp(self): - self.master.state.clear() - - @property - def scheme(self): - return "https" if self.ssl else "http" - - @property - def proxies(self): - """ - The URL base for the server instance. - """ - return ( - (self.scheme, ("127.0.0.1", self.proxy.port)) - ) - - @property - def urlbase(self): - """ - The URL base for the server instance. - """ - return self.server.urlbase - - def last_log(self): - return self.server.last_log() - - -class HTTPProxTest(ProxTestBase): - ssl = None - clientcerts = False - @classmethod - def get_proxy_config(cls): - d = dict() - if cls.clientcerts: - d["clientcerts"] = test_data.path("data/clientcert") - return d - - def pathoc(self, connect_to = None): - p = libpathod.pathoc.Pathoc("localhost", self.proxy.port) - p.connect(connect_to) - return p - - def pathod(self, spec): - """ - Constructs a pathod request, with the appropriate base and proxy. - """ - return hurl.get( - self.urlbase + "/p/" + spec, - proxy=self.proxies, - validate_cert=False, - #debug=hurl.utils.stdout_debug - ) - - -class TResolver: - def __init__(self, port): - self.port = port - - def original_addr(self, sock): - return ("127.0.0.1", self.port) - - -class TransparentProxTest(ProxTestBase): - ssl = None - @classmethod - def get_proxy_config(cls): - return dict( - transparent_proxy = dict( - resolver = TResolver(cls.server.port), - sslports = [] - ) - ) - - def pathod(self, spec): - """ - Constructs a pathod request, with the appropriate base and proxy. - """ - r = hurl.get( - "http://127.0.0.1:%s"%self.proxy.port + "/p/" + spec, - validate_cert=False, - #debug=hurl.utils.stdout_debug - ) - return r - - -class ReverseProxTest(ProxTestBase): - ssl = None - @classmethod - def get_proxy_config(cls): - return dict( - reverse_proxy = ( - "https" if cls.ssl else "http", - "127.0.0.1", - cls.server.port - ) - ) - - def pathod(self, spec): - """ - Constructs a pathod request, with the appropriate base and proxy. - """ - r = hurl.get( - "http://127.0.0.1:%s"%self.proxy.port + "/p/" + spec, - validate_cert=False, - #debug=hurl.utils.stdout_debug - ) - return r - - @contextmanager def tmpdir(*args, **kwargs): orig_workdir = os.getcwd() @@ -252,5 +90,4 @@ def raises(exc, obj, *args, **kwargs): ) raise AssertionError("No exception raised.") - test_data = utils.Data(__name__) -- cgit v1.2.3 From aaf892e3afc682b2dc2a166a96872420e50092cd Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 17 Feb 2013 12:42:48 +1300 Subject: Significantly refactor the master/slave message passing interface. --- libmproxy/console/__init__.py | 6 +-- libmproxy/console/common.py | 4 +- libmproxy/controller.py | 85 +++++++++++++++++++++++++++++++------------ libmproxy/dump.py | 22 +++++------ libmproxy/flow.py | 50 ++++++++++++------------- libmproxy/proxy.py | 43 +++++++++++----------- test/test_dump.py | 6 ++- test/test_flow.py | 32 +++++++++------- test/tservers.py | 2 +- test/tutils.py | 15 ++++++-- 10 files changed, 158 insertions(+), 107 deletions(-) diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py index d6c7f5a2..a16cc4dc 100644 --- a/libmproxy/console/__init__.py +++ b/libmproxy/console/__init__.py @@ -580,7 +580,7 @@ class ConsoleMaster(flow.FlowMaster): self.view_flowlist() - self.server.start_slave(controller.Slave, self.masterq) + self.server.start_slave(controller.Slave, controller.Channel(self.masterq)) if self.options.rfile: ret = self.load_flows(self.options.rfile) @@ -1002,7 +1002,7 @@ class ConsoleMaster(flow.FlowMaster): if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay(): f.intercept() else: - r._ack() + r.reply() self.sync_list_view() self.refresh_flow(f) @@ -1023,7 +1023,7 @@ class ConsoleMaster(flow.FlowMaster): # Handlers def handle_log(self, l): self.add_event(l.msg) - l._ack() + l.reply() def handle_error(self, r): f = flow.FlowMaster.handle_error(self, r) diff --git a/libmproxy/console/common.py b/libmproxy/console/common.py index 2da7f802..1cc0b5b9 100644 --- a/libmproxy/console/common.py +++ b/libmproxy/console/common.py @@ -184,7 +184,7 @@ def format_flow(f, focus, extended=False, padding=2): req_timestamp = f.request.timestamp_start, req_is_replay = f.request.is_replay(), req_method = f.request.method, - req_acked = f.request.acked, + req_acked = f.request.reply.acked, req_url = f.request.get_url(), err_msg = f.error.msg if f.error else None, @@ -200,7 +200,7 @@ def format_flow(f, focus, extended=False, padding=2): d.update(dict( resp_code = f.response.code, resp_is_replay = f.response.is_replay(), - resp_acked = f.response.acked, + resp_acked = f.response.reply.acked, resp_clen = contentdesc )) t = f.response.headers["content-type"] diff --git a/libmproxy/controller.py b/libmproxy/controller.py index f38d1edb..c36bb9df 100644 --- a/libmproxy/controller.py +++ b/libmproxy/controller.py @@ -17,37 +17,73 @@ import Queue, threading should_exit = False -class Msg: + +class DummyReply: + """ + A reply object that does nothing. Useful when we need an object to seem + like it has a channel, and during testing. + """ def __init__(self): + self.acked = False + + def __call__(self, msg=False): + self.acked = True + + +class Reply: + """ + Messages sent through a channel are decorated with a "reply" attribute. + This object is used to respond to the message through the return + channel. + """ + def __init__(self, obj): + self.obj = obj self.q = Queue.Queue() self.acked = False - def _ack(self, data=False): + def __call__(self, msg=False): if not self.acked: self.acked = True - if data is None: - self.q.put(data) + if msg is None: + self.q.put(msg) else: - self.q.put(data or self) + self.q.put(msg or self.obj) - def _send(self, masterq): - self.acked = False - try: - masterq.put(self, timeout=3) - while not should_exit: # pragma: no cover - try: - g = self.q.get(timeout=0.5) - except Queue.Empty: - continue - return g - except (Queue.Empty, Queue.Full): # pragma: no cover - return None + +class Channel: + def __init__(self, q): + self.q = q + + def ask(self, m): + """ + Send a message to the master, and wait for a response. + """ + m.reply = Reply(m) + self.q.put(m) + while not should_exit: + try: + # The timeout is here so we can handle a should_exit event. + g = m.reply.q.get(timeout=0.5) + except Queue.Empty: + continue + return g + + def tell(self, m): + """ + Send a message to the master, and keep going. + """ + m.reply = None + self.q.put(m) class Slave(threading.Thread): - def __init__(self, masterq, server): - self.masterq, self.server = masterq, server - self.server.set_mqueue(masterq) + """ + Slaves get a channel end-point through which they can send messages to + the master. + """ + def __init__(self, channel, server): + self.channel, self.server = channel, server + self.server.set_channel(channel) threading.Thread.__init__(self) def run(self): @@ -55,6 +91,9 @@ class Slave(threading.Thread): class Master: + """ + Masters get and respond to messages from slaves. + """ def __init__(self, server): """ server may be None if no server is needed. @@ -81,18 +120,18 @@ class Master: def run(self): global should_exit should_exit = False - self.server.start_slave(Slave, self.masterq) + self.server.start_slave(Slave, Channel(self.masterq)) while not should_exit: self.tick(self.masterq) self.shutdown() - def handle(self, msg): # pragma: no cover + def handle(self, msg): c = "handle_" + msg.__class__.__name__.lower() m = getattr(self, c, None) if m: m(msg) else: - msg._ack() + msg.reply() def shutdown(self): global should_exit diff --git a/libmproxy/dump.py b/libmproxy/dump.py index 170c701d..3c7eee71 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -150,16 +150,6 @@ class DumpMaster(flow.FlowMaster): print >> self.outfile, e self.outfile.flush() - def handle_log(self, l): - self.add_event(l.msg) - l._ack() - - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) - if f: - r._ack() - return f - def indent(self, n, t): l = str(t).strip().split("\n") return "\n".join(" "*n + i for i in l) @@ -210,10 +200,20 @@ class DumpMaster(flow.FlowMaster): self.outfile.flush() self.state.delete_flow(f) + def handle_log(self, l): + self.add_event(l.msg) + l.reply() + + def handle_request(self, r): + f = flow.FlowMaster.handle_request(self, r) + if f: + r.reply() + return f + def handle_response(self, msg): f = flow.FlowMaster.handle_response(self, msg) if f: - msg._ack() + msg.reply() self._process_flow(f) return f diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 9238cfbf..0f5fb563 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -196,7 +196,7 @@ class decoded(object): self.o.encode(self.ce) -class HTTPMsg(controller.Msg): +class HTTPMsg: def get_decoded_content(self): """ Returns the decoded content based on the current Content-Encoding header. @@ -252,6 +252,7 @@ class HTTPMsg(controller.Msg): return 0 return len(self.content) + class Request(HTTPMsg): """ An HTTP request. @@ -289,7 +290,6 @@ class Request(HTTPMsg): self.timestamp_start = timestamp_start or utils.timestamp() self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start) self.close = False - controller.Msg.__init__(self) # Have this request's cookies been modified by sticky cookies or auth? self.stickycookie = False @@ -396,7 +396,6 @@ class Request(HTTPMsg): Returns a copy of this object. """ c = copy.copy(self) - c.acked = True c.headers = self.headers.copy() return c @@ -603,7 +602,6 @@ class Response(HTTPMsg): self.cert = cert self.timestamp_start = timestamp_start or utils.timestamp() self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start) - controller.Msg.__init__(self) self.replay = False def _refresh_cookie(self, c, delta): @@ -708,7 +706,6 @@ class Response(HTTPMsg): Returns a copy of this object. """ c = copy.copy(self) - c.acked = True c.headers = self.headers.copy() return c @@ -773,7 +770,7 @@ class Response(HTTPMsg): cookies.append((cookie_name, (cookie_value, cookie_parameters))) return dict(cookies) -class ClientDisconnect(controller.Msg): +class ClientDisconnect: """ A client disconnection event. @@ -782,11 +779,10 @@ class ClientDisconnect(controller.Msg): client_conn: ClientConnect object. """ def __init__(self, client_conn): - controller.Msg.__init__(self) self.client_conn = client_conn -class ClientConnect(controller.Msg): +class ClientConnect: """ A single client connection. Each connection can result in multiple HTTP Requests. @@ -807,7 +803,6 @@ class ClientConnect(controller.Msg): self.close = False self.requestcount = 0 self.error = None - controller.Msg.__init__(self) def __eq__(self, other): return self._get_state() == other._get_state() @@ -838,11 +833,10 @@ class ClientConnect(controller.Msg): Returns a copy of this object. """ c = copy.copy(self) - c.acked = True return c -class Error(controller.Msg): +class Error: """ An Error. @@ -860,7 +854,6 @@ class Error(controller.Msg): def __init__(self, request, msg, timestamp=None): self.request, self.msg = request, msg self.timestamp = timestamp or utils.timestamp() - controller.Msg.__init__(self) def _load_state(self, state): self.msg = state["msg"] @@ -871,7 +864,6 @@ class Error(controller.Msg): Returns a copy of this object. """ c = copy.copy(self) - c.acked = True return c def _get_state(self): @@ -1180,10 +1172,11 @@ class Flow: Kill this request. """ self.error = Error(self.request, "Connection killed") - if self.request and not self.request.acked: - self.request._ack(None) - elif self.response and not self.response.acked: - self.response._ack(None) + self.error.reply = controller.DummyReply() + if self.request and not self.request.reply.acked: + self.request.reply(None) + elif self.response and not self.response.reply.acked: + self.response.reply(None) master.handle_error(self.error) self.intercepting = False @@ -1199,10 +1192,10 @@ class Flow: Continue with the flow - called after an intercept(). """ if self.request: - if not self.request.acked: - self.request._ack() - elif self.response and not self.response.acked: - self.response._ack() + if not self.request.reply.acked: + self.request.reply() + elif self.response and not self.response.reply.acked: + self.response.reply() self.intercepting = False def replace(self, pattern, repl, *args, **kwargs): @@ -1464,7 +1457,7 @@ class FlowMaster(controller.Master): flow.response = response if self.refresh_server_playback: response.refresh() - flow.request._ack(response) + flow.request.reply(response) if self.server_playback.count() == 0: self.stop_server_playback() return True @@ -1491,10 +1484,13 @@ class FlowMaster(controller.Master): Loads a flow, and returns a new flow object. """ if f.request: + f.request.reply = controller.DummyReply() fr = self.handle_request(f.request) if f.response: + f.response.reply = controller.DummyReply() self.handle_response(f.response) if f.error: + f.error.reply = controller.DummyReply() self.handle_error(f.error) return fr @@ -1522,7 +1518,7 @@ class FlowMaster(controller.Master): if self.kill_nonreplay: f.kill(self) else: - f.request._ack() + f.request.reply() def process_new_response(self, f): if self.stickycookie_state: @@ -1561,11 +1557,11 @@ class FlowMaster(controller.Master): def handle_clientconnect(self, cc): self.run_script_hook("clientconnect", cc) - cc._ack() + cc.reply() def handle_clientdisconnect(self, r): self.run_script_hook("clientdisconnect", r) - r._ack() + r.reply() def handle_error(self, r): f = self.state.add_error(r) @@ -1573,7 +1569,7 @@ class FlowMaster(controller.Master): self.run_script_hook("error", f) if self.client_playback: self.client_playback.clear(f) - r._ack() + r.reply() return f def handle_request(self, r): @@ -1596,7 +1592,7 @@ class FlowMaster(controller.Master): if self.stream: self.stream.add(f) else: - r._ack() + r.reply() return f def shutdown(self): diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index f2c6db1f..1fbb6d58 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -29,9 +29,8 @@ class ProxyError(Exception): return "ProxyError(%s, %s)"%(self.code, self.msg) -class Log(controller.Msg): +class Log: def __init__(self, msg): - controller.Msg.__init__(self) self.msg = msg @@ -51,7 +50,7 @@ class ProxyConfig: class RequestReplayThread(threading.Thread): def __init__(self, config, flow, masterq): - self.config, self.flow, self.masterq = config, flow, masterq + self.config, self.flow, self.channel = config, flow, controller.Channel(masterq) threading.Thread.__init__(self) def run(self): @@ -66,10 +65,10 @@ class RequestReplayThread(threading.Thread): response = flow.Response( self.flow.request, httpversion, code, msg, headers, content, server.cert ) - response._send(self.masterq) + self.channel.ask(response) except (ProxyError, http.HttpError, tcp.NetLibError), v: err = flow.Error(self.flow.request, str(v)) - err._send(self.masterq) + self.channel.ask(err) class ServerConnection(tcp.TCPClient): @@ -128,8 +127,8 @@ class ServerConnectionPool: class ProxyHandler(tcp.BaseHandler): - def __init__(self, config, connection, client_address, server, mqueue, server_version): - self.mqueue, self.server_version = mqueue, server_version + def __init__(self, config, connection, client_address, server, channel, server_version): + self.channel, self.server_version = channel, server_version self.config = config self.server_conn_pool = ServerConnectionPool(config) self.proxy_connect_state = None @@ -139,18 +138,18 @@ class ProxyHandler(tcp.BaseHandler): def handle(self): cc = flow.ClientConnect(self.client_address) self.log(cc, "connect") - cc._send(self.mqueue) + self.channel.ask(cc) while self.handle_request(cc) and not cc.close: pass cc.close = True - cd = flow.ClientDisconnect(cc) + cd = flow.ClientDisconnect(cc) self.log( cc, "disconnect", [ "handled %s requests"%cc.requestcount] ) - cd._send(self.mqueue) + self.channel.ask(cd) def handle_request(self, cc): try: @@ -167,14 +166,14 @@ class ProxyHandler(tcp.BaseHandler): self.log(cc, "Error in wsgi app.", err.split("\n")) return else: - request = request._send(self.mqueue) + request = self.channel.ask(request) if request is None: return if isinstance(request, flow.Response): response = request request = False - response = response._send(self.mqueue) + response = self.channel.ask(response) else: if self.config.reverse_proxy: scheme, host, port = self.config.reverse_proxy @@ -192,7 +191,7 @@ class ProxyHandler(tcp.BaseHandler): request, httpversion, code, msg, headers, content, sc.cert, sc.rfile.first_byte_timestamp, utils.timestamp() ) - response = response._send(self.mqueue) + response = self.channel.ask(response) if response is None: sc.terminate() if response is None: @@ -214,7 +213,7 @@ class ProxyHandler(tcp.BaseHandler): if request: err = flow.Error(request, cc.error) - err._send(self.mqueue) + self.channel.ask(err) self.log( cc, cc.error, ["url: %s"%request.get_url()] @@ -235,7 +234,7 @@ class ProxyHandler(tcp.BaseHandler): msg.append(" -> "+i) msg = "\n".join(msg) l = Log(msg) - l._send(self.mqueue) + self.channel.ask(l) def find_cert(self, host, port, sni): if self.config.certfile: @@ -438,18 +437,18 @@ class ProxyServer(tcp.TCPServer): tcp.TCPServer.__init__(self, (address, port)) except socket.error, v: raise ProxyServerError('Error starting proxy server: ' + v.strerror) - self.masterq = None + self.channel = None self.apps = AppRegistry() - def start_slave(self, klass, masterq): - slave = klass(masterq, self) + def start_slave(self, klass, channel): + slave = klass(channel, self) slave.start() - def set_mqueue(self, q): - self.masterq = q + def set_channel(self, channel): + self.channel = channel def handle_connection(self, request, client_address): - h = ProxyHandler(self.config, request, client_address, self, self.masterq, self.server_version) + h = ProxyHandler(self.config, request, client_address, self, self.channel, self.server_version) h.handle() try: h.finish() @@ -487,7 +486,7 @@ class DummyServer: def __init__(self, config): self.config = config - def start_slave(self, klass, masterq): + def start_slave(self, klass, channel): pass def shutdown(self): diff --git a/test/test_dump.py b/test/test_dump.py index e1241e29..5d3f9133 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -3,6 +3,7 @@ from cStringIO import StringIO import libpry from libmproxy import dump, flow, proxy import tutils +import mock def test_strfuncs(): t = tutils.tresp() @@ -21,6 +22,7 @@ class TestDumpMaster: req = tutils.treq() req.content = content l = proxy.Log("connect") + l.reply = mock.MagicMock() m.handle_log(l) cc = req.client_conn cc.connection_error = "error" @@ -29,7 +31,9 @@ class TestDumpMaster: m.handle_clientconnect(cc) m.handle_request(req) f = m.handle_response(resp) - m.handle_clientdisconnect(flow.ClientDisconnect(cc)) + cd = flow.ClientDisconnect(cc) + cd.reply = mock.MagicMock() + m.handle_clientdisconnect(cd) return f def _dummy_cycle(self, n, filt, content, **options): diff --git a/test/test_flow.py b/test/test_flow.py index da5b095e..6aa898ad 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -223,16 +223,16 @@ class TestFlow: f = tutils.tflow() f.request = tutils.treq() f.intercept() - assert not f.request.acked + assert not f.request.reply.acked f.kill(fm) - assert f.request.acked + assert f.request.reply.acked f.intercept() f.response = tutils.tresp() f.request = f.response.request - f.request._ack() - assert not f.response.acked + f.request.reply() + assert not f.response.reply.acked f.kill(fm) - assert f.response.acked + assert f.response.reply.acked def test_killall(self): s = flow.State() @@ -245,25 +245,25 @@ class TestFlow: fm.handle_request(r) for i in s.view: - assert not i.request.acked + assert not i.request.reply.acked s.killall(fm) for i in s.view: - assert i.request.acked + assert i.request.reply.acked def test_accept_intercept(self): f = tutils.tflow() f.request = tutils.treq() f.intercept() - assert not f.request.acked + assert not f.request.reply.acked f.accept_intercept() - assert f.request.acked + assert f.request.reply.acked f.response = tutils.tresp() f.request = f.response.request f.intercept() - f.request._ack() - assert not f.response.acked + f.request.reply() + assert not f.response.reply.acked f.accept_intercept() - assert f.response.acked + assert f.response.reply.acked def test_serialization(self): f = flow.Flow(None) @@ -562,9 +562,11 @@ class TestFlowMaster: fm.handle_response(resp) assert fm.script.ns["log"][-1] == "response" dc = flow.ClientDisconnect(req.client_conn) + dc.reply = controller.DummyReply() fm.handle_clientdisconnect(dc) assert fm.script.ns["log"][-1] == "clientdisconnect" err = flow.Error(f.request, "msg") + err.reply = controller.DummyReply() fm.handle_error(err) assert fm.script.ns["log"][-1] == "error" @@ -598,10 +600,12 @@ class TestFlowMaster: assert not fm.handle_response(rx) dc = flow.ClientDisconnect(req.client_conn) + dc.reply = controller.DummyReply() req.client_conn.requestcount = 1 fm.handle_clientdisconnect(dc) err = flow.Error(f.request, "msg") + err.reply = controller.DummyReply() fm.handle_error(err) fm.load_script(tutils.test_data.path("scripts/a.py")) @@ -621,7 +625,9 @@ class TestFlowMaster: fm.tick(q) assert fm.state.flow_count() - fm.handle_error(flow.Error(f.request, "error")) + err = flow.Error(f.request, "error") + err.reply = controller.DummyReply() + fm.handle_error(err) def test_server_playback(self): controller.should_exit = False diff --git a/test/tservers.py b/test/tservers.py index 2966a436..4cbdc942 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -31,7 +31,7 @@ class TestMaster(flow.FlowMaster): def handle(self, m): flow.FlowMaster.handle(self, m) - m._ack() + m.reply() class ProxyThread(threading.Thread): diff --git a/test/tutils.py b/test/tutils.py index d5497bae..1a1c8724 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -1,15 +1,18 @@ import os, shutil, tempfile from contextlib import contextmanager -from libmproxy import flow, utils +from libmproxy import flow, utils, controller from netlib import certutils - +import mock def treq(conn=None): if not conn: conn = flow.ClientConnect(("address", 22)) + conn.reply = controller.DummyReply() headers = flow.ODictCaseless() headers["header"] = ["qvalue"] - return flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, "content") + r = flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, "content") + r.reply = controller.DummyReply() + return r def tresp(req=None): @@ -18,7 +21,9 @@ def tresp(req=None): headers = flow.ODictCaseless() headers["header_response"] = ["svalue"] cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert")).read()) - return flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert) + resp = flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert) + resp.reply = controller.DummyReply() + return resp def tflow(): @@ -37,9 +42,11 @@ def tflow_err(): r = treq() f = flow.Flow(r) f.error = flow.Error(r, "error") + f.error.reply = controller.DummyReply() return f + @contextmanager def tmpdir(*args, **kwargs): orig_workdir = os.getcwd() -- cgit v1.2.3 From 7800b7c9103ae119a13b81048a84f626850cf04f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Feb 2013 14:08:28 +1300 Subject: Refactor proxy core communications to be clearer. --- libmproxy/controller.py | 10 ++++++---- libmproxy/flow.py | 45 ++++++++++++++------------------------------- libmproxy/proxy.py | 46 ++++++++++++++++++++++++++-------------------- 3 files changed, 46 insertions(+), 55 deletions(-) diff --git a/libmproxy/controller.py b/libmproxy/controller.py index c36bb9df..849d998b 100644 --- a/libmproxy/controller.py +++ b/libmproxy/controller.py @@ -56,13 +56,14 @@ class Channel: def ask(self, m): """ - Send a message to the master, and wait for a response. + Decorate a message with a reply attribute, and send it to the + master. then wait for a response. """ m.reply = Reply(m) self.q.put(m) while not should_exit: try: - # The timeout is here so we can handle a should_exit event. + # The timeout is here so we can handle a should_exit event. g = m.reply.q.get(timeout=0.5) except Queue.Empty: continue @@ -70,9 +71,10 @@ class Channel: def tell(self, m): """ - Send a message to the master, and keep going. + Decorate a message with a dummy reply attribute, send it to the + master, then return immediately. """ - m.reply = None + m.reply = DummyReply() self.q.put(m) diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 0f5fb563..883c7235 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -196,7 +196,15 @@ class decoded(object): self.o.encode(self.ce) -class HTTPMsg: +class StateObject: + def __eq__(self, other): + try: + return self._get_state() == other._get_state() + except AttributeError: + return False + + +class HTTPMsg(StateObject): def get_decoded_content(self): """ Returns the decoded content based on the current Content-Encoding header. @@ -388,13 +396,7 @@ class Request(HTTPMsg): def __hash__(self): return id(self) - def __eq__(self, other): - return self._get_state() == other._get_state() - def copy(self): - """ - Returns a copy of this object. - """ c = copy.copy(self) c.headers = self.headers.copy() return c @@ -698,13 +700,7 @@ class Response(HTTPMsg): state["timestamp_end"], ) - def __eq__(self, other): - return self._get_state() == other._get_state() - def copy(self): - """ - Returns a copy of this object. - """ c = copy.copy(self) c.headers = self.headers.copy() return c @@ -782,7 +778,7 @@ class ClientDisconnect: self.client_conn = client_conn -class ClientConnect: +class ClientConnect(StateObject): """ A single client connection. Each connection can result in multiple HTTP Requests. @@ -804,9 +800,6 @@ class ClientConnect: self.requestcount = 0 self.error = None - def __eq__(self, other): - return self._get_state() == other._get_state() - def _load_state(self, state): self.close = True self.error = state["error"] @@ -829,14 +822,10 @@ class ClientConnect: return None def copy(self): - """ - Returns a copy of this object. - """ - c = copy.copy(self) - return c + return copy.copy(self) -class Error: +class Error(StateObject): """ An Error. @@ -860,9 +849,6 @@ class Error: self.timestamp = state["timestamp"] def copy(self): - """ - Returns a copy of this object. - """ c = copy.copy(self) return c @@ -880,9 +866,6 @@ class Error: state["timestamp"], ) - def __eq__(self, other): - return self._get_state() == other._get_state() - def replace(self, pattern, repl, *args, **kwargs): """ Replaces a regular expression pattern with repl in both the headers @@ -1174,9 +1157,9 @@ class Flow: self.error = Error(self.request, "Connection killed") self.error.reply = controller.DummyReply() if self.request and not self.request.reply.acked: - self.request.reply(None) + self.request.reply(proxy.KILL) elif self.response and not self.response.reply.acked: - self.response.reply(None) + self.response.reply(proxy.KILL) master.handle_error(self.error) self.intercepting = False diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 1fbb6d58..6d476c7b 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -20,6 +20,8 @@ from netlib import odict, tcp, http, wsgi, certutils, http_status import utils, flow, version, platform, controller import authentication +KILL = 0 + class ProxyError(Exception): def __init__(self, code, msg, headers=None): @@ -149,7 +151,7 @@ class ProxyHandler(tcp.BaseHandler): [ "handled %s requests"%cc.requestcount] ) - self.channel.ask(cd) + self.channel.tell(cd) def handle_request(self, cc): try: @@ -166,15 +168,15 @@ class ProxyHandler(tcp.BaseHandler): self.log(cc, "Error in wsgi app.", err.split("\n")) return else: - request = self.channel.ask(request) - if request is None: + request_reply = self.channel.ask(request) + if request_reply == KILL: return - - if isinstance(request, flow.Response): - response = request + elif isinstance(request_reply, flow.Response): request = False - response = self.channel.ask(response) + response = request_reply + response_reply = self.channel.ask(response) else: + request = request_reply if self.config.reverse_proxy: scheme, host, port = self.config.reverse_proxy else: @@ -191,20 +193,24 @@ class ProxyHandler(tcp.BaseHandler): request, httpversion, code, msg, headers, content, sc.cert, sc.rfile.first_byte_timestamp, utils.timestamp() ) - response = self.channel.ask(response) - if response is None: + response_reply = self.channel.ask(response) + # Not replying to the server invalidates the server connection, so we terminate. + if response_reply == KILL: sc.terminate() - if response is None: - return - self.send_response(response) - if request and http.request_connection_close(request.httpversion, request.headers): - return - # We could keep the client connection when the server - # connection needs to go away. However, we want to mimic - # behaviour as closely as possible to the client, so we - # disconnect. - if http.response_connection_close(response.httpversion, response.headers): + + if response_reply == KILL: return + else: + response = response_reply + self.send_response(response) + if request and http.request_connection_close(request.httpversion, request.headers): + return + # We could keep the client connection when the server + # connection needs to go away. However, we want to mimic + # behaviour as closely as possible to the client, so we + # disconnect. + if http.response_connection_close(response.httpversion, response.headers): + return except (IOError, ProxyError, http.HttpError, tcp.NetLibDisconnect), e: if hasattr(e, "code"): cc.error = "%s: %s"%(e.code, e.msg) @@ -234,7 +240,7 @@ class ProxyHandler(tcp.BaseHandler): msg.append(" -> "+i) msg = "\n".join(msg) l = Log(msg) - self.channel.ask(l) + self.channel.tell(l) def find_cert(self, host, port, sni): if self.config.certfile: -- cgit v1.2.3 From f203881b0d7f81a7f8ecbc44b7030060242af01b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Feb 2013 14:13:43 +1300 Subject: Remove redundant clause in controller.Reply --- libmproxy/controller.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/libmproxy/controller.py b/libmproxy/controller.py index 849d998b..da097692 100644 --- a/libmproxy/controller.py +++ b/libmproxy/controller.py @@ -44,10 +44,7 @@ class Reply: def __call__(self, msg=False): if not self.acked: self.acked = True - if msg is None: - self.q.put(msg) - else: - self.q.put(msg or self.obj) + self.q.put(msg or self.obj) class Channel: -- cgit v1.2.3 From 269780c57780d155c4d8bd95fcc2af2104effa5b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Feb 2013 16:34:59 +1300 Subject: Unit test dummy response functions. --- libmproxy/proxy.py | 3 ++- test/test_server.py | 16 ++++++++++++++++ test/tservers.py | 18 ++++++++++++------ 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 6d476c7b..c8fac5f4 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -194,7 +194,8 @@ class ProxyHandler(tcp.BaseHandler): sc.rfile.first_byte_timestamp, utils.timestamp() ) response_reply = self.channel.ask(response) - # Not replying to the server invalidates the server connection, so we terminate. + # Not replying to the server invalidates the server + # connection, so we terminate. if response_reply == KILL: sc.terminate() diff --git a/test/test_server.py b/test/test_server.py index 58dfee58..5cba891c 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -2,6 +2,7 @@ import socket, time from netlib import tcp from libpathod import pathoc import tutils, tservers +from libmproxy import flow """ Note that the choice of response code in these tests matters more than you @@ -144,3 +145,18 @@ class TestProxy(tservers.HTTPProxTest): request = self.master.state.view[1].request assert request.timestamp_end - request.timestamp_start <= 0.1 + + +class MasterFakeResponse(tservers.TestMaster): + def handle_request(self, m): + resp = tutils.tresp() + m.reply(resp) + + +class TestFakeResponse(tservers.HTTPProxTest): + masterclass = MasterFakeResponse + def test_kill(self): + p = self.pathoc() + f = self.pathod("200") + assert "header_response" in f.headers.keys() + diff --git a/test/tservers.py b/test/tservers.py index 4cbdc942..3fdb8d13 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -29,16 +29,20 @@ class TestMaster(flow.FlowMaster): flow.FlowMaster.__init__(self, s, state) self.testq = testq - def handle(self, m): - flow.FlowMaster.handle(self, m) + def handle_request(self, m): + flow.FlowMaster.handle_request(self, m) + m.reply() + + def handle_response(self, m): + flow.FlowMaster.handle_response(self, m) m.reply() class ProxyThread(threading.Thread): - def __init__(self, testq, config): - self.tmaster = TestMaster(testq, config) - controller.should_exit = False + def __init__(self, tmaster): threading.Thread.__init__(self) + self.tmaster = tmaster + controller.should_exit = False @property def port(self): @@ -52,6 +56,7 @@ class ProxyThread(threading.Thread): class ProxTestBase: + masterclass = TestMaster @classmethod def setupAll(cls): cls.tqueue = Queue.Queue() @@ -61,7 +66,8 @@ class ProxTestBase: certfile=tutils.test_data.path("data/testkey.pem"), **pconf ) - cls.proxy = ProxyThread(cls.tqueue, config) + tmaster = cls.masterclass(cls.tqueue, config) + cls.proxy = ProxyThread(tmaster) cls.proxy.start() @property -- cgit v1.2.3 From 05e4d4468ec372adb73649e6980c525a185e9c07 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Feb 2013 21:59:25 +1300 Subject: Test request and response kill functionality. --- .coveragerc | 3 +++ libmproxy/controller.py | 9 ++++++--- test/.gitignore | 1 - test/.pry | 6 ------ test/test_server.py | 32 +++++++++++++++++++++++++++++++- 5 files changed, 40 insertions(+), 11 deletions(-) delete mode 100644 test/.gitignore delete mode 100644 test/.pry diff --git a/.coveragerc b/.coveragerc index 696e0eb8..dd57a6e7 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,6 @@ +[rum] +branch = True + [report] omit = *contrib*, *tnetstring*, *platform* include = *libmproxy* diff --git a/libmproxy/controller.py b/libmproxy/controller.py index da097692..bb22597d 100644 --- a/libmproxy/controller.py +++ b/libmproxy/controller.py @@ -41,10 +41,13 @@ class Reply: self.q = Queue.Queue() self.acked = False - def __call__(self, msg=False): + def __call__(self, msg=None): if not self.acked: self.acked = True - self.q.put(msg or self.obj) + if msg is None: + self.q.put(self.obj) + else: + self.q.put(msg) class Channel: @@ -62,7 +65,7 @@ class Channel: try: # The timeout is here so we can handle a should_exit event. g = m.reply.q.get(timeout=0.5) - except Queue.Empty: + except Queue.Empty: # pragma: nocover continue return g diff --git a/test/.gitignore b/test/.gitignore deleted file mode 100644 index 6350e986..00000000 --- a/test/.gitignore +++ /dev/null @@ -1 +0,0 @@ -.coverage diff --git a/test/.pry b/test/.pry deleted file mode 100644 index f6f18e7b..00000000 --- a/test/.pry +++ /dev/null @@ -1,6 +0,0 @@ -base = .. -coverage = ../libmproxy -exclude = . - ../libmproxy/contrib - ../libmproxy/tnetstring.py - diff --git a/test/test_server.py b/test/test_server.py index 5cba891c..8aefa4b8 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -2,7 +2,7 @@ import socket, time from netlib import tcp from libpathod import pathoc import tutils, tservers -from libmproxy import flow +from libmproxy import flow, proxy """ Note that the choice of response code in these tests matters more than you @@ -147,6 +147,7 @@ class TestProxy(tservers.HTTPProxTest): assert request.timestamp_end - request.timestamp_start <= 0.1 + class MasterFakeResponse(tservers.TestMaster): def handle_request(self, m): resp = tutils.tresp() @@ -160,3 +161,32 @@ class TestFakeResponse(tservers.HTTPProxTest): f = self.pathod("200") assert "header_response" in f.headers.keys() + + +class MasterKillRequest(tservers.TestMaster): + def handle_request(self, m): + m.reply(proxy.KILL) + + +class TestKillRequest(tservers.HTTPProxTest): + masterclass = MasterKillRequest + def test_kill(self): + p = self.pathoc() + tutils.raises("empty reply", self.pathod, "200") + # Nothing should have hit the server + assert not self.last_log() + + +class MasterKillResponse(tservers.TestMaster): + def handle_response(self, m): + m.reply(proxy.KILL) + + +class TestKillResponse(tservers.HTTPProxTest): + masterclass = MasterKillResponse + def test_kill(self): + p = self.pathoc() + tutils.raises("empty reply", self.pathod, "200") + # The server should have seen a request + assert self.last_log() + -- cgit v1.2.3 From 51de9f9fdf6ec4cd345e0b2c8607453cc22c5045 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Feb 2013 10:51:14 +1300 Subject: Test client connection close conditions. --- test/test_server.py | 16 ++++++++++++++++ test/tservers.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/test/test_server.py b/test/test_server.py index 8aefa4b8..a2c65275 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -81,6 +81,22 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): assert "host" in l.request.headers assert l.response.code == 304 + def test_connection_close(self): + # Add a body, so we have a content-length header, which combined with + # HTTP1.1 means the connection is kept alive. + response = '%s/p/200:b@1'%self.urlbase + + # Lets sanity check that the connection does indeed stay open by + # issuing two requests over the same connection + p = self.pathoc() + assert p.request("get:'%s'"%response) + assert p.request("get:'%s'"%response) + + # Now check that the connection is closed as the client specifies + p = self.pathoc() + assert p.request("get:'%s':h'Connection'='close'"%response) + tutils.raises("disconnect", p.request, "get:'%s'"%response) + class TestHTTPS(tservers.HTTPProxTest, SanityMixin): ssl = True diff --git a/test/tservers.py b/test/tservers.py index 3fdb8d13..ae0bacf5 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -117,6 +117,9 @@ class HTTPProxTest(ProxTestBase): return d def pathoc(self, connect_to = None): + """ + Returns a connected Pathoc instance. + """ p = libpathod.pathoc.Pathoc("localhost", self.proxy.port) p.connect(connect_to) return p -- cgit v1.2.3 From 64285140f959eaa939c4cf35585cfe21cbf1a449 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Feb 2013 11:34:01 +1300 Subject: Test a difficult-to-trigger IOError, fix cert generation in test suite. --- doc-src/howmitmproxy.html | 2 +- test/test_server.py | 9 +++++++++ test/tservers.py | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/doc-src/howmitmproxy.html b/doc-src/howmitmproxy.html index 6ea723cd..94c895d7 100644 --- a/doc-src/howmitmproxy.html +++ b/doc-src/howmitmproxy.html @@ -71,7 +71,7 @@ flow of requests and responses are completely opaque to the proxy. ## The MITM in mitmproxy -This is where mitmproxy's fundamental trick comes in to play. The MITM in its +This is where mitmproxy's fundamental trick comes into play. The MITM in its name stands for Man-In-The-Middle - a reference to the process we use to intercept and interfere with these theoretially opaque data streams. The basic idea is to pretend to be the server to the client, and pretend to be the client diff --git a/test/test_server.py b/test/test_server.py index a2c65275..9df88400 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -1,4 +1,5 @@ import socket, time +import mock from netlib import tcp from libpathod import pathoc import tutils, tservers @@ -97,6 +98,14 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): assert p.request("get:'%s':h'Connection'='close'"%response) tutils.raises("disconnect", p.request, "get:'%s'"%response) + def test_proxy_ioerror(self): + # Tests a difficult-to-trigger condition, where an IOError is raised + # within our read loop. + with mock.patch("libmproxy.proxy.ProxyHandler.read_request") as m: + m.side_effect = IOError("error!") + tutils.raises("empty reply", self.pathod, "304") + + class TestHTTPS(tservers.HTTPProxTest, SanityMixin): ssl = True diff --git a/test/tservers.py b/test/tservers.py index ae0bacf5..262536a7 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -63,7 +63,7 @@ class ProxTestBase: cls.server = libpathod.test.Daemon(ssl=cls.ssl) pconf = cls.get_proxy_config() config = proxy.ProxyConfig( - certfile=tutils.test_data.path("data/testkey.pem"), + cacert = tutils.test_data.path("data/serverkey.pem"), **pconf ) tmaster = cls.masterclass(cls.tqueue, config) -- cgit v1.2.3 From d0639e8925541bd6f6f386386c982d23b3828d3d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Feb 2013 14:04:56 +1300 Subject: Handle server disconnects better. Server connections can be closed for legitimate reasons, like timeouts. If we've already pumped data over a server connection, we reconnect on error. If not, we treat it as a legitimate error and pass it on to the client. Fixes #85 --- libmproxy/proxy.py | 39 +++++++++++++++++++++++++++++---------- test/test_server.py | 14 +++++++++++++- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index c8fac5f4..088fe94c 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -117,8 +117,8 @@ class ServerConnectionPool: def get_connection(self, scheme, host, port): sc = self.conn if self.conn and (host, port) != (sc.host, sc.port): - sc.terminate() - self.conn = None + sc.terminate() + self.conn = None if not self.conn: try: self.conn = ServerConnection(self.config, host, port) @@ -127,6 +127,9 @@ class ServerConnectionPool: raise ProxyError(502, v) return self.conn + def del_connection(self, scheme, host, port): + self.conn = None + class ProxyHandler(tcp.BaseHandler): def __init__(self, config, connection, client_address, server, channel, server_version): @@ -181,14 +184,30 @@ class ProxyHandler(tcp.BaseHandler): scheme, host, port = self.config.reverse_proxy else: scheme, host, port = request.scheme, request.host, request.port - sc = self.server_conn_pool.get_connection(scheme, host, port) - sc.send(request) - sc.rfile.reset_timestamps() - httpversion, code, msg, headers, content = http.read_response( - sc.rfile, - request.method, - self.config.body_size_limit - ) + + # If we've already pumped a request over this connection, + # it's possible that the server has timed out. If this is + # the case, we want to reconnect without sending an error + # to the client. + while 1: + try: + sc = self.server_conn_pool.get_connection(scheme, host, port) + sc.send(request) + sc.rfile.reset_timestamps() + httpversion, code, msg, headers, content = http.read_response( + sc.rfile, + request.method, + self.config.body_size_limit + ) + except http.HttpErrorConnClosed, v: + if sc.requestcount > 1: + self.server_conn_pool.del_connection(scheme, host, port) + continue + else: + raise + else: + break + response = flow.Response( request, httpversion, code, msg, headers, content, sc.cert, sc.rfile.first_byte_timestamp, utils.timestamp() diff --git a/test/test_server.py b/test/test_server.py index 9df88400..924b63b7 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -98,6 +98,19 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): assert p.request("get:'%s':h'Connection'='close'"%response) tutils.raises("disconnect", p.request, "get:'%s'"%response) + def test_reconnect(self): + req = "get:'%s/p/200:b@1:da'"%self.urlbase + p = self.pathoc() + assert p.request(req) + # Server has disconnected. Mitmproxy should detect this, and reconnect. + assert p.request(req) + assert p.request(req) + + # However, if the server disconnects on our first try, it's an error. + req = "get:'%s/p/200:b@1:d0'"%self.urlbase + p = self.pathoc() + tutils.raises("server disconnect", p.request, req) + def test_proxy_ioerror(self): # Tests a difficult-to-trigger condition, where an IOError is raised # within our read loop. @@ -106,7 +119,6 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): tutils.raises("empty reply", self.pathod, "304") - class TestHTTPS(tservers.HTTPProxTest, SanityMixin): ssl = True clientcerts = True -- cgit v1.2.3 From 705559d65e5dc5883395efb85bacbf1459eb243c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Feb 2013 17:35:24 +1300 Subject: Refactor to prepare for SNI fixes. --- libmproxy/proxy.py | 99 +++++++++++++++++++++++++++--------------------------- test/test_proxy.py | 12 +++---- 2 files changed, 55 insertions(+), 56 deletions(-) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 088fe94c..d92e2da9 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -50,36 +50,13 @@ class ProxyConfig: self.certstore = certutils.CertStore(certdir) -class RequestReplayThread(threading.Thread): - def __init__(self, config, flow, masterq): - self.config, self.flow, self.channel = config, flow, controller.Channel(masterq) - threading.Thread.__init__(self) - - def run(self): - try: - r = self.flow.request - server = ServerConnection(self.config, r.host, r.port) - server.connect(r.scheme) - server.send(r) - httpversion, code, msg, headers, content = http.read_response( - server.rfile, r.method, self.config.body_size_limit - ) - response = flow.Response( - self.flow.request, httpversion, code, msg, headers, content, server.cert - ) - self.channel.ask(response) - except (ProxyError, http.HttpError, tcp.NetLibError), v: - err = flow.Error(self.flow.request, str(v)) - self.channel.ask(err) - - class ServerConnection(tcp.TCPClient): def __init__(self, config, host, port): tcp.TCPClient.__init__(self, host, port) self.config = config self.requestcount = 0 - def connect(self, scheme): + def connect(self, scheme, sni): tcp.TCPClient.connect(self) if scheme == "https": clientcert = None @@ -88,7 +65,7 @@ class ServerConnection(tcp.TCPClient): if os.path.exists(path): clientcert = path try: - self.convert_to_ssl(clientcert=clientcert, sni=self.host) + self.convert_to_ssl(cert=clientcert, sni=sni) except tcp.NetLibError, v: raise ProxyError(400, str(v)) @@ -109,12 +86,35 @@ class ServerConnection(tcp.TCPClient): pass +class RequestReplayThread(threading.Thread): + def __init__(self, config, flow, masterq): + self.config, self.flow, self.channel = config, flow, controller.Channel(masterq) + threading.Thread.__init__(self) + + def run(self): + try: + r = self.flow.request + server = ServerConnection(self.config, r.host, r.port) + server.connect(r.scheme, r.host) + server.send(r) + httpversion, code, msg, headers, content = http.read_response( + server.rfile, r.method, self.config.body_size_limit + ) + response = flow.Response( + self.flow.request, httpversion, code, msg, headers, content, server.cert + ) + self.channel.ask(response) + except (ProxyError, http.HttpError, tcp.NetLibError), v: + err = flow.Error(self.flow.request, str(v)) + self.channel.ask(err) + + class ServerConnectionPool: def __init__(self, config): self.config = config self.conn = None - def get_connection(self, scheme, host, port): + def get_connection(self, scheme, host, port, sni): sc = self.conn if self.conn and (host, port) != (sc.host, sc.port): sc.terminate() @@ -122,7 +122,7 @@ class ServerConnectionPool: if not self.conn: try: self.conn = ServerConnection(self.config, host, port) - self.conn.connect(scheme) + self.conn.connect(scheme, sni) except tcp.NetLibError, v: raise ProxyError(502, v) return self.conn @@ -190,18 +190,18 @@ class ProxyHandler(tcp.BaseHandler): # the case, we want to reconnect without sending an error # to the client. while 1: + sc = self.server_conn_pool.get_connection(scheme, host, port, host) + sc.send(request) + sc.rfile.reset_timestamps() try: - sc = self.server_conn_pool.get_connection(scheme, host, port) - sc.send(request) - sc.rfile.reset_timestamps() httpversion, code, msg, headers, content = http.read_response( sc.rfile, request.method, self.config.body_size_limit ) except http.HttpErrorConnClosed, v: + self.server_conn_pool.del_connection(scheme, host, port) if sc.requestcount > 1: - self.server_conn_pool.del_connection(scheme, host, port) continue else: raise @@ -324,25 +324,6 @@ class ProxyHandler(tcp.BaseHandler): self.rfile.first_byte_timestamp, utils.timestamp() ) - def read_request_reverse(self, client_conn): - line = self.get_line(self.rfile) - if line == "": - return None - scheme, host, port = self.config.reverse_proxy - r = http.parse_init_http(line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) - method, path, httpversion = r - headers = self.read_headers(authenticate=False) - content = http.read_http_body_request( - self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit - ) - return flow.Request( - client_conn, httpversion, host, port, "http", method, path, headers, content, - self.rfile.first_byte_timestamp, utils.timestamp() - ) - - def read_request_proxy(self, client_conn): line = self.get_line(self.rfile) if line == "": @@ -398,6 +379,24 @@ class ProxyHandler(tcp.BaseHandler): self.rfile.first_byte_timestamp, utils.timestamp() ) + def read_request_reverse(self, client_conn): + line = self.get_line(self.rfile) + if line == "": + return None + scheme, host, port = self.config.reverse_proxy + r = http.parse_init_http(line) + if not r: + raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) + method, path, httpversion = r + headers = self.read_headers(authenticate=False) + content = http.read_http_body_request( + self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit + ) + return flow.Request( + client_conn, httpversion, host, port, "http", method, path, headers, content, + self.rfile.first_byte_timestamp, utils.timestamp() + ) + def read_request(self, client_conn): self.rfile.reset_timestamps() if self.config.transparent_proxy: diff --git a/test/test_proxy.py b/test/test_proxy.py index bdac8697..b575a1d0 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -40,7 +40,7 @@ class TestServerConnection: def test_simple(self): sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port) - sc.connect("http") + sc.connect("http", "host.com") r = tutils.treq() r.path = "/p/200:da" sc.send(r) @@ -54,7 +54,7 @@ class TestServerConnection: def test_terminate_error(self): sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port) - sc.connect("http") + sc.connect("http", "host.com") sc.connection = mock.Mock() sc.connection.close = mock.Mock(side_effect=IOError) sc.terminate() @@ -75,14 +75,14 @@ class TestServerConnectionPool: @mock.patch("libmproxy.proxy.ServerConnection", _dummysc) def test_pooling(self): p = proxy.ServerConnectionPool(proxy.ProxyConfig()) - c = p.get_connection("http", "localhost", 80) - c2 = p.get_connection("http", "localhost", 80) + c = p.get_connection("http", "localhost", 80, "localhost") + c2 = p.get_connection("http", "localhost", 80, "localhost") assert c is c2 - c3 = p.get_connection("http", "foo", 80) + c3 = p.get_connection("http", "foo", 80, "localhost") assert not c is c3 @mock.patch("libmproxy.proxy.ServerConnection", _errsc) def test_connection_error(self): p = proxy.ServerConnectionPool(proxy.ProxyConfig()) - tutils.raises("502", p.get_connection, "http", "localhost", 80) + tutils.raises("502", p.get_connection, "http", "localhost", 80, "localhost") -- cgit v1.2.3 From 02578151410fff4b3c018303290e2f843e244a89 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Feb 2013 22:24:21 +1300 Subject: Significantly simplify server connection handling, and test. --- libmproxy/proxy.py | 66 ++++++++++++++++++++++++++++------------------------- test/test_proxy.py | 35 ++++------------------------ test/test_server.py | 28 ++++++++++++++++++----- test/tservers.py | 23 ++++++++++--------- 4 files changed, 73 insertions(+), 79 deletions(-) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index d92e2da9..7c229064 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -51,21 +51,22 @@ class ProxyConfig: class ServerConnection(tcp.TCPClient): - def __init__(self, config, host, port): + def __init__(self, config, scheme, host, port, sni): tcp.TCPClient.__init__(self, host, port) self.config = config + self.scheme, self.sni = scheme, sni self.requestcount = 0 - def connect(self, scheme, sni): + def connect(self): tcp.TCPClient.connect(self) - if scheme == "https": + if self.scheme == "https": clientcert = None if self.config.clientcerts: path = os.path.join(self.config.clientcerts, self.host.encode("idna")) + ".pem" if os.path.exists(path): clientcert = path try: - self.convert_to_ssl(cert=clientcert, sni=sni) + self.convert_to_ssl(cert=clientcert, sni=self.sni) except tcp.NetLibError, v: raise ProxyError(400, str(v)) @@ -94,8 +95,8 @@ class RequestReplayThread(threading.Thread): def run(self): try: r = self.flow.request - server = ServerConnection(self.config, r.host, r.port) - server.connect(r.scheme, r.host) + server = ServerConnection(self.config, r.scheme, r.host, r.port, r.host) + server.connect() server.send(r) httpversion, code, msg, headers, content = http.read_response( server.rfile, r.method, self.config.body_size_limit @@ -109,37 +110,40 @@ class RequestReplayThread(threading.Thread): self.channel.ask(err) -class ServerConnectionPool: - def __init__(self, config): - self.config = config - self.conn = None - - def get_connection(self, scheme, host, port, sni): - sc = self.conn - if self.conn and (host, port) != (sc.host, sc.port): - sc.terminate() - self.conn = None - if not self.conn: - try: - self.conn = ServerConnection(self.config, host, port) - self.conn.connect(scheme, sni) - except tcp.NetLibError, v: - raise ProxyError(502, v) - return self.conn - - def del_connection(self, scheme, host, port): - self.conn = None - - class ProxyHandler(tcp.BaseHandler): def __init__(self, config, connection, client_address, server, channel, server_version): self.channel, self.server_version = channel, server_version self.config = config - self.server_conn_pool = ServerConnectionPool(config) self.proxy_connect_state = None self.sni = None + self.server_conn = None tcp.BaseHandler.__init__(self, connection, client_address, server) + def get_server_connection(self, cc, scheme, host, port, sni): + sc = self.server_conn + if sc and (scheme, host, port, sni) != (sc.scheme, sc.host, sc.port, sc.sni): + sc.terminate() + self.server_conn = None + self.log( + cc, + "switching connection", [ + "%s://%s:%s (sni=%s) -> %s://%s:%s (sni=%s)"%( + scheme, host, port, sni, + sc.scheme, sc.host, sc.port, sc.sni + ) + ] + ) + if not self.server_conn: + try: + self.server_conn = ServerConnection(self.config, scheme, host, port, sni) + self.server_conn.connect() + except tcp.NetLibError, v: + raise ProxyError(502, v) + return self.server_conn + + def del_server_connection(self): + self.server_conn = None + def handle(self): cc = flow.ClientConnect(self.client_address) self.log(cc, "connect") @@ -190,7 +194,7 @@ class ProxyHandler(tcp.BaseHandler): # the case, we want to reconnect without sending an error # to the client. while 1: - sc = self.server_conn_pool.get_connection(scheme, host, port, host) + sc = self.get_server_connection(cc, scheme, host, port, host) sc.send(request) sc.rfile.reset_timestamps() try: @@ -200,7 +204,7 @@ class ProxyHandler(tcp.BaseHandler): self.config.body_size_limit ) except http.HttpErrorConnClosed, v: - self.server_conn_pool.del_connection(scheme, host, port) + self.del_server_connection() if sc.requestcount > 1: continue else: diff --git a/test/test_proxy.py b/test/test_proxy.py index b575a1d0..3995b393 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -39,8 +39,8 @@ class TestServerConnection: self.d.shutdown() def test_simple(self): - sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port) - sc.connect("http", "host.com") + sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com") + sc.connect() r = tutils.treq() r.path = "/p/200:da" sc.send(r) @@ -53,36 +53,9 @@ class TestServerConnection: sc.terminate() def test_terminate_error(self): - sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port) - sc.connect("http", "host.com") + sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com") + sc.connect() sc.connection = mock.Mock() sc.connection.close = mock.Mock(side_effect=IOError) sc.terminate() - - -def _dummysc(config, host, port): - return mock.MagicMock(config=config, host=host, port=port) - - -def _errsc(config, host, port): - m = mock.MagicMock(config=config, host=host, port=port) - m.connect = mock.MagicMock(side_effect=tcp.NetLibError()) - return m - - -class TestServerConnectionPool: - @mock.patch("libmproxy.proxy.ServerConnection", _dummysc) - def test_pooling(self): - p = proxy.ServerConnectionPool(proxy.ProxyConfig()) - c = p.get_connection("http", "localhost", 80, "localhost") - c2 = p.get_connection("http", "localhost", 80, "localhost") - assert c is c2 - c3 = p.get_connection("http", "foo", 80, "localhost") - assert not c is c3 - - @mock.patch("libmproxy.proxy.ServerConnection", _errsc) - def test_connection_error(self): - p = proxy.ServerConnectionPool(proxy.ProxyConfig()) - tutils.raises("502", p.get_connection, "http", "localhost", 80, "localhost") - diff --git a/test/test_server.py b/test/test_server.py index 924b63b7..f93ddbb3 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -85,7 +85,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): def test_connection_close(self): # Add a body, so we have a content-length header, which combined with # HTTP1.1 means the connection is kept alive. - response = '%s/p/200:b@1'%self.urlbase + response = '%s/p/200:b@1'%self.server.urlbase # Lets sanity check that the connection does indeed stay open by # issuing two requests over the same connection @@ -99,7 +99,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): tutils.raises("disconnect", p.request, "get:'%s'"%response) def test_reconnect(self): - req = "get:'%s/p/200:b@1:da'"%self.urlbase + req = "get:'%s/p/200:b@1:da'"%self.server.urlbase p = self.pathoc() assert p.request(req) # Server has disconnected. Mitmproxy should detect this, and reconnect. @@ -107,7 +107,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): assert p.request(req) # However, if the server disconnects on our first try, it's an error. - req = "get:'%s/p/200:b@1:d0'"%self.urlbase + req = "get:'%s/p/200:b@1:d0'"%self.server.urlbase p = self.pathoc() tutils.raises("server disconnect", p.request, req) @@ -118,13 +118,29 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): m.side_effect = IOError("error!") tutils.raises("empty reply", self.pathod, "304") + def test_get_connection_switching(self): + def switched(l): + for i in l: + if "switching" in i: + return True + req = "get:'%s/p/200:b@1'" + p = self.pathoc() + assert p.request(req%self.server.urlbase) + assert p.request(req%self.server2.urlbase) + assert switched(self.proxy.log) + + def test_get_connection_err(self): + p = self.pathoc() + ret = p.request("get:'http://localhost:0'") + assert ret[1] == 502 + class TestHTTPS(tservers.HTTPProxTest, SanityMixin): ssl = True clientcerts = True def test_clientcert(self): f = self.pathod("304") - assert self.last_log()["request"]["clientcert"]["keyinfo"] + assert self.server.last_log()["request"]["clientcert"]["keyinfo"] class TestReverse(tservers.ReverseProxTest, SanityMixin): @@ -211,7 +227,7 @@ class TestKillRequest(tservers.HTTPProxTest): p = self.pathoc() tutils.raises("empty reply", self.pathod, "200") # Nothing should have hit the server - assert not self.last_log() + assert not self.server.last_log() class MasterKillResponse(tservers.TestMaster): @@ -225,5 +241,5 @@ class TestKillResponse(tservers.HTTPProxTest): p = self.pathoc() tutils.raises("empty reply", self.pathod, "200") # The server should have seen a request - assert self.last_log() + assert self.server.last_log() diff --git a/test/tservers.py b/test/tservers.py index 262536a7..9597dab4 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -28,6 +28,7 @@ class TestMaster(flow.FlowMaster): state = flow.State() flow.FlowMaster.__init__(self, s, state) self.testq = testq + self.log = [] def handle_request(self, m): flow.FlowMaster.handle_request(self, m) @@ -37,6 +38,10 @@ class TestMaster(flow.FlowMaster): flow.FlowMaster.handle_response(self, m) m.reply() + def handle_log(self, l): + self.log.append(l.msg) + l.reply() + class ProxyThread(threading.Thread): def __init__(self, tmaster): @@ -48,6 +53,10 @@ class ProxyThread(threading.Thread): def port(self): return self.tmaster.server.port + @property + def log(self): + return self.tmaster.log + def run(self): self.tmaster.run() @@ -61,6 +70,7 @@ class ProxTestBase: def setupAll(cls): cls.tqueue = Queue.Queue() cls.server = libpathod.test.Daemon(ssl=cls.ssl) + cls.server2 = libpathod.test.Daemon(ssl=cls.ssl) pconf = cls.get_proxy_config() config = proxy.ProxyConfig( cacert = tutils.test_data.path("data/serverkey.pem"), @@ -78,6 +88,7 @@ class ProxTestBase: def teardownAll(cls): cls.proxy.shutdown() cls.server.shutdown() + cls.server2.shutdown() def setUp(self): self.master.state.clear() @@ -95,16 +106,6 @@ class ProxTestBase: (self.scheme, ("127.0.0.1", self.proxy.port)) ) - @property - def urlbase(self): - """ - The URL base for the server instance. - """ - return self.server.urlbase - - def last_log(self): - return self.server.last_log() - class HTTPProxTest(ProxTestBase): ssl = None @@ -129,7 +130,7 @@ class HTTPProxTest(ProxTestBase): Constructs a pathod request, with the appropriate base and proxy. """ return hurl.get( - self.urlbase + "/p/" + spec, + self.server.urlbase + "/p/" + spec, proxy=self.proxies, validate_cert=False, #debug=hurl.utils.stdout_debug -- cgit v1.2.3 From b077189dd5230b6c440a200d867c70c6ce031b66 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Feb 2013 22:52:59 +1300 Subject: Test cert file specification, spruce up server testing truss a bit. --- test/test_server.py | 7 +++++++ test/tservers.py | 31 ++++++++++++++++++------------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/test/test_server.py b/test/test_server.py index f93ddbb3..034fab41 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -143,6 +143,13 @@ class TestHTTPS(tservers.HTTPProxTest, SanityMixin): assert self.server.last_log()["request"]["clientcert"]["keyinfo"] +class TestHTTPSCertfile(tservers.HTTPProxTest, SanityMixin): + ssl = True + certfile = True + def test_certfile(self): + assert self.pathod("304") + + class TestReverse(tservers.ReverseProxTest, SanityMixin): reverse = True diff --git a/test/tservers.py b/test/tservers.py index 9597dab4..998ad6c6 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -65,6 +65,11 @@ class ProxyThread(threading.Thread): class ProxTestBase: + # Test Configuration + ssl = None + clientcerts = False + certfile = None + masterclass = TestMaster @classmethod def setupAll(cls): @@ -106,17 +111,17 @@ class ProxTestBase: (self.scheme, ("127.0.0.1", self.proxy.port)) ) - -class HTTPProxTest(ProxTestBase): - ssl = None - clientcerts = False @classmethod def get_proxy_config(cls): d = dict() if cls.clientcerts: d["clientcerts"] = tutils.test_data.path("data/clientcert") + if cls.certfile: + d["certfile"] =tutils.test_data.path("data/testkey.pem") return d + +class HTTPProxTest(ProxTestBase): def pathoc(self, connect_to = None): """ Returns a connected Pathoc instance. @@ -149,12 +154,12 @@ class TransparentProxTest(ProxTestBase): ssl = None @classmethod def get_proxy_config(cls): - return dict( - transparent_proxy = dict( - resolver = TResolver(cls.server.port), - sslports = [] - ) - ) + d = ProxTestBase.get_proxy_config() + d["transparent_proxy"] = dict( + resolver = TResolver(cls.server.port), + sslports = [] + ) + return d def pathod(self, spec): """ @@ -172,13 +177,13 @@ class ReverseProxTest(ProxTestBase): ssl = None @classmethod def get_proxy_config(cls): - return dict( - reverse_proxy = ( + d = ProxTestBase.get_proxy_config() + d["reverse_proxy"] = ( "https" if cls.ssl else "http", "127.0.0.1", cls.server.port ) - ) + return d def pathod(self, spec): """ -- cgit v1.2.3