diff options
Diffstat (limited to 'libmproxy/script.py')
-rw-r--r-- | libmproxy/script.py | 86 |
1 files changed, 49 insertions, 37 deletions
diff --git a/libmproxy/script.py b/libmproxy/script.py index 46edb86b..e13f0e2b 100644 --- a/libmproxy/script.py +++ b/libmproxy/script.py @@ -3,7 +3,7 @@ import os import traceback import threading import shlex -from . import controller +import sys class ScriptError(Exception): @@ -55,21 +55,17 @@ class ScriptContext: class Script: """ - The instantiator should do something along this vein: - - s = Script(argv, master) - s.load() + Script object representing an inline script. """ def __init__(self, command, master): - self.command = command - self.argv = self.parse_command(command) + self.args = self.parse_command(command) self.ctx = ScriptContext(master) self.ns = None self.load() @classmethod - def parse_command(klass, command): + def parse_command(cls, command): if not command or not command.strip(): raise ScriptError("Empty script command.") if os.name == "nt": # Windows: escape all backslashes in the path. @@ -89,54 +85,66 @@ class Script: def load(self): """ - Loads a module. + Loads an inline script. + + Returns: + The return value of self.run("start", ...) - Raises ScriptError on failure, with argument equal to an error - message that may be a formatted traceback. + Raises: + ScriptError on failure """ + if self.ns is not None: + self.unload() ns = {} + script_dir = os.path.dirname(os.path.abspath(self.args[0])) + sys.path.append(script_dir) try: - execfile(self.argv[0], ns, ns) - except Exception as v: - raise ScriptError(traceback.format_exc(v)) + execfile(self.args[0], ns, ns) + except Exception as e: + # Python 3: use exception chaining, https://www.python.org/dev/peps/pep-3134/ + raise ScriptError(traceback.format_exc(e)) + sys.path.pop() self.ns = ns - r = self.run("start", self.argv) - if not r[0] and r[1]: - raise ScriptError(r[1][1]) + return self.run("start", self.args) def unload(self): - return self.run("done") + ret = self.run("done") + self.ns = None + return ret def run(self, name, *args, **kwargs): """ - Runs a plugin method. + Runs an inline script hook. Returns: + The return value of the method. + None, if the script does not provide the method. - (True, retval) on success. - (False, None) on nonexistent method. - (False, (exc, traceback string)) if there was an exception. + Raises: + ScriptError if there was an exception. """ f = self.ns.get(name) if f: try: - return (True, f(self.ctx, *args, **kwargs)) - except Exception as v: - return (False, (v, traceback.format_exc(v))) + return f(self.ctx, *args, **kwargs) + except Exception as e: + raise ScriptError(traceback.format_exc(e)) else: - return (False, None) + return None class ReplyProxy(object): - def __init__(self, original_reply): - self._ignore_calls = 1 - self.lock = threading.Lock() + def __init__(self, original_reply, script_thread): self.original_reply = original_reply + self.script_thread = script_thread + self._ignore_call = True + self.lock = threading.Lock() def __call__(self, *args, **kwargs): with self.lock: - if self._ignore_calls > 0: - self._ignore_calls -= 1 + if self._ignore_call: + self.script_thread.start() + self._ignore_call = False return self.original_reply(*args, **kwargs) @@ -145,16 +153,19 @@ class ReplyProxy(object): def _handle_concurrent_reply(fn, o, *args, **kwargs): - # Make first call to o.reply a no op - - reply_proxy = ReplyProxy(o.reply) - o.reply = reply_proxy + # Make first call to o.reply a no op and start the script thread. + # We must not start the script thread before, as this may lead to a nasty race condition + # where the script thread replies a different response before the normal reply, which then gets swallowed. def run(): fn(*args, **kwargs) # If the script did not call .reply(), we have to do it now. reply_proxy() - ScriptThread(target=run).start() + + script_thread = ScriptThread(target=run) + + reply_proxy = ReplyProxy(o.reply, script_thread) + o.reply = reply_proxy class ScriptThread(threading.Thread): @@ -171,6 +182,7 @@ def concurrent(fn): "clientdisconnect"): def _concurrent(ctx, obj): _handle_concurrent_reply(fn, obj, ctx, obj) + return _concurrent raise NotImplementedError( - "Concurrent decorator not supported for this method.") + "Concurrent decorator not supported for '%s' method." % fn.func_name) |