diff options
author | Aldo Cortesi <aldo@nullcube.com> | 2016-06-08 11:21:38 +1200 |
---|---|---|
committer | Aldo Cortesi <aldo@nullcube.com> | 2016-06-08 11:21:38 +1200 |
commit | b3bf754e539555351230cbb0887f8838c12fd23c (patch) | |
tree | 862b77357a28b85472bb48387e98e1d3519b625e /mitmproxy/script | |
parent | a388ddfd781fd05a414c07cac8446ef151cbd1d2 (diff) | |
download | mitmproxy-b3bf754e539555351230cbb0887f8838c12fd23c.tar.gz mitmproxy-b3bf754e539555351230cbb0887f8838c12fd23c.tar.bz2 mitmproxy-b3bf754e539555351230cbb0887f8838c12fd23c.zip |
Simplify script concurrency helpers
We now have take() to prevent double-replies.
Diffstat (limited to 'mitmproxy/script')
-rw-r--r-- | mitmproxy/script/concurrent.py | 44 |
1 files changed, 6 insertions, 38 deletions
diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index b81f2ab1..89c835f6 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -8,43 +8,6 @@ from mitmproxy import controller import threading -class ReplyProxy(object): - - def __init__(self, reply_func, script_thread): - self.reply_func = reply_func - self.script_thread = script_thread - self.master_reply = None - - def send(self, message): - if self.master_reply is None: - self.master_reply = message - self.script_thread.start() - return - self.reply_func(message) - - def done(self): - self.reply_func.send(self.master_reply) - - def __getattr__(self, k): - return getattr(self.reply_func, k) - - -def _handle_concurrent_reply(fn, o, *args, **kwargs): - # Make first call to o.reply a no op and start the script thread. - # We must not start the script thread before, as this may lead to a nasty race condition - # where the script thread replies a different response before the normal reply, which then gets swallowed. - - def run(): - fn(*args, **kwargs) - # If the script did not call .reply(), we have to do it now. - reply_proxy.done() - - script_thread = ScriptThread(target=run) - - reply_proxy = ReplyProxy(o.reply, script_thread) - o.reply = reply_proxy - - class ScriptThread(threading.Thread): name = "ScriptThread" @@ -56,5 +19,10 @@ def concurrent(fn): ) def _concurrent(ctx, obj): - _handle_concurrent_reply(fn, obj, ctx, obj) + def run(): + fn(ctx, obj) + if not obj.reply.acked: + obj.reply.ack() + obj.reply.take() + ScriptThread(target=run).start() return _concurrent |