diff options
-rw-r--r-- | mitmproxy/console/__init__.py | 21 | ||||
-rw-r--r-- | mitmproxy/controller.py | 81 | ||||
-rw-r--r-- | mitmproxy/dump.py | 3 | ||||
-rw-r--r-- | mitmproxy/flow.py | 31 | ||||
-rw-r--r-- | mitmproxy/proxy/root_context.py | 1 | ||||
-rw-r--r-- | mitmproxy/web/__init__.py | 8 | ||||
-rw-r--r-- | netlib/tcp.py | 2 | ||||
-rw-r--r-- | test/mitmproxy/test_controller.py | 31 | ||||
-rw-r--r-- | test/mitmproxy/test_flow.py | 14 | ||||
-rw-r--r-- | test/mitmproxy/test_server.py | 36 | ||||
-rw-r--r-- | test/mitmproxy/tservers.py | 4 |
11 files changed, 154 insertions, 78 deletions
diff --git a/mitmproxy/console/__init__.py b/mitmproxy/console/__init__.py index 1dd032be..9ce02e72 100644 --- a/mitmproxy/console/__init__.py +++ b/mitmproxy/console/__init__.py @@ -16,7 +16,7 @@ import weakref from netlib import tcp -from .. import flow, script, contentviews +from .. import flow, script, contentviews, controller from . import flowlist, flowview, help, window, signals, options from . import grideditor, palettes, statusbar, palettepicker from ..exceptions import FlowReadException, ScriptException @@ -713,14 +713,15 @@ class ConsoleMaster(flow.FlowMaster): ) def process_flow(self, f): - if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay: + should_intercept = any( + [ + self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay, + f.intercepted, + ] + ) + if should_intercept: f.intercept(self) - else: - # check if flow was intercepted within an inline script by flow.intercept() - if f.intercepted: - f.intercept(self) - else: - f.reply() + f.reply.take() signals.flowlist_change.send(self) signals.flow_change.send(self, flow = f) @@ -728,24 +729,28 @@ class ConsoleMaster(flow.FlowMaster): self.eventlist[:] = [] # Handlers + @controller.handler def handle_error(self, f): f = flow.FlowMaster.handle_error(self, f) if f: self.process_flow(f) return f + @controller.handler def handle_request(self, f): f = flow.FlowMaster.handle_request(self, f) if f: self.process_flow(f) return f + @controller.handler def handle_response(self, f): f = flow.FlowMaster.handle_response(self, f) if f: self.process_flow(f) return f + @controller.handler def handle_script_change(self, script): if super(ConsoleMaster, self).handle_script_change(script): signals.status_message.send(message='"{}" reloaded.'.format(script.filename)) diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index af8a77bd..c43fbb84 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -1,8 +1,14 @@ from __future__ import absolute_import from six.moves import queue import threading +import functools +import sys -from .exceptions import Kill +from . import exceptions + + +class ControlError(Exception): + pass class Master(object): @@ -37,6 +43,12 @@ class Master(object): while True: mtype, obj = self.event_queue.get(timeout=timeout) handle_func = getattr(self, "handle_" + mtype) + if not handle_func.func_dict.get("handler"): + raise ControlError( + "Handler function %s is not decorated with controller.handler"%( + handle_func + ) + ) handle_func(obj) self.event_queue.task_done() changed = True @@ -104,7 +116,7 @@ class Channel(object): master. Then wait for a response. Raises: - Kill: All connections should be closed immediately. + exceptions.Kill: All connections should be closed immediately. """ m.reply = Reply(m) self.q.put((mtype, m)) @@ -114,11 +126,10 @@ class Channel(object): g = m.reply.q.get(timeout=0.5) except queue.Empty: # pragma: no cover continue - if g == Kill: - raise Kill() + if g == exceptions.Kill: + raise exceptions.Kill() return g - - raise Kill() + raise exceptions.Kill() def tell(self, mtype, m): """ @@ -138,6 +149,11 @@ class DummyReply(object): def __init__(self): self.acked = False + self.taken = False + self.handled = False + + def take(self): + self.taken = True def __call__(self, msg=False): self.acked = True @@ -147,6 +163,34 @@ class DummyReply(object): NO_REPLY = object() +def handler(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + if len(args) == 1: + message = args[0] + elif len(args) == 2: + message = args[1] + else: + raise ControlError("Handler takes one argument: a message") + + if not hasattr(message, "reply"): + raise ControlError("Message %s has no reply attribute"%message) + + handling = False + # We're the first handler - ack responsibility is ours + if not message.reply.handled: + handling = True + message.reply.handled = True + + ret = f(*args, **kwargs) + + if handling and not message.reply.acked and not message.reply.taken: + message.reply() + return ret + wrapper.func_dict["handler"] = True + return wrapper + + class Reply(object): """ @@ -154,16 +198,29 @@ class Reply(object): This object is used to respond to the message through the return channel. """ - def __init__(self, obj): self.obj = obj self.q = queue.Queue() + # Has this message been acked? self.acked = False + # Has the user taken responsibility for ack-ing? + self.taken = False + # Has a handler taken responsibility for ack-ing? + self.handled = False + + def take(self): + self.taken = True def __call__(self, msg=NO_REPLY): + if self.acked: + raise ControlError("Message already acked.") + self.acked = True + if msg is NO_REPLY: + self.q.put(self.obj) + else: + self.q.put(msg) + + def __del__(self): if not self.acked: - self.acked = True - if msg is NO_REPLY: - self.q.put(self.obj) - else: - self.q.put(msg) + # This will be ignored by the interpreter, but emit a warning + raise ControlError("Un-acked message") diff --git a/mitmproxy/dump.py b/mitmproxy/dump.py index 8f9488be..cbf4b3da 100644 --- a/mitmproxy/dump.py +++ b/mitmproxy/dump.py @@ -6,10 +6,11 @@ import itertools from netlib import tcp from netlib.utils import bytes_to_escaped_str, pretty_size -from . import flow, filt, contentviews +from . import flow, filt, contentviews, controller from .exceptions import ContentViewException, FlowReadException, ScriptException + class DumpError(Exception): pass diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index d70ec2d9..1b4a999a 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -546,7 +546,8 @@ class FlowStore(FlowList): def kill_all(self, master): for f in self._list: - f.kill(master) + if not f.reply.acked: + f.kill(master) class State(object): @@ -985,38 +986,39 @@ class FlowMaster(controller.ServerMaster): if block: rt.join() + @controller.handler def handle_log(self, l): self.add_event(l.msg, l.level) - l.reply() + @controller.handler def handle_clientconnect(self, root_layer): self.run_script_hook("clientconnect", root_layer) - root_layer.reply() + @controller.handler def handle_clientdisconnect(self, root_layer): self.run_script_hook("clientdisconnect", root_layer) - root_layer.reply() + @controller.handler def handle_serverconnect(self, server_conn): self.run_script_hook("serverconnect", server_conn) - server_conn.reply() + @controller.handler def handle_serverdisconnect(self, server_conn): self.run_script_hook("serverdisconnect", server_conn) - server_conn.reply() + @controller.handler def handle_next_layer(self, top_layer): self.run_script_hook("next_layer", top_layer) - top_layer.reply() + @controller.handler def handle_error(self, f): self.state.update_flow(f) self.run_script_hook("error", f) if self.client_playback: self.client_playback.clear(f) - f.reply() return f + @controller.handler def handle_request(self, f): if f.live: app = self.apps.get(f.request) @@ -1039,6 +1041,7 @@ class FlowMaster(controller.ServerMaster): self.run_script_hook("request", f) return f + @controller.handler def handle_responseheaders(self, f): try: if self.stream_large_bodies: @@ -1046,12 +1049,10 @@ class FlowMaster(controller.ServerMaster): except HttpException: f.reply(Kill) return - self.run_script_hook("responseheaders", f) - - f.reply() return f + @controller.handler def handle_response(self, f): self.active_flows.discard(f) self.state.update_flow(f) @@ -1099,13 +1100,14 @@ class FlowMaster(controller.ServerMaster): self.add_event('"{}" reloaded.'.format(s.filename), 'info') return ok + @controller.handler def handle_tcp_open(self, flow): # TODO: This would break mitmproxy currently. # self.state.add_flow(flow) self.active_flows.add(flow) self.run_script_hook("tcp_open", flow) - flow.reply() + @controller.handler def handle_tcp_message(self, flow): self.run_script_hook("tcp_message", flow) message = flow.messages[-1] @@ -1116,22 +1118,21 @@ class FlowMaster(controller.ServerMaster): direction=direction, ), "info") self.add_event(clean_bin(message.content), "debug") - flow.reply() + @controller.handler def handle_tcp_error(self, flow): self.add_event("Error in TCP connection to {}: {}".format( repr(flow.server_conn.address), flow.error ), "info") self.run_script_hook("tcp_error", flow) - flow.reply() + @controller.handler def handle_tcp_close(self, flow): self.active_flows.discard(flow) if self.stream: self.stream.add(flow) self.run_script_hook("tcp_close", flow) - flow.reply() def shutdown(self): super(FlowMaster, self).shutdown() diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index 96e7aab6..9b4e2963 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -132,7 +132,6 @@ class RootContext(object): class Log(object): - def __init__(self, msg, level="info"): self.msg = msg self.level = level diff --git a/mitmproxy/web/__init__.py b/mitmproxy/web/__init__.py index 956d221d..7fef6df1 100644 --- a/mitmproxy/web/__init__.py +++ b/mitmproxy/web/__init__.py @@ -6,7 +6,7 @@ import sys from netlib.http import authentication -from .. import flow +from .. import flow, controller from ..exceptions import FlowReadException from . import app @@ -194,17 +194,19 @@ class WebMaster(flow.FlowMaster): if self.state.intercept and self.state.intercept( f) and not f.request.is_replay: f.intercept(self) - else: - f.reply() + f.reply.take() + @controller.handler def handle_request(self, f): super(WebMaster, self).handle_request(f) self._process_flow(f) + @controller.handler def handle_response(self, f): super(WebMaster, self).handle_response(f) self._process_flow(f) + @controller.handler def handle_error(self, f): super(WebMaster, self).handle_error(f) self._process_flow(f) diff --git a/netlib/tcp.py b/netlib/tcp.py index ad75cff8..c7231dbb 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -901,7 +901,7 @@ class TCPServer(object): """ # If a thread has persisted after interpreter exit, the module might be # none. - if traceback: + if traceback and six: exc = six.text_type(traceback.format_exc()) print(u'-' * 40, file=fp) print( diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py index f7bf615a..c9a8e2f4 100644 --- a/test/mitmproxy/test_controller.py +++ b/test/mitmproxy/test_controller.py @@ -2,7 +2,7 @@ from threading import Thread, Event from mock import Mock -from mitmproxy.controller import Reply, DummyReply, Channel, ServerThread, ServerMaster, Master +from mitmproxy import controller from six.moves import queue from mitmproxy.exceptions import Kill @@ -10,10 +10,15 @@ from mitmproxy.proxy import DummyServer from netlib.tutils import raises +class TMsg: + pass + + class TestMaster(object): def test_simple(self): - class DummyMaster(Master): + class DummyMaster(controller.Master): + @controller.handler def handle_panic(self, _): m.should_exit.set() @@ -23,14 +28,16 @@ class TestMaster(object): m = DummyMaster() assert not m.should_exit.is_set() - m.event_queue.put(("panic", 42)) + msg = TMsg() + msg.reply = controller.DummyReply() + m.event_queue.put(("panic", msg)) m.run() assert m.should_exit.is_set() class TestServerMaster(object): def test_simple(self): - m = ServerMaster() + m = controller.ServerMaster() s = DummyServer(None) m.add_server(s) m.start() @@ -42,7 +49,7 @@ class TestServerMaster(object): class TestServerThread(object): def test_simple(self): m = Mock() - t = ServerThread(m) + t = controller.ServerThread(m) t.run() assert m.serve_forever.called @@ -50,7 +57,7 @@ class TestServerThread(object): class TestChannel(object): def test_tell(self): q = queue.Queue() - channel = Channel(q, Event()) + channel = controller.Channel(q, Event()) m = Mock() channel.tell("test", m) assert q.get() == ("test", m) @@ -66,21 +73,21 @@ class TestChannel(object): Thread(target=reply).start() - channel = Channel(q, Event()) + channel = controller.Channel(q, Event()) assert channel.ask("test", Mock()) == 42 def test_ask_shutdown(self): q = queue.Queue() done = Event() done.set() - channel = Channel(q, done) + channel = controller.Channel(q, done) with raises(Kill): channel.ask("test", Mock()) class TestDummyReply(object): def test_simple(self): - reply = DummyReply() + reply = controller.DummyReply() assert not reply.acked reply() assert reply.acked @@ -88,18 +95,18 @@ class TestDummyReply(object): class TestReply(object): def test_simple(self): - reply = Reply(42) + reply = controller.Reply(42) assert not reply.acked reply("foo") assert reply.acked assert reply.q.get() == "foo" def test_default(self): - reply = Reply(42) + reply = controller.Reply(42) reply() assert reply.q.get() == 42 def test_reply_none(self): - reply = Reply(42) + reply = controller.Reply(42) reply(None) assert reply.q.get() is None diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 5441ea59..f8338dcb 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -460,25 +460,20 @@ class TestFlow(object): fm = flow.FlowMaster(None, s) f = tutils.tflow() f.intercept(mock.Mock()) - assert not f.reply.acked f.kill(fm) - assert f.reply.acked + for i in s.view: + assert "killed" in str(i.error) def test_killall(self): s = flow.State() fm = flow.FlowMaster(None, s) f = tutils.tflow() - fm.handle_request(f) - - f = tutils.tflow() - fm.handle_request(f) + f.intercept(fm) - for i in s.view: - assert not i.reply.acked s.killall(fm) for i in s.view: - assert i.reply.acked + assert "killed" in str(i.error) def test_accept_intercept(self): f = tutils.tflow() @@ -865,7 +860,6 @@ class TestFlowMaster: f.response = HTTPResponse.wrap(netlib.tutils.tresp()) fm.handle_response(f) - assert not fm.handle_response(None) assert s.flow_count() == 1 fm.handle_clientdisconnect(f.client_conn) diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index 0701d52b..f4e7452f 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -12,6 +12,7 @@ from netlib.http import authentication, http1 from netlib.tutils import raises from pathod import pathoc, pathod +from mitmproxy import controller from mitmproxy.proxy.config import HostMatcher from mitmproxy.exceptions import Kill from mitmproxy.models import Error, HTTPResponse, HTTPFlow @@ -623,6 +624,7 @@ class TestProxySSL(tservers.HTTPProxyTest): class MasterRedirectRequest(tservers.TestMaster): redirect_port = None # Set by TestRedirectRequest + @controller.handler def handle_request(self, f): if f.request.path == "/p/201": @@ -636,6 +638,7 @@ class MasterRedirectRequest(tservers.TestMaster): f.request.port = self.redirect_port super(MasterRedirectRequest, self).handle_request(f) + @controller.handler def handle_response(self, f): f.response.content = str(f.client_conn.address.port) f.response.headers["server-conn-id"] = str(f.server_conn.source_address.port) @@ -689,10 +692,9 @@ class MasterStreamRequest(tservers.TestMaster): """ Enables the stream flag on the flow for all requests """ - + @controller.handler def handle_responseheaders(self, f): f.response.stream = True - f.reply() class TestStreamRequest(tservers.HTTPProxyTest): @@ -739,7 +741,7 @@ class TestStreamRequest(tservers.HTTPProxyTest): class MasterFakeResponse(tservers.TestMaster): - + @controller.handler def handle_request(self, f): resp = HTTPResponse.wrap(netlib.tutils.tresp()) f.reply(resp) @@ -767,6 +769,7 @@ class TestServerConnect(tservers.HTTPProxyTest): class MasterKillRequest(tservers.TestMaster): + @controller.handler def handle_request(self, f): f.reply(Kill) @@ -783,6 +786,7 @@ class TestKillRequest(tservers.HTTPProxyTest): class MasterKillResponse(tservers.TestMaster): + @controller.handler def handle_response(self, f): f.reply(Kill) @@ -812,6 +816,7 @@ class TestTransparentResolveError(tservers.TransparentProxyTest): class MasterIncomplete(tservers.TestMaster): + @controller.handler def handle_request(self, f): resp = HTTPResponse.wrap(netlib.tutils.tresp()) resp.content = None @@ -930,7 +935,9 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest): k = [0] # variable scope workaround: put into array _func = getattr(master, attr) - def handler(f): + @controller.handler + def handler(*args): + f = args[-1] k[0] += 1 if not (k[0] in exclude): f.client_conn.finish() @@ -940,11 +947,14 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest): setattr(master, attr, handler) - kill_requests(self.chain[1].tmaster, "handle_request", - exclude=[ - # fail first request - 2, # allow second request - ]) + kill_requests( + self.chain[1].tmaster, + "handle_request", + exclude = [ + # fail first request + 2, # allow second request + ] + ) kill_requests(self.chain[0].tmaster, "handle_request", exclude=[ @@ -1004,10 +1014,10 @@ class AddUpstreamCertsToClientChainMixin: ssl = True servercert = tutils.test_data.path("data/trusted-server.crt") ssloptions = pathod.SSLOptions( - cn="trusted-cert", - certs=[ - ("trusted-cert", servercert) - ] + cn="trusted-cert", + certs=[ + ("trusted-cert", servercert) + ] ) def test_add_upstream_certs_to_client_chain(self): diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index c9d68cfd..51f4109d 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -39,13 +39,13 @@ class TestMaster(flow.FlowMaster): self.apps.add(errapp, "errapp", 80) self.clear_log() + @controller.handler def handle_request(self, f): flow.FlowMaster.handle_request(self, f) - f.reply() + @controller.handler def handle_response(self, f): flow.FlowMaster.handle_response(self, f) - f.reply() def clear_log(self): self.log = [] |