aboutsummaryrefslogtreecommitdiffstats
path: root/libmproxy/script.py
diff options
context:
space:
mode:
Diffstat (limited to 'libmproxy/script.py')
-rw-r--r--libmproxy/script.py86
1 files changed, 49 insertions, 37 deletions
diff --git a/libmproxy/script.py b/libmproxy/script.py
index 46edb86b..e13f0e2b 100644
--- a/libmproxy/script.py
+++ b/libmproxy/script.py
@@ -3,7 +3,7 @@ import os
import traceback
import threading
import shlex
-from . import controller
+import sys
class ScriptError(Exception):
@@ -55,21 +55,17 @@ class ScriptContext:
class Script:
"""
- The instantiator should do something along this vein:
-
- s = Script(argv, master)
- s.load()
+ Script object representing an inline script.
"""
def __init__(self, command, master):
- self.command = command
- self.argv = self.parse_command(command)
+ self.args = self.parse_command(command)
self.ctx = ScriptContext(master)
self.ns = None
self.load()
@classmethod
- def parse_command(klass, command):
+ def parse_command(cls, command):
if not command or not command.strip():
raise ScriptError("Empty script command.")
if os.name == "nt": # Windows: escape all backslashes in the path.
@@ -89,54 +85,66 @@ class Script:
def load(self):
"""
- Loads a module.
+ Loads an inline script.
+
+ Returns:
+ The return value of self.run("start", ...)
- Raises ScriptError on failure, with argument equal to an error
- message that may be a formatted traceback.
+ Raises:
+ ScriptError on failure
"""
+ if self.ns is not None:
+ self.unload()
ns = {}
+ script_dir = os.path.dirname(os.path.abspath(self.args[0]))
+ sys.path.append(script_dir)
try:
- execfile(self.argv[0], ns, ns)
- except Exception as v:
- raise ScriptError(traceback.format_exc(v))
+ execfile(self.args[0], ns, ns)
+ except Exception as e:
+ # Python 3: use exception chaining, https://www.python.org/dev/peps/pep-3134/
+ raise ScriptError(traceback.format_exc(e))
+ sys.path.pop()
self.ns = ns
- r = self.run("start", self.argv)
- if not r[0] and r[1]:
- raise ScriptError(r[1][1])
+ return self.run("start", self.args)
def unload(self):
- return self.run("done")
+ ret = self.run("done")
+ self.ns = None
+ return ret
def run(self, name, *args, **kwargs):
"""
- Runs a plugin method.
+ Runs an inline script hook.
Returns:
+ The return value of the method.
+ None, if the script does not provide the method.
- (True, retval) on success.
- (False, None) on nonexistent method.
- (False, (exc, traceback string)) if there was an exception.
+ Raises:
+ ScriptError if there was an exception.
"""
f = self.ns.get(name)
if f:
try:
- return (True, f(self.ctx, *args, **kwargs))
- except Exception as v:
- return (False, (v, traceback.format_exc(v)))
+ return f(self.ctx, *args, **kwargs)
+ except Exception as e:
+ raise ScriptError(traceback.format_exc(e))
else:
- return (False, None)
+ return None
class ReplyProxy(object):
- def __init__(self, original_reply):
- self._ignore_calls = 1
- self.lock = threading.Lock()
+ def __init__(self, original_reply, script_thread):
self.original_reply = original_reply
+ self.script_thread = script_thread
+ self._ignore_call = True
+ self.lock = threading.Lock()
def __call__(self, *args, **kwargs):
with self.lock:
- if self._ignore_calls > 0:
- self._ignore_calls -= 1
+ if self._ignore_call:
+ self.script_thread.start()
+ self._ignore_call = False
return
self.original_reply(*args, **kwargs)
@@ -145,16 +153,19 @@ class ReplyProxy(object):
def _handle_concurrent_reply(fn, o, *args, **kwargs):
- # Make first call to o.reply a no op
-
- reply_proxy = ReplyProxy(o.reply)
- o.reply = reply_proxy
+ # Make first call to o.reply a no op and start the script thread.
+ # We must not start the script thread before, as this may lead to a nasty race condition
+ # where the script thread replies a different response before the normal reply, which then gets swallowed.
def run():
fn(*args, **kwargs)
# If the script did not call .reply(), we have to do it now.
reply_proxy()
- ScriptThread(target=run).start()
+
+ script_thread = ScriptThread(target=run)
+
+ reply_proxy = ReplyProxy(o.reply, script_thread)
+ o.reply = reply_proxy
class ScriptThread(threading.Thread):
@@ -171,6 +182,7 @@ def concurrent(fn):
"clientdisconnect"):
def _concurrent(ctx, obj):
_handle_concurrent_reply(fn, obj, ctx, obj)
+
return _concurrent
raise NotImplementedError(
- "Concurrent decorator not supported for this method.")
+ "Concurrent decorator not supported for '%s' method." % fn.func_name)