aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@corte.si>2016-06-08 14:09:59 +1200
committerAldo Cortesi <aldo@corte.si>2016-06-08 14:09:59 +1200
commite93fe9d4fa37ec1aae60eee612be7a9cd989891c (patch)
treeb6d29eb42c7989e5b44fb9221a358fecfe672ad4
parentdb11fe0087776c2bf5d95f5aeb751c6c35d67f4b (diff)
parenta5cb241c7cb1035b4d9ff43fb1c8958b7b3dac1d (diff)
downloadmitmproxy-e93fe9d4fa37ec1aae60eee612be7a9cd989891c.tar.gz
mitmproxy-e93fe9d4fa37ec1aae60eee612be7a9cd989891c.tar.bz2
mitmproxy-e93fe9d4fa37ec1aae60eee612be7a9cd989891c.zip
Merge pull request #1228 from cortesi/controller2
Controller refactoring
-rw-r--r--examples/nonblocking.py4
-rw-r--r--examples/redirect_requests.py2
-rw-r--r--examples/tls_passthrough.py2
-rw-r--r--mitmproxy/controller.py59
-rw-r--r--mitmproxy/flow/master.py56
-rw-r--r--mitmproxy/models/flow.py5
-rw-r--r--mitmproxy/script/concurrent.py65
-rw-r--r--test/mitmproxy/mastertest.py2
-rw-r--r--test/mitmproxy/script/test_concurrent.py26
-rw-r--r--test/mitmproxy/test_controller.py10
-rw-r--r--test/mitmproxy/test_flow.py8
-rw-r--r--test/mitmproxy/test_server.py11
12 files changed, 115 insertions, 135 deletions
diff --git a/examples/nonblocking.py b/examples/nonblocking.py
index 41674b2a..4609f389 100644
--- a/examples/nonblocking.py
+++ b/examples/nonblocking.py
@@ -4,6 +4,6 @@ from mitmproxy.script import concurrent
@concurrent # Remove this and see what happens
def request(context, flow):
- print("handle request: %s%s" % (flow.request.host, flow.request.path))
+ context.log("handle request: %s%s" % (flow.request.host, flow.request.path))
time.sleep(5)
- print("start request: %s%s" % (flow.request.host, flow.request.path))
+ context.log("start request: %s%s" % (flow.request.host, flow.request.path))
diff --git a/examples/redirect_requests.py b/examples/redirect_requests.py
index 3ff8f9e4..d7db3f1c 100644
--- a/examples/redirect_requests.py
+++ b/examples/redirect_requests.py
@@ -16,7 +16,7 @@ def request(context, flow):
"HTTP/1.1", 200, "OK",
Headers(Content_Type="text/html"),
"helloworld")
- flow.reply(resp)
+ flow.reply.send(resp)
# Method 2: Redirect the request to a different server
if flow.request.pretty_host.endswith("example.org"):
diff --git a/examples/tls_passthrough.py b/examples/tls_passthrough.py
index 23afe3ff..0c6d450d 100644
--- a/examples/tls_passthrough.py
+++ b/examples/tls_passthrough.py
@@ -134,5 +134,5 @@ def next_layer(context, next_layer):
# We don't intercept - reply with a pass-through layer and add a "skipped" entry.
context.log("TLS passthrough for %s" % repr(next_layer.server_conn.address), "info")
next_layer_replacement = RawTCPLayer(next_layer.ctx, logging=False)
- next_layer.reply(next_layer_replacement)
+ next_layer.reply.send(next_layer_replacement)
context.tls_strategy.record_skipped(server_address)
diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py
index 1498c3ad..084702a6 100644
--- a/mitmproxy/controller.py
+++ b/mitmproxy/controller.py
@@ -145,27 +145,6 @@ class Channel(object):
self.q.put((mtype, m))
-class DummyReply(object):
- """
- A reply object that does nothing. Useful when we need an object to seem
- like it has a channel, and during testing.
- """
- def __init__(self):
- self.acked = False
- self.taken = False
- self.handled = False
-
- def take(self):
- self.taken = True
-
- def __call__(self, msg=False):
- self.acked = True
-
-
-# Special value to distinguish the case where no reply was sent
-NO_REPLY = object()
-
-
def handler(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
@@ -192,7 +171,7 @@ def handler(f):
ret = f(*args, **kwargs)
if handling and not message.reply.acked and not message.reply.taken:
- message.reply()
+ message.reply.ack()
return ret
# Mark this function as a handler wrapper
wrapper.func_dict["__handler"] = True
@@ -215,19 +194,45 @@ class Reply(object):
# Has a handler taken responsibility for ack-ing?
self.handled = False
+ def ack(self):
+ self.send(self.obj)
+
+ def kill(self):
+ self.send(exceptions.Kill)
+
def take(self):
self.taken = True
- def __call__(self, msg=NO_REPLY):
+ def send(self, msg):
if self.acked:
raise exceptions.ControlException("Message already acked.")
self.acked = True
- if msg is NO_REPLY:
- self.q.put(self.obj)
- else:
- self.q.put(msg)
+ self.q.put(msg)
def __del__(self):
if not self.acked:
# This will be ignored by the interpreter, but emit a warning
raise exceptions.ControlException("Un-acked message")
+
+
+class DummyReply(object):
+ """
+ A reply object that does nothing. Useful when we need an object to seem
+ like it has a channel, and during testing.
+ """
+ def __init__(self):
+ self.acked = False
+ self.taken = False
+ self.handled = False
+
+ def kill(self):
+ self.send(None)
+
+ def ack(self):
+ self.send(None)
+
+ def take(self):
+ self.taken = True
+
+ def send(self, msg):
+ self.acked = True
diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py
index ec0bf36d..289102a1 100644
--- a/mitmproxy/flow/master.py
+++ b/mitmproxy/flow/master.py
@@ -103,9 +103,10 @@ class FlowMaster(controller.Master):
except script.ScriptException as e:
self.add_event("Script error:\n{}".format(e), "error")
- def run_script_hook(self, name, *args, **kwargs):
+ def run_scripts(self, name, msg):
for script_obj in self.scripts:
- self._run_single_script_hook(script_obj, name, *args, **kwargs)
+ if not msg.reply.acked:
+ self._run_single_script_hook(script_obj, name, msg)
def get_ignore_filter(self):
return self.server.config.check_ignore.patterns
@@ -373,28 +374,28 @@ class FlowMaster(controller.Master):
@controller.handler
def clientconnect(self, root_layer):
- self.run_script_hook("clientconnect", root_layer)
+ self.run_scripts("clientconnect", root_layer)
@controller.handler
def clientdisconnect(self, root_layer):
- self.run_script_hook("clientdisconnect", root_layer)
+ self.run_scripts("clientdisconnect", root_layer)
@controller.handler
def serverconnect(self, server_conn):
- self.run_script_hook("serverconnect", server_conn)
+ self.run_scripts("serverconnect", server_conn)
@controller.handler
def serverdisconnect(self, server_conn):
- self.run_script_hook("serverdisconnect", server_conn)
+ self.run_scripts("serverdisconnect", server_conn)
@controller.handler
def next_layer(self, top_layer):
- self.run_script_hook("next_layer", top_layer)
+ self.run_scripts("next_layer", top_layer)
@controller.handler
def error(self, f):
self.state.update_flow(f)
- self.run_script_hook("error", f)
+ self.run_scripts("error", f)
if self.client_playback:
self.client_playback.clear(f)
return f
@@ -411,15 +412,19 @@ class FlowMaster(controller.Master):
)
if err:
self.add_event("Error in wsgi app. %s" % err, "error")
- f.reply(exceptions.Kill)
+ f.reply.kill()
return
if f not in self.state.flows: # don't add again on replay
self.state.add_flow(f)
self.active_flows.add(f)
- self.replacehooks.run(f)
- self.setheaders.run(f)
- self.process_new_request(f)
- self.run_script_hook("request", f)
+ if not f.reply.acked:
+ self.replacehooks.run(f)
+ if not f.reply.acked:
+ self.setheaders.run(f)
+ if not f.reply.acked:
+ self.process_new_request(f)
+ if not f.reply.acked:
+ self.run_scripts("request", f)
return f
@controller.handler
@@ -428,20 +433,23 @@ class FlowMaster(controller.Master):
if self.stream_large_bodies:
self.stream_large_bodies.run(f, False)
except netlib.exceptions.HttpException:
- f.reply(exceptions.Kill)
+ f.reply.kill()
return
- self.run_script_hook("responseheaders", f)
+ self.run_scripts("responseheaders", f)
return f
@controller.handler
def response(self, f):
self.active_flows.discard(f)
self.state.update_flow(f)
- self.replacehooks.run(f)
- self.setheaders.run(f)
- self.run_script_hook("response", f)
- if self.client_playback:
- self.client_playback.clear(f)
+ if not f.reply.acked:
+ self.replacehooks.run(f)
+ if not f.reply.acked:
+ self.setheaders.run(f)
+ self.run_scripts("response", f)
+ if not f.reply.acked:
+ if self.client_playback:
+ self.client_playback.clear(f)
self.process_new_response(f)
if self.stream:
self.stream.add(f)
@@ -487,11 +495,11 @@ class FlowMaster(controller.Master):
# TODO: This would break mitmproxy currently.
# self.state.add_flow(flow)
self.active_flows.add(flow)
- self.run_script_hook("tcp_open", flow)
+ self.run_scripts("tcp_open", flow)
@controller.handler
def tcp_message(self, flow):
- self.run_script_hook("tcp_message", flow)
+ self.run_scripts("tcp_message", flow)
message = flow.messages[-1]
direction = "->" if message.from_client else "<-"
self.add_event("{client} {direction} tcp {direction} {server}".format(
@@ -507,14 +515,14 @@ class FlowMaster(controller.Master):
repr(flow.server_conn.address),
flow.error
), "info")
- self.run_script_hook("tcp_error", flow)
+ self.run_scripts("tcp_error", flow)
@controller.handler
def tcp_close(self, flow):
self.active_flows.discard(flow)
if self.stream:
self.stream.add(flow)
- self.run_script_hook("tcp_close", flow)
+ self.run_scripts("tcp_close", flow)
def shutdown(self):
super(FlowMaster, self).shutdown()
diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py
index e2dac221..de86e451 100644
--- a/mitmproxy/models/flow.py
+++ b/mitmproxy/models/flow.py
@@ -4,7 +4,6 @@ import time
import copy
import uuid
-from mitmproxy import exceptions
from mitmproxy import stateobject
from mitmproxy import version
from mitmproxy.models.connections import ClientConnection
@@ -155,7 +154,7 @@ class Flow(stateobject.StateObject):
"""
self.error = Error("Connection killed")
self.intercepted = False
- self.reply(exceptions.Kill)
+ self.reply.kill()
master.error(self)
def intercept(self, master):
@@ -175,5 +174,5 @@ class Flow(stateobject.StateObject):
if not self.intercepted:
return
self.intercepted = False
- self.reply()
+ self.reply.ack()
master.handle_accept_intercept(self)
diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py
index 43d0d328..89c835f6 100644
--- a/mitmproxy/script/concurrent.py
+++ b/mitmproxy/script/concurrent.py
@@ -4,62 +4,25 @@ offload computations from mitmproxy's main master thread.
"""
from __future__ import absolute_import, print_function, division
+from mitmproxy import controller
import threading
-class ReplyProxy(object):
-
- def __init__(self, reply_func, script_thread):
- self.reply_func = reply_func
- self.script_thread = script_thread
- self.master_reply = None
-
- def __call__(self, *args):
- if self.master_reply is None:
- self.master_reply = args
- self.script_thread.start()
- return
- self.reply_func(*args)
-
- def done(self):
- self.reply_func(*self.master_reply)
-
- def __getattr__(self, k):
- return getattr(self.reply_func, k)
-
-
-def _handle_concurrent_reply(fn, o, *args, **kwargs):
- # Make first call to o.reply a no op and start the script thread.
- # We must not start the script thread before, as this may lead to a nasty race condition
- # where the script thread replies a different response before the normal reply, which then gets swallowed.
-
- def run():
- fn(*args, **kwargs)
- # If the script did not call .reply(), we have to do it now.
- reply_proxy.done()
-
- script_thread = ScriptThread(target=run)
-
- reply_proxy = ReplyProxy(o.reply, script_thread)
- o.reply = reply_proxy
-
-
class ScriptThread(threading.Thread):
name = "ScriptThread"
def concurrent(fn):
- if fn.__name__ in (
- "request",
- "response",
- "error",
- "clientconnect",
- "serverconnect",
- "clientdisconnect",
- "next_layer"):
- def _concurrent(ctx, obj):
- _handle_concurrent_reply(fn, obj, ctx, obj)
-
- return _concurrent
- raise NotImplementedError(
- "Concurrent decorator not supported for '%s' method." % fn.__name__)
+ if fn.__name__ not in controller.Events:
+ raise NotImplementedError(
+ "Concurrent decorator not supported for '%s' method." % fn.__name__
+ )
+
+ def _concurrent(ctx, obj):
+ def run():
+ fn(ctx, obj)
+ if not obj.reply.acked:
+ obj.reply.ack()
+ obj.reply.take()
+ ScriptThread(target=run).start()
+ return _concurrent
diff --git a/test/mitmproxy/mastertest.py b/test/mitmproxy/mastertest.py
index 9bb8826d..4d04f337 100644
--- a/test/mitmproxy/mastertest.py
+++ b/test/mitmproxy/mastertest.py
@@ -16,7 +16,9 @@ class MasterTest:
master.request(f)
if not f.error:
f.response = models.HTTPResponse.wrap(netlib.tutils.tresp(content=content))
+ f.reply.acked = False
f = master.response(f)
+ f.client_conn.reply.acked = False
master.clientdisconnect(f.client_conn)
return f
diff --git a/test/mitmproxy/script/test_concurrent.py b/test/mitmproxy/script/test_concurrent.py
index c2f169ad..62541f3f 100644
--- a/test/mitmproxy/script/test_concurrent.py
+++ b/test/mitmproxy/script/test_concurrent.py
@@ -1,29 +1,25 @@
-from threading import Event
-
from mitmproxy.script import Script
from test.mitmproxy import tutils
+from mitmproxy import controller
+import time
-class Dummy:
- def __init__(self, reply):
- self.reply = reply
+class Thing:
+ def __init__(self):
+ self.reply = controller.DummyReply()
@tutils.skip_appveyor
def test_concurrent():
with Script(tutils.test_data.path("data/scripts/concurrent_decorator.py"), None) as s:
- def reply():
- reply.acked.set()
- reply.acked = Event()
-
- f1, f2 = Dummy(reply), Dummy(reply)
+ f1, f2 = Thing(), Thing()
s.run("request", f1)
- f1.reply()
s.run("request", f2)
- f2.reply()
- assert f1.reply.acked == reply.acked
- assert not reply.acked.is_set()
- assert reply.acked.wait(10)
+ start = time.time()
+ while time.time() - start < 5:
+ if f1.reply.acked and f2.reply.acked:
+ return
+ raise ValueError("Script never acked")
def test_concurrent_err():
diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py
index 83ad428e..5a68e15b 100644
--- a/test/mitmproxy/test_controller.py
+++ b/test/mitmproxy/test_controller.py
@@ -66,7 +66,7 @@ class TestChannel(object):
def reply():
m, obj = q.get()
assert m == "test"
- obj.reply(42)
+ obj.reply.send(42)
Thread(target=reply).start()
@@ -86,7 +86,7 @@ class TestDummyReply(object):
def test_simple(self):
reply = controller.DummyReply()
assert not reply.acked
- reply()
+ reply.ack()
assert reply.acked
@@ -94,16 +94,16 @@ class TestReply(object):
def test_simple(self):
reply = controller.Reply(42)
assert not reply.acked
- reply("foo")
+ reply.send("foo")
assert reply.acked
assert reply.q.get() == "foo"
def test_default(self):
reply = controller.Reply(42)
- reply()
+ reply.ack()
assert reply.q.get() == 42
def test_reply_none(self):
reply = controller.Reply(42)
- reply(None)
+ reply.send(None)
assert reply.q.get() is None
diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py
index 1b1f03f9..af8256c4 100644
--- a/test/mitmproxy/test_flow.py
+++ b/test/mitmproxy/test_flow.py
@@ -807,17 +807,22 @@ class TestFlowMaster:
fm.load_script(tutils.test_data.path("data/scripts/all.py"))
f = tutils.tflow(resp=True)
+ f.client_conn.acked = False
fm.clientconnect(f.client_conn)
assert fm.scripts[0].ns["log"][-1] == "clientconnect"
+ f.server_conn.acked = False
fm.serverconnect(f.server_conn)
assert fm.scripts[0].ns["log"][-1] == "serverconnect"
+ f.reply.acked = False
fm.request(f)
assert fm.scripts[0].ns["log"][-1] == "request"
+ f.reply.acked = False
fm.response(f)
assert fm.scripts[0].ns["log"][-1] == "response"
# load second script
fm.load_script(tutils.test_data.path("data/scripts/all.py"))
assert len(fm.scripts) == 2
+ f.server_conn.reply.acked = False
fm.clientdisconnect(f.server_conn)
assert fm.scripts[0].ns["log"][-1] == "clientdisconnect"
assert fm.scripts[1].ns["log"][-1] == "clientdisconnect"
@@ -828,6 +833,7 @@ class TestFlowMaster:
fm.load_script(tutils.test_data.path("data/scripts/all.py"))
f.error = tutils.terr()
+ f.reply.acked = False
fm.error(f)
assert fm.scripts[0].ns["log"][-1] == "error"
@@ -977,10 +983,12 @@ class TestFlowMaster:
f = tutils.tflow(resp=True)
f.response.headers["set-cookie"] = "foo=bar"
fm.request(f)
+ f.reply.acked = False
fm.response(f)
assert fm.stickycookie_state.jar
assert "cookie" not in f.request.headers
f = f.copy()
+ f.reply.acked = False
fm.request(f)
assert f.request.headers["cookie"] == "foo=bar"
diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py
index b58c4f44..432340c0 100644
--- a/test/mitmproxy/test_server.py
+++ b/test/mitmproxy/test_server.py
@@ -14,7 +14,6 @@ from pathod import pathoc, pathod
from mitmproxy import controller
from mitmproxy.proxy.config import HostMatcher
-from mitmproxy.exceptions import Kill
from mitmproxy.models import Error, HTTPResponse, HTTPFlow
from . import tutils, tservers
@@ -744,7 +743,7 @@ class MasterFakeResponse(tservers.TestMaster):
@controller.handler
def request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp())
- f.reply(resp)
+ f.reply.send(resp)
class TestFakeResponse(tservers.HTTPProxyTest):
@@ -771,7 +770,7 @@ class MasterKillRequest(tservers.TestMaster):
@controller.handler
def request(self, f):
- f.reply(Kill)
+ f.reply.kill()
class TestKillRequest(tservers.HTTPProxyTest):
@@ -788,7 +787,7 @@ class MasterKillResponse(tservers.TestMaster):
@controller.handler
def response(self, f):
- f.reply(Kill)
+ f.reply.kill()
class TestKillResponse(tservers.HTTPProxyTest):
@@ -820,7 +819,7 @@ class MasterIncomplete(tservers.TestMaster):
def request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp())
resp.content = None
- f.reply(resp)
+ f.reply.send(resp)
class TestIncompleteResponse(tservers.HTTPProxyTest):
@@ -942,7 +941,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest):
if not (k[0] in exclude):
f.client_conn.finish()
f.error = Error("terminated")
- f.reply(Kill)
+ f.reply.kill()
return _func(f)
setattr(master, attr, handler)