From 982077ec31ddffeab9830a02b425c35cb0b0dac5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 8 Jun 2016 10:14:34 +1200 Subject: Add reply.ack and reply.kill --- mitmproxy/controller.py | 48 +++++++++++++++++++++++++++---------------- mitmproxy/flow/master.py | 4 ++-- mitmproxy/models/flow.py | 5 ++--- test/mitmproxy/test_server.py | 7 +++---- 4 files changed, 37 insertions(+), 27 deletions(-) diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index 1498c3ad..1aac82db 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -145,23 +145,6 @@ class Channel(object): self.q.put((mtype, m)) -class DummyReply(object): - """ - A reply object that does nothing. Useful when we need an object to seem - like it has a channel, and during testing. - """ - def __init__(self): - self.acked = False - self.taken = False - self.handled = False - - def take(self): - self.taken = True - - def __call__(self, msg=False): - self.acked = True - - # Special value to distinguish the case where no reply was sent NO_REPLY = object() @@ -192,7 +175,7 @@ def handler(f): ret = f(*args, **kwargs) if handling and not message.reply.acked and not message.reply.taken: - message.reply() + message.reply.ack() return ret # Mark this function as a handler wrapper wrapper.func_dict["__handler"] = True @@ -215,6 +198,12 @@ class Reply(object): # Has a handler taken responsibility for ack-ing? self.handled = False + def ack(self): + self(NO_REPLY) + + def kill(self): + self(exceptions.Kill) + def take(self): self.taken = True @@ -231,3 +220,26 @@ class Reply(object): if not self.acked: # This will be ignored by the interpreter, but emit a warning raise exceptions.ControlException("Un-acked message") + + +class DummyReply(object): + """ + A reply object that does nothing. Useful when we need an object to seem + like it has a channel, and during testing. + """ + def __init__(self): + self.acked = False + self.taken = False + self.handled = False + + def kill(self): + self() + + def ack(self): + self() + + def take(self): + self.taken = True + + def __call__(self, msg=False): + self.acked = True diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py index ec0bf36d..31475f5b 100644 --- a/mitmproxy/flow/master.py +++ b/mitmproxy/flow/master.py @@ -411,7 +411,7 @@ class FlowMaster(controller.Master): ) if err: self.add_event("Error in wsgi app. %s" % err, "error") - f.reply(exceptions.Kill) + f.reply.kill() return if f not in self.state.flows: # don't add again on replay self.state.add_flow(f) @@ -428,7 +428,7 @@ class FlowMaster(controller.Master): if self.stream_large_bodies: self.stream_large_bodies.run(f, False) except netlib.exceptions.HttpException: - f.reply(exceptions.Kill) + f.reply.kill() return self.run_script_hook("responseheaders", f) return f diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py index e2dac221..de86e451 100644 --- a/mitmproxy/models/flow.py +++ b/mitmproxy/models/flow.py @@ -4,7 +4,6 @@ import time import copy import uuid -from mitmproxy import exceptions from mitmproxy import stateobject from mitmproxy import version from mitmproxy.models.connections import ClientConnection @@ -155,7 +154,7 @@ class Flow(stateobject.StateObject): """ self.error = Error("Connection killed") self.intercepted = False - self.reply(exceptions.Kill) + self.reply.kill() master.error(self) def intercept(self, master): @@ -175,5 +174,5 @@ class Flow(stateobject.StateObject): if not self.intercepted: return self.intercepted = False - self.reply() + self.reply.ack() master.handle_accept_intercept(self) diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index b58c4f44..1cd6cb0c 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -14,7 +14,6 @@ from pathod import pathoc, pathod from mitmproxy import controller from mitmproxy.proxy.config import HostMatcher -from mitmproxy.exceptions import Kill from mitmproxy.models import Error, HTTPResponse, HTTPFlow from . import tutils, tservers @@ -771,7 +770,7 @@ class MasterKillRequest(tservers.TestMaster): @controller.handler def request(self, f): - f.reply(Kill) + f.reply.kill() class TestKillRequest(tservers.HTTPProxyTest): @@ -788,7 +787,7 @@ class MasterKillResponse(tservers.TestMaster): @controller.handler def response(self, f): - f.reply(Kill) + f.reply.kill() class TestKillResponse(tservers.HTTPProxyTest): @@ -942,7 +941,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest): if not (k[0] in exclude): f.client_conn.finish() f.error = Error("terminated") - f.reply(Kill) + f.reply.kill() return _func(f) setattr(master, attr, handler) -- cgit v1.2.3 From a388ddfd781fd05a414c07cac8446ef151cbd1d2 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 8 Jun 2016 10:44:20 +1200 Subject: A new interface for reply Reply is now explicit - it's no longer a callable itself. Instead, we have: reply.kill() - kill the flow reply.ack() - ack, but don't send anything reply.send(message) - send a response This is part of an incremental move to detach reply from our flow objects, and unify the script and handler interfaces. --- examples/redirect_requests.py | 2 +- examples/tls_passthrough.py | 2 +- mitmproxy/controller.py | 21 +++++++-------------- mitmproxy/script/concurrent.py | 31 +++++++++++++------------------ test/mitmproxy/test_controller.py | 10 +++++----- test/mitmproxy/test_server.py | 4 ++-- 6 files changed, 29 insertions(+), 41 deletions(-) diff --git a/examples/redirect_requests.py b/examples/redirect_requests.py index 3ff8f9e4..d7db3f1c 100644 --- a/examples/redirect_requests.py +++ b/examples/redirect_requests.py @@ -16,7 +16,7 @@ def request(context, flow): "HTTP/1.1", 200, "OK", Headers(Content_Type="text/html"), "helloworld") - flow.reply(resp) + flow.reply.send(resp) # Method 2: Redirect the request to a different server if flow.request.pretty_host.endswith("example.org"): diff --git a/examples/tls_passthrough.py b/examples/tls_passthrough.py index 23afe3ff..0c6d450d 100644 --- a/examples/tls_passthrough.py +++ b/examples/tls_passthrough.py @@ -134,5 +134,5 @@ def next_layer(context, next_layer): # We don't intercept - reply with a pass-through layer and add a "skipped" entry. context.log("TLS passthrough for %s" % repr(next_layer.server_conn.address), "info") next_layer_replacement = RawTCPLayer(next_layer.ctx, logging=False) - next_layer.reply(next_layer_replacement) + next_layer.reply.send(next_layer_replacement) context.tls_strategy.record_skipped(server_address) diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index 1aac82db..084702a6 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -145,10 +145,6 @@ class Channel(object): self.q.put((mtype, m)) -# Special value to distinguish the case where no reply was sent -NO_REPLY = object() - - def handler(f): @functools.wraps(f) def wrapper(*args, **kwargs): @@ -199,22 +195,19 @@ class Reply(object): self.handled = False def ack(self): - self(NO_REPLY) + self.send(self.obj) def kill(self): - self(exceptions.Kill) + self.send(exceptions.Kill) def take(self): self.taken = True - def __call__(self, msg=NO_REPLY): + def send(self, msg): if self.acked: raise exceptions.ControlException("Message already acked.") self.acked = True - if msg is NO_REPLY: - self.q.put(self.obj) - else: - self.q.put(msg) + self.q.put(msg) def __del__(self): if not self.acked: @@ -233,13 +226,13 @@ class DummyReply(object): self.handled = False def kill(self): - self() + self.send(None) def ack(self): - self() + self.send(None) def take(self): self.taken = True - def __call__(self, msg=False): + def send(self, msg): self.acked = True diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index 43d0d328..b81f2ab1 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -4,6 +4,7 @@ offload computations from mitmproxy's main master thread. """ from __future__ import absolute_import, print_function, division +from mitmproxy import controller import threading @@ -14,15 +15,15 @@ class ReplyProxy(object): self.script_thread = script_thread self.master_reply = None - def __call__(self, *args): + def send(self, message): if self.master_reply is None: - self.master_reply = args + self.master_reply = message self.script_thread.start() return - self.reply_func(*args) + self.reply_func(message) def done(self): - self.reply_func(*self.master_reply) + self.reply_func.send(self.master_reply) def __getattr__(self, k): return getattr(self.reply_func, k) @@ -49,17 +50,11 @@ class ScriptThread(threading.Thread): def concurrent(fn): - if fn.__name__ in ( - "request", - "response", - "error", - "clientconnect", - "serverconnect", - "clientdisconnect", - "next_layer"): - def _concurrent(ctx, obj): - _handle_concurrent_reply(fn, obj, ctx, obj) - - return _concurrent - raise NotImplementedError( - "Concurrent decorator not supported for '%s' method." % fn.__name__) + if fn.__name__ not in controller.Events: + raise NotImplementedError( + "Concurrent decorator not supported for '%s' method." % fn.__name__ + ) + + def _concurrent(ctx, obj): + _handle_concurrent_reply(fn, obj, ctx, obj) + return _concurrent diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py index 83ad428e..5a68e15b 100644 --- a/test/mitmproxy/test_controller.py +++ b/test/mitmproxy/test_controller.py @@ -66,7 +66,7 @@ class TestChannel(object): def reply(): m, obj = q.get() assert m == "test" - obj.reply(42) + obj.reply.send(42) Thread(target=reply).start() @@ -86,7 +86,7 @@ class TestDummyReply(object): def test_simple(self): reply = controller.DummyReply() assert not reply.acked - reply() + reply.ack() assert reply.acked @@ -94,16 +94,16 @@ class TestReply(object): def test_simple(self): reply = controller.Reply(42) assert not reply.acked - reply("foo") + reply.send("foo") assert reply.acked assert reply.q.get() == "foo" def test_default(self): reply = controller.Reply(42) - reply() + reply.ack() assert reply.q.get() == 42 def test_reply_none(self): reply = controller.Reply(42) - reply(None) + reply.send(None) assert reply.q.get() is None diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index 1cd6cb0c..432340c0 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -743,7 +743,7 @@ class MasterFakeResponse(tservers.TestMaster): @controller.handler def request(self, f): resp = HTTPResponse.wrap(netlib.tutils.tresp()) - f.reply(resp) + f.reply.send(resp) class TestFakeResponse(tservers.HTTPProxyTest): @@ -819,7 +819,7 @@ class MasterIncomplete(tservers.TestMaster): def request(self, f): resp = HTTPResponse.wrap(netlib.tutils.tresp()) resp.content = None - f.reply(resp) + f.reply.send(resp) class TestIncompleteResponse(tservers.HTTPProxyTest): -- cgit v1.2.3 From b3bf754e539555351230cbb0887f8838c12fd23c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 8 Jun 2016 11:21:38 +1200 Subject: Simplify script concurrency helpers We now have take() to prevent double-replies. --- examples/nonblocking.py | 4 +-- mitmproxy/script/concurrent.py | 44 +++++--------------------------- test/mitmproxy/script/test_concurrent.py | 26 ++++++++----------- 3 files changed, 19 insertions(+), 55 deletions(-) diff --git a/examples/nonblocking.py b/examples/nonblocking.py index 41674b2a..4609f389 100644 --- a/examples/nonblocking.py +++ b/examples/nonblocking.py @@ -4,6 +4,6 @@ from mitmproxy.script import concurrent @concurrent # Remove this and see what happens def request(context, flow): - print("handle request: %s%s" % (flow.request.host, flow.request.path)) + context.log("handle request: %s%s" % (flow.request.host, flow.request.path)) time.sleep(5) - print("start request: %s%s" % (flow.request.host, flow.request.path)) + context.log("start request: %s%s" % (flow.request.host, flow.request.path)) diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index b81f2ab1..89c835f6 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -8,43 +8,6 @@ from mitmproxy import controller import threading -class ReplyProxy(object): - - def __init__(self, reply_func, script_thread): - self.reply_func = reply_func - self.script_thread = script_thread - self.master_reply = None - - def send(self, message): - if self.master_reply is None: - self.master_reply = message - self.script_thread.start() - return - self.reply_func(message) - - def done(self): - self.reply_func.send(self.master_reply) - - def __getattr__(self, k): - return getattr(self.reply_func, k) - - -def _handle_concurrent_reply(fn, o, *args, **kwargs): - # Make first call to o.reply a no op and start the script thread. - # We must not start the script thread before, as this may lead to a nasty race condition - # where the script thread replies a different response before the normal reply, which then gets swallowed. - - def run(): - fn(*args, **kwargs) - # If the script did not call .reply(), we have to do it now. - reply_proxy.done() - - script_thread = ScriptThread(target=run) - - reply_proxy = ReplyProxy(o.reply, script_thread) - o.reply = reply_proxy - - class ScriptThread(threading.Thread): name = "ScriptThread" @@ -56,5 +19,10 @@ def concurrent(fn): ) def _concurrent(ctx, obj): - _handle_concurrent_reply(fn, obj, ctx, obj) + def run(): + fn(ctx, obj) + if not obj.reply.acked: + obj.reply.ack() + obj.reply.take() + ScriptThread(target=run).start() return _concurrent diff --git a/test/mitmproxy/script/test_concurrent.py b/test/mitmproxy/script/test_concurrent.py index c2f169ad..62541f3f 100644 --- a/test/mitmproxy/script/test_concurrent.py +++ b/test/mitmproxy/script/test_concurrent.py @@ -1,29 +1,25 @@ -from threading import Event - from mitmproxy.script import Script from test.mitmproxy import tutils +from mitmproxy import controller +import time -class Dummy: - def __init__(self, reply): - self.reply = reply +class Thing: + def __init__(self): + self.reply = controller.DummyReply() @tutils.skip_appveyor def test_concurrent(): with Script(tutils.test_data.path("data/scripts/concurrent_decorator.py"), None) as s: - def reply(): - reply.acked.set() - reply.acked = Event() - - f1, f2 = Dummy(reply), Dummy(reply) + f1, f2 = Thing(), Thing() s.run("request", f1) - f1.reply() s.run("request", f2) - f2.reply() - assert f1.reply.acked == reply.acked - assert not reply.acked.is_set() - assert reply.acked.wait(10) + start = time.time() + while time.time() - start < 5: + if f1.reply.acked and f2.reply.acked: + return + raise ValueError("Script never acked") def test_concurrent_err(): -- cgit v1.2.3 From a5cb241c7cb1035b4d9ff43fb1c8958b7b3dac1d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 8 Jun 2016 12:58:58 +1200 Subject: If a message has been acked, all other processors are skipped This applies the constraint, but does to clumsily. When we've unified modules and processors it will be much nicer. We also make some exceptions for the master processors that we may want to re-evaluate down the track. --- mitmproxy/flow/master.py | 52 +++++++++++++++++++++++++------------------- test/mitmproxy/mastertest.py | 2 ++ test/mitmproxy/test_flow.py | 8 +++++++ 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py index 31475f5b..289102a1 100644 --- a/mitmproxy/flow/master.py +++ b/mitmproxy/flow/master.py @@ -103,9 +103,10 @@ class FlowMaster(controller.Master): except script.ScriptException as e: self.add_event("Script error:\n{}".format(e), "error") - def run_script_hook(self, name, *args, **kwargs): + def run_scripts(self, name, msg): for script_obj in self.scripts: - self._run_single_script_hook(script_obj, name, *args, **kwargs) + if not msg.reply.acked: + self._run_single_script_hook(script_obj, name, msg) def get_ignore_filter(self): return self.server.config.check_ignore.patterns @@ -373,28 +374,28 @@ class FlowMaster(controller.Master): @controller.handler def clientconnect(self, root_layer): - self.run_script_hook("clientconnect", root_layer) + self.run_scripts("clientconnect", root_layer) @controller.handler def clientdisconnect(self, root_layer): - self.run_script_hook("clientdisconnect", root_layer) + self.run_scripts("clientdisconnect", root_layer) @controller.handler def serverconnect(self, server_conn): - self.run_script_hook("serverconnect", server_conn) + self.run_scripts("serverconnect", server_conn) @controller.handler def serverdisconnect(self, server_conn): - self.run_script_hook("serverdisconnect", server_conn) + self.run_scripts("serverdisconnect", server_conn) @controller.handler def next_layer(self, top_layer): - self.run_script_hook("next_layer", top_layer) + self.run_scripts("next_layer", top_layer) @controller.handler def error(self, f): self.state.update_flow(f) - self.run_script_hook("error", f) + self.run_scripts("error", f) if self.client_playback: self.client_playback.clear(f) return f @@ -416,10 +417,14 @@ class FlowMaster(controller.Master): if f not in self.state.flows: # don't add again on replay self.state.add_flow(f) self.active_flows.add(f) - self.replacehooks.run(f) - self.setheaders.run(f) - self.process_new_request(f) - self.run_script_hook("request", f) + if not f.reply.acked: + self.replacehooks.run(f) + if not f.reply.acked: + self.setheaders.run(f) + if not f.reply.acked: + self.process_new_request(f) + if not f.reply.acked: + self.run_scripts("request", f) return f @controller.handler @@ -430,18 +435,21 @@ class FlowMaster(controller.Master): except netlib.exceptions.HttpException: f.reply.kill() return - self.run_script_hook("responseheaders", f) + self.run_scripts("responseheaders", f) return f @controller.handler def response(self, f): self.active_flows.discard(f) self.state.update_flow(f) - self.replacehooks.run(f) - self.setheaders.run(f) - self.run_script_hook("response", f) - if self.client_playback: - self.client_playback.clear(f) + if not f.reply.acked: + self.replacehooks.run(f) + if not f.reply.acked: + self.setheaders.run(f) + self.run_scripts("response", f) + if not f.reply.acked: + if self.client_playback: + self.client_playback.clear(f) self.process_new_response(f) if self.stream: self.stream.add(f) @@ -487,11 +495,11 @@ class FlowMaster(controller.Master): # TODO: This would break mitmproxy currently. # self.state.add_flow(flow) self.active_flows.add(flow) - self.run_script_hook("tcp_open", flow) + self.run_scripts("tcp_open", flow) @controller.handler def tcp_message(self, flow): - self.run_script_hook("tcp_message", flow) + self.run_scripts("tcp_message", flow) message = flow.messages[-1] direction = "->" if message.from_client else "<-" self.add_event("{client} {direction} tcp {direction} {server}".format( @@ -507,14 +515,14 @@ class FlowMaster(controller.Master): repr(flow.server_conn.address), flow.error ), "info") - self.run_script_hook("tcp_error", flow) + self.run_scripts("tcp_error", flow) @controller.handler def tcp_close(self, flow): self.active_flows.discard(flow) if self.stream: self.stream.add(flow) - self.run_script_hook("tcp_close", flow) + self.run_scripts("tcp_close", flow) def shutdown(self): super(FlowMaster, self).shutdown() diff --git a/test/mitmproxy/mastertest.py b/test/mitmproxy/mastertest.py index 9bb8826d..4d04f337 100644 --- a/test/mitmproxy/mastertest.py +++ b/test/mitmproxy/mastertest.py @@ -16,7 +16,9 @@ class MasterTest: master.request(f) if not f.error: f.response = models.HTTPResponse.wrap(netlib.tutils.tresp(content=content)) + f.reply.acked = False f = master.response(f) + f.client_conn.reply.acked = False master.clientdisconnect(f.client_conn) return f diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 1b1f03f9..af8256c4 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -807,17 +807,22 @@ class TestFlowMaster: fm.load_script(tutils.test_data.path("data/scripts/all.py")) f = tutils.tflow(resp=True) + f.client_conn.acked = False fm.clientconnect(f.client_conn) assert fm.scripts[0].ns["log"][-1] == "clientconnect" + f.server_conn.acked = False fm.serverconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "serverconnect" + f.reply.acked = False fm.request(f) assert fm.scripts[0].ns["log"][-1] == "request" + f.reply.acked = False fm.response(f) assert fm.scripts[0].ns["log"][-1] == "response" # load second script fm.load_script(tutils.test_data.path("data/scripts/all.py")) assert len(fm.scripts) == 2 + f.server_conn.reply.acked = False fm.clientdisconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "clientdisconnect" assert fm.scripts[1].ns["log"][-1] == "clientdisconnect" @@ -828,6 +833,7 @@ class TestFlowMaster: fm.load_script(tutils.test_data.path("data/scripts/all.py")) f.error = tutils.terr() + f.reply.acked = False fm.error(f) assert fm.scripts[0].ns["log"][-1] == "error" @@ -977,10 +983,12 @@ class TestFlowMaster: f = tutils.tflow(resp=True) f.response.headers["set-cookie"] = "foo=bar" fm.request(f) + f.reply.acked = False fm.response(f) assert fm.stickycookie_state.jar assert "cookie" not in f.request.headers f = f.copy() + f.reply.acked = False fm.request(f) assert f.request.headers["cookie"] == "foo=bar" -- cgit v1.2.3