diff options
author | Aldo Cortesi <aldo@corte.si> | 2016-06-08 14:09:59 +1200 |
---|---|---|
committer | Aldo Cortesi <aldo@corte.si> | 2016-06-08 14:09:59 +1200 |
commit | e93fe9d4fa37ec1aae60eee612be7a9cd989891c (patch) | |
tree | b6d29eb42c7989e5b44fb9221a358fecfe672ad4 | |
parent | db11fe0087776c2bf5d95f5aeb751c6c35d67f4b (diff) | |
parent | a5cb241c7cb1035b4d9ff43fb1c8958b7b3dac1d (diff) | |
download | mitmproxy-e93fe9d4fa37ec1aae60eee612be7a9cd989891c.tar.gz mitmproxy-e93fe9d4fa37ec1aae60eee612be7a9cd989891c.tar.bz2 mitmproxy-e93fe9d4fa37ec1aae60eee612be7a9cd989891c.zip |
Merge pull request #1228 from cortesi/controller2
Controller refactoring
-rw-r--r-- | examples/nonblocking.py | 4 | ||||
-rw-r--r-- | examples/redirect_requests.py | 2 | ||||
-rw-r--r-- | examples/tls_passthrough.py | 2 | ||||
-rw-r--r-- | mitmproxy/controller.py | 59 | ||||
-rw-r--r-- | mitmproxy/flow/master.py | 56 | ||||
-rw-r--r-- | mitmproxy/models/flow.py | 5 | ||||
-rw-r--r-- | mitmproxy/script/concurrent.py | 65 | ||||
-rw-r--r-- | test/mitmproxy/mastertest.py | 2 | ||||
-rw-r--r-- | test/mitmproxy/script/test_concurrent.py | 26 | ||||
-rw-r--r-- | test/mitmproxy/test_controller.py | 10 | ||||
-rw-r--r-- | test/mitmproxy/test_flow.py | 8 | ||||
-rw-r--r-- | test/mitmproxy/test_server.py | 11 |
12 files changed, 115 insertions, 135 deletions
diff --git a/examples/nonblocking.py b/examples/nonblocking.py index 41674b2a..4609f389 100644 --- a/examples/nonblocking.py +++ b/examples/nonblocking.py @@ -4,6 +4,6 @@ from mitmproxy.script import concurrent @concurrent # Remove this and see what happens def request(context, flow): - print("handle request: %s%s" % (flow.request.host, flow.request.path)) + context.log("handle request: %s%s" % (flow.request.host, flow.request.path)) time.sleep(5) - print("start request: %s%s" % (flow.request.host, flow.request.path)) + context.log("start request: %s%s" % (flow.request.host, flow.request.path)) diff --git a/examples/redirect_requests.py b/examples/redirect_requests.py index 3ff8f9e4..d7db3f1c 100644 --- a/examples/redirect_requests.py +++ b/examples/redirect_requests.py @@ -16,7 +16,7 @@ def request(context, flow): "HTTP/1.1", 200, "OK", Headers(Content_Type="text/html"), "helloworld") - flow.reply(resp) + flow.reply.send(resp) # Method 2: Redirect the request to a different server if flow.request.pretty_host.endswith("example.org"): diff --git a/examples/tls_passthrough.py b/examples/tls_passthrough.py index 23afe3ff..0c6d450d 100644 --- a/examples/tls_passthrough.py +++ b/examples/tls_passthrough.py @@ -134,5 +134,5 @@ def next_layer(context, next_layer): # We don't intercept - reply with a pass-through layer and add a "skipped" entry. context.log("TLS passthrough for %s" % repr(next_layer.server_conn.address), "info") next_layer_replacement = RawTCPLayer(next_layer.ctx, logging=False) - next_layer.reply(next_layer_replacement) + next_layer.reply.send(next_layer_replacement) context.tls_strategy.record_skipped(server_address) diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index 1498c3ad..084702a6 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -145,27 +145,6 @@ class Channel(object): self.q.put((mtype, m)) -class DummyReply(object): - """ - A reply object that does nothing. Useful when we need an object to seem - like it has a channel, and during testing. - """ - def __init__(self): - self.acked = False - self.taken = False - self.handled = False - - def take(self): - self.taken = True - - def __call__(self, msg=False): - self.acked = True - - -# Special value to distinguish the case where no reply was sent -NO_REPLY = object() - - def handler(f): @functools.wraps(f) def wrapper(*args, **kwargs): @@ -192,7 +171,7 @@ def handler(f): ret = f(*args, **kwargs) if handling and not message.reply.acked and not message.reply.taken: - message.reply() + message.reply.ack() return ret # Mark this function as a handler wrapper wrapper.func_dict["__handler"] = True @@ -215,19 +194,45 @@ class Reply(object): # Has a handler taken responsibility for ack-ing? self.handled = False + def ack(self): + self.send(self.obj) + + def kill(self): + self.send(exceptions.Kill) + def take(self): self.taken = True - def __call__(self, msg=NO_REPLY): + def send(self, msg): if self.acked: raise exceptions.ControlException("Message already acked.") self.acked = True - if msg is NO_REPLY: - self.q.put(self.obj) - else: - self.q.put(msg) + self.q.put(msg) def __del__(self): if not self.acked: # This will be ignored by the interpreter, but emit a warning raise exceptions.ControlException("Un-acked message") + + +class DummyReply(object): + """ + A reply object that does nothing. Useful when we need an object to seem + like it has a channel, and during testing. + """ + def __init__(self): + self.acked = False + self.taken = False + self.handled = False + + def kill(self): + self.send(None) + + def ack(self): + self.send(None) + + def take(self): + self.taken = True + + def send(self, msg): + self.acked = True diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py index ec0bf36d..289102a1 100644 --- a/mitmproxy/flow/master.py +++ b/mitmproxy/flow/master.py @@ -103,9 +103,10 @@ class FlowMaster(controller.Master): except script.ScriptException as e: self.add_event("Script error:\n{}".format(e), "error") - def run_script_hook(self, name, *args, **kwargs): + def run_scripts(self, name, msg): for script_obj in self.scripts: - self._run_single_script_hook(script_obj, name, *args, **kwargs) + if not msg.reply.acked: + self._run_single_script_hook(script_obj, name, msg) def get_ignore_filter(self): return self.server.config.check_ignore.patterns @@ -373,28 +374,28 @@ class FlowMaster(controller.Master): @controller.handler def clientconnect(self, root_layer): - self.run_script_hook("clientconnect", root_layer) + self.run_scripts("clientconnect", root_layer) @controller.handler def clientdisconnect(self, root_layer): - self.run_script_hook("clientdisconnect", root_layer) + self.run_scripts("clientdisconnect", root_layer) @controller.handler def serverconnect(self, server_conn): - self.run_script_hook("serverconnect", server_conn) + self.run_scripts("serverconnect", server_conn) @controller.handler def serverdisconnect(self, server_conn): - self.run_script_hook("serverdisconnect", server_conn) + self.run_scripts("serverdisconnect", server_conn) @controller.handler def next_layer(self, top_layer): - self.run_script_hook("next_layer", top_layer) + self.run_scripts("next_layer", top_layer) @controller.handler def error(self, f): self.state.update_flow(f) - self.run_script_hook("error", f) + self.run_scripts("error", f) if self.client_playback: self.client_playback.clear(f) return f @@ -411,15 +412,19 @@ class FlowMaster(controller.Master): ) if err: self.add_event("Error in wsgi app. %s" % err, "error") - f.reply(exceptions.Kill) + f.reply.kill() return if f not in self.state.flows: # don't add again on replay self.state.add_flow(f) self.active_flows.add(f) - self.replacehooks.run(f) - self.setheaders.run(f) - self.process_new_request(f) - self.run_script_hook("request", f) + if not f.reply.acked: + self.replacehooks.run(f) + if not f.reply.acked: + self.setheaders.run(f) + if not f.reply.acked: + self.process_new_request(f) + if not f.reply.acked: + self.run_scripts("request", f) return f @controller.handler @@ -428,20 +433,23 @@ class FlowMaster(controller.Master): if self.stream_large_bodies: self.stream_large_bodies.run(f, False) except netlib.exceptions.HttpException: - f.reply(exceptions.Kill) + f.reply.kill() return - self.run_script_hook("responseheaders", f) + self.run_scripts("responseheaders", f) return f @controller.handler def response(self, f): self.active_flows.discard(f) self.state.update_flow(f) - self.replacehooks.run(f) - self.setheaders.run(f) - self.run_script_hook("response", f) - if self.client_playback: - self.client_playback.clear(f) + if not f.reply.acked: + self.replacehooks.run(f) + if not f.reply.acked: + self.setheaders.run(f) + self.run_scripts("response", f) + if not f.reply.acked: + if self.client_playback: + self.client_playback.clear(f) self.process_new_response(f) if self.stream: self.stream.add(f) @@ -487,11 +495,11 @@ class FlowMaster(controller.Master): # TODO: This would break mitmproxy currently. # self.state.add_flow(flow) self.active_flows.add(flow) - self.run_script_hook("tcp_open", flow) + self.run_scripts("tcp_open", flow) @controller.handler def tcp_message(self, flow): - self.run_script_hook("tcp_message", flow) + self.run_scripts("tcp_message", flow) message = flow.messages[-1] direction = "->" if message.from_client else "<-" self.add_event("{client} {direction} tcp {direction} {server}".format( @@ -507,14 +515,14 @@ class FlowMaster(controller.Master): repr(flow.server_conn.address), flow.error ), "info") - self.run_script_hook("tcp_error", flow) + self.run_scripts("tcp_error", flow) @controller.handler def tcp_close(self, flow): self.active_flows.discard(flow) if self.stream: self.stream.add(flow) - self.run_script_hook("tcp_close", flow) + self.run_scripts("tcp_close", flow) def shutdown(self): super(FlowMaster, self).shutdown() diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py index e2dac221..de86e451 100644 --- a/mitmproxy/models/flow.py +++ b/mitmproxy/models/flow.py @@ -4,7 +4,6 @@ import time import copy import uuid -from mitmproxy import exceptions from mitmproxy import stateobject from mitmproxy import version from mitmproxy.models.connections import ClientConnection @@ -155,7 +154,7 @@ class Flow(stateobject.StateObject): """ self.error = Error("Connection killed") self.intercepted = False - self.reply(exceptions.Kill) + self.reply.kill() master.error(self) def intercept(self, master): @@ -175,5 +174,5 @@ class Flow(stateobject.StateObject): if not self.intercepted: return self.intercepted = False - self.reply() + self.reply.ack() master.handle_accept_intercept(self) diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index 43d0d328..89c835f6 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -4,62 +4,25 @@ offload computations from mitmproxy's main master thread. """ from __future__ import absolute_import, print_function, division +from mitmproxy import controller import threading -class ReplyProxy(object): - - def __init__(self, reply_func, script_thread): - self.reply_func = reply_func - self.script_thread = script_thread - self.master_reply = None - - def __call__(self, *args): - if self.master_reply is None: - self.master_reply = args - self.script_thread.start() - return - self.reply_func(*args) - - def done(self): - self.reply_func(*self.master_reply) - - def __getattr__(self, k): - return getattr(self.reply_func, k) - - -def _handle_concurrent_reply(fn, o, *args, **kwargs): - # Make first call to o.reply a no op and start the script thread. - # We must not start the script thread before, as this may lead to a nasty race condition - # where the script thread replies a different response before the normal reply, which then gets swallowed. - - def run(): - fn(*args, **kwargs) - # If the script did not call .reply(), we have to do it now. - reply_proxy.done() - - script_thread = ScriptThread(target=run) - - reply_proxy = ReplyProxy(o.reply, script_thread) - o.reply = reply_proxy - - class ScriptThread(threading.Thread): name = "ScriptThread" def concurrent(fn): - if fn.__name__ in ( - "request", - "response", - "error", - "clientconnect", - "serverconnect", - "clientdisconnect", - "next_layer"): - def _concurrent(ctx, obj): - _handle_concurrent_reply(fn, obj, ctx, obj) - - return _concurrent - raise NotImplementedError( - "Concurrent decorator not supported for '%s' method." % fn.__name__) + if fn.__name__ not in controller.Events: + raise NotImplementedError( + "Concurrent decorator not supported for '%s' method." % fn.__name__ + ) + + def _concurrent(ctx, obj): + def run(): + fn(ctx, obj) + if not obj.reply.acked: + obj.reply.ack() + obj.reply.take() + ScriptThread(target=run).start() + return _concurrent diff --git a/test/mitmproxy/mastertest.py b/test/mitmproxy/mastertest.py index 9bb8826d..4d04f337 100644 --- a/test/mitmproxy/mastertest.py +++ b/test/mitmproxy/mastertest.py @@ -16,7 +16,9 @@ class MasterTest: master.request(f) if not f.error: f.response = models.HTTPResponse.wrap(netlib.tutils.tresp(content=content)) + f.reply.acked = False f = master.response(f) + f.client_conn.reply.acked = False master.clientdisconnect(f.client_conn) return f diff --git a/test/mitmproxy/script/test_concurrent.py b/test/mitmproxy/script/test_concurrent.py index c2f169ad..62541f3f 100644 --- a/test/mitmproxy/script/test_concurrent.py +++ b/test/mitmproxy/script/test_concurrent.py @@ -1,29 +1,25 @@ -from threading import Event - from mitmproxy.script import Script from test.mitmproxy import tutils +from mitmproxy import controller +import time -class Dummy: - def __init__(self, reply): - self.reply = reply +class Thing: + def __init__(self): + self.reply = controller.DummyReply() @tutils.skip_appveyor def test_concurrent(): with Script(tutils.test_data.path("data/scripts/concurrent_decorator.py"), None) as s: - def reply(): - reply.acked.set() - reply.acked = Event() - - f1, f2 = Dummy(reply), Dummy(reply) + f1, f2 = Thing(), Thing() s.run("request", f1) - f1.reply() s.run("request", f2) - f2.reply() - assert f1.reply.acked == reply.acked - assert not reply.acked.is_set() - assert reply.acked.wait(10) + start = time.time() + while time.time() - start < 5: + if f1.reply.acked and f2.reply.acked: + return + raise ValueError("Script never acked") def test_concurrent_err(): diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py index 83ad428e..5a68e15b 100644 --- a/test/mitmproxy/test_controller.py +++ b/test/mitmproxy/test_controller.py @@ -66,7 +66,7 @@ class TestChannel(object): def reply(): m, obj = q.get() assert m == "test" - obj.reply(42) + obj.reply.send(42) Thread(target=reply).start() @@ -86,7 +86,7 @@ class TestDummyReply(object): def test_simple(self): reply = controller.DummyReply() assert not reply.acked - reply() + reply.ack() assert reply.acked @@ -94,16 +94,16 @@ class TestReply(object): def test_simple(self): reply = controller.Reply(42) assert not reply.acked - reply("foo") + reply.send("foo") assert reply.acked assert reply.q.get() == "foo" def test_default(self): reply = controller.Reply(42) - reply() + reply.ack() assert reply.q.get() == 42 def test_reply_none(self): reply = controller.Reply(42) - reply(None) + reply.send(None) assert reply.q.get() is None diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 1b1f03f9..af8256c4 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -807,17 +807,22 @@ class TestFlowMaster: fm.load_script(tutils.test_data.path("data/scripts/all.py")) f = tutils.tflow(resp=True) + f.client_conn.acked = False fm.clientconnect(f.client_conn) assert fm.scripts[0].ns["log"][-1] == "clientconnect" + f.server_conn.acked = False fm.serverconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "serverconnect" + f.reply.acked = False fm.request(f) assert fm.scripts[0].ns["log"][-1] == "request" + f.reply.acked = False fm.response(f) assert fm.scripts[0].ns["log"][-1] == "response" # load second script fm.load_script(tutils.test_data.path("data/scripts/all.py")) assert len(fm.scripts) == 2 + f.server_conn.reply.acked = False fm.clientdisconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "clientdisconnect" assert fm.scripts[1].ns["log"][-1] == "clientdisconnect" @@ -828,6 +833,7 @@ class TestFlowMaster: fm.load_script(tutils.test_data.path("data/scripts/all.py")) f.error = tutils.terr() + f.reply.acked = False fm.error(f) assert fm.scripts[0].ns["log"][-1] == "error" @@ -977,10 +983,12 @@ class TestFlowMaster: f = tutils.tflow(resp=True) f.response.headers["set-cookie"] = "foo=bar" fm.request(f) + f.reply.acked = False fm.response(f) assert fm.stickycookie_state.jar assert "cookie" not in f.request.headers f = f.copy() + f.reply.acked = False fm.request(f) assert f.request.headers["cookie"] == "foo=bar" diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index b58c4f44..432340c0 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -14,7 +14,6 @@ from pathod import pathoc, pathod from mitmproxy import controller from mitmproxy.proxy.config import HostMatcher -from mitmproxy.exceptions import Kill from mitmproxy.models import Error, HTTPResponse, HTTPFlow from . import tutils, tservers @@ -744,7 +743,7 @@ class MasterFakeResponse(tservers.TestMaster): @controller.handler def request(self, f): resp = HTTPResponse.wrap(netlib.tutils.tresp()) - f.reply(resp) + f.reply.send(resp) class TestFakeResponse(tservers.HTTPProxyTest): @@ -771,7 +770,7 @@ class MasterKillRequest(tservers.TestMaster): @controller.handler def request(self, f): - f.reply(Kill) + f.reply.kill() class TestKillRequest(tservers.HTTPProxyTest): @@ -788,7 +787,7 @@ class MasterKillResponse(tservers.TestMaster): @controller.handler def response(self, f): - f.reply(Kill) + f.reply.kill() class TestKillResponse(tservers.HTTPProxyTest): @@ -820,7 +819,7 @@ class MasterIncomplete(tservers.TestMaster): def request(self, f): resp = HTTPResponse.wrap(netlib.tutils.tresp()) resp.content = None - f.reply(resp) + f.reply.send(resp) class TestIncompleteResponse(tservers.HTTPProxyTest): @@ -942,7 +941,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest): if not (k[0] in exclude): f.client_conn.finish() f.error = Error("terminated") - f.reply(Kill) + f.reply.kill() return _func(f) setattr(master, attr, handler) |