diff options
48 files changed, 267 insertions, 312 deletions
diff --git a/docs/scripting/inlinescripts.rst b/docs/scripting/inlinescripts.rst index 1ee44972..bc9d5ff5 100644 --- a/docs/scripting/inlinescripts.rst +++ b/docs/scripting/inlinescripts.rst @@ -15,9 +15,7 @@ client: :caption: examples/add_header.py :language: python -The first argument to each event method is an instance of -:py:class:`~mitmproxy.script.ScriptContext` that lets the script interact with the global mitmproxy -state. The **response** event also gets an instance of :py:class:`~mitmproxy.models.HTTPFlow`, +All events that deal with an HTTP request get an instance of :py:class:`~mitmproxy.models.HTTPFlow`, which we can use to manipulate the response itself. We can now run this script using mitmdump or mitmproxy as follows: @@ -36,11 +34,6 @@ We encourage you to either browse them locally or on `GitHub`_. Events ------ -The ``context`` argument passed to each event method is always a -:py:class:`~mitmproxy.script.ScriptContext` instance. It is guaranteed to be the same object -for the scripts lifetime and is not shared between multiple inline scripts. You can safely use it -to store any form of state you require. - Script Lifecycle Events ^^^^^^^^^^^^^^^^^^^^^^^ @@ -155,8 +148,9 @@ The canonical API documentation is the code, which you can browse here, locally The main classes you will deal with in writing mitmproxy scripts are: -:py:class:`~mitmproxy.script.ScriptContext` - - A handle for interacting with mitmproxy's Flow Master from within scripts. +:py:class:`mitmproxy.flow.FlowMaster` + - The "heart" of mitmproxy, usually subclassed as :py:class:`mitmproxy.dump.DumpMaster` or + :py:class:`mitmproxy.console.ConsoleMaster`. :py:class:`~mitmproxy.models.ClientConnection` - Describes a client connection. :py:class:`~mitmproxy.models.ServerConnection` @@ -173,16 +167,7 @@ The main classes you will deal with in writing mitmproxy scripts are: - A dictionary-like object for managing HTTP headers. :py:class:`netlib.certutils.SSLCert` - Exposes information SSL certificates. -:py:class:`mitmproxy.flow.FlowMaster` - - The "heart" of mitmproxy, usually subclassed as :py:class:`mitmproxy.dump.DumpMaster` or - :py:class:`mitmproxy.console.ConsoleMaster`. - -Script Context --------------- -.. autoclass:: mitmproxy.script.ScriptContext - :members: - :undoc-members: Running scripts in parallel --------------------------- diff --git a/examples/add_header.py b/examples/add_header.py index cf1b53cc..3e0b5f1e 100644 --- a/examples/add_header.py +++ b/examples/add_header.py @@ -1,2 +1,2 @@ -def response(context, flow): +def response(flow): flow.response.headers["newheader"] = "foo" diff --git a/examples/change_upstream_proxy.py b/examples/change_upstream_proxy.py index 34a6eece..49d5379f 100644 --- a/examples/change_upstream_proxy.py +++ b/examples/change_upstream_proxy.py @@ -14,7 +14,7 @@ def proxy_address(flow): return ("localhost", 8081) -def request(context, flow): +def request(flow): if flow.request.method == "CONNECT": # If the decision is done by domain, one could also modify the server address here. # We do it after CONNECT here to have the request data available as well. diff --git a/examples/custom_contentviews.py b/examples/custom_contentviews.py index 92fb6a58..5a63e2a0 100644 --- a/examples/custom_contentviews.py +++ b/examples/custom_contentviews.py @@ -62,9 +62,9 @@ class ViewPigLatin(contentviews.View): pig_view = ViewPigLatin() -def start(context): - context.add_contentview(pig_view) +def start(): + contentviews.add(pig_view) -def done(context): - context.remove_contentview(pig_view) +def done(): + contentviews.remove(pig_view) diff --git a/examples/dns_spoofing.py b/examples/dns_spoofing.py index 8d715f33..c020047f 100644 --- a/examples/dns_spoofing.py +++ b/examples/dns_spoofing.py @@ -28,7 +28,7 @@ import re parse_host_header = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$") -def request(context, flow): +def request(flow): if flow.client_conn.ssl_established: flow.request.scheme = "https" sni = flow.client_conn.connection.get_servername() diff --git a/examples/dup_and_replay.py b/examples/dup_and_replay.py index 9ba91d3b..b47bf951 100644 --- a/examples/dup_and_replay.py +++ b/examples/dup_and_replay.py @@ -1,4 +1,7 @@ -def request(context, flow): - f = context.duplicate_flow(flow) +from mitmproxy import master + + +def request(flow): + f = master.duplicate_flow(flow) f.request.path = "/changed" - context.replay_request(f) + master.replay_request(f, block=True, run_scripthooks=False) diff --git a/examples/fail_with_500.py b/examples/fail_with_500.py index aec85b50..9710f74a 100644 --- a/examples/fail_with_500.py +++ b/examples/fail_with_500.py @@ -1,3 +1,3 @@ -def response(context, flow): +def response(flow): flow.response.status_code = 500 flow.response.content = b"" diff --git a/examples/filt.py b/examples/filt.py index 1a423845..21744edd 100644 --- a/examples/filt.py +++ b/examples/filt.py @@ -3,14 +3,16 @@ import sys from mitmproxy import filt +state = {} -def start(context): + +def start(): if len(sys.argv) != 2: raise ValueError("Usage: -s 'filt.py FILTER'") - context.filter = filt.parse(sys.argv[1]) + state["filter"] = filt.parse(sys.argv[1]) -def response(context, flow): - if flow.match(context.filter): +def response(flow): + if flow.match(state["filter"]): print("Flow matches filter:") print(flow) diff --git a/examples/flowwriter.py b/examples/flowwriter.py index cb5ccb0d..07c7ca20 100644 --- a/examples/flowwriter.py +++ b/examples/flowwriter.py @@ -3,8 +3,10 @@ import sys from mitmproxy.flow import FlowWriter +state = {} -def start(context): + +def start(): if len(sys.argv) != 2: raise ValueError('Usage: -s "flowriter.py filename"') @@ -12,9 +14,9 @@ def start(context): f = sys.stdout else: f = open(sys.argv[1], "wb") - context.flow_writer = FlowWriter(f) + state["flow_writer"] = FlowWriter(f) -def response(context, flow): +def response(flow): if random.choice([True, False]): - context.flow_writer.add(flow) + state["flow_writer"].add(flow) diff --git a/examples/har_extractor.py b/examples/har_extractor.py index a5c05519..2a69b9af 100644 --- a/examples/har_extractor.py +++ b/examples/har_extractor.py @@ -2,6 +2,7 @@ This inline script utilizes harparser.HAR from https://github.com/JustusW/harparser to generate a HAR log object. """ +import mitmproxy import six import sys import pytz @@ -54,7 +55,13 @@ class _HARLog(HAR.log): return self.__page_list__ -def start(context): +class Context(object): + pass + +context = Context() + + +def start(): """ On start we create a HARLog instance. You will have to adapt this to suit your actual needs of HAR generation. As it will probably be @@ -79,7 +86,7 @@ def start(context): context.seen_server = set() -def response(context, flow): +def response(flow): """ Called when a server response has been received. At the time of this message both a request and a response are present and completely done. @@ -201,7 +208,7 @@ def response(context, flow): context.HARLog.add(entry) -def done(context): +def done(): """ Called once on script shutdown, after any other events. """ @@ -212,17 +219,17 @@ def done(context): compressed_json_dump = context.HARLog.compress() if context.dump_file == '-': - context.log(pprint.pformat(json.loads(json_dump))) + mitmproxy.ctx.log(pprint.pformat(json.loads(json_dump))) elif context.dump_file.endswith('.zhar'): file(context.dump_file, "w").write(compressed_json_dump) else: file(context.dump_file, "w").write(json_dump) - context.log( + mitmproxy.ctx.log( "HAR log finished with %s bytes (%s bytes compressed)" % ( len(json_dump), len(compressed_json_dump) ) ) - context.log( + mitmproxy.ctx.log( "Compression rate is %s%%" % str( 100. * len(compressed_json_dump) / len(json_dump) ) diff --git a/examples/iframe_injector.py b/examples/iframe_injector.py index ebb5fd02..70247d31 100644 --- a/examples/iframe_injector.py +++ b/examples/iframe_injector.py @@ -4,25 +4,27 @@ import sys from bs4 import BeautifulSoup from mitmproxy.models import decoded +iframe_url = None -def start(context): + +def start(): if len(sys.argv) != 2: raise ValueError('Usage: -s "iframe_injector.py url"') - context.iframe_url = sys.argv[1] + global iframe_url + iframe_url = sys.argv[1] -def response(context, flow): - if flow.request.host in context.iframe_url: +def response(flow): + if flow.request.host in iframe_url: return with decoded(flow.response): # Remove content encoding (gzip, ...) html = BeautifulSoup(flow.response.content, "lxml") if html.body: iframe = html.new_tag( "iframe", - src=context.iframe_url, + src=iframe_url, frameborder=0, height=0, width=0) html.body.insert(0, iframe) flow.response.content = str(html).encode("utf8") - context.log("Iframe inserted.") diff --git a/examples/modify_form.py b/examples/modify_form.py index 3fe0cf96..b63a1586 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,4 +1,4 @@ -def request(context, flow): +def request(flow): if flow.request.urlencoded_form: flow.request.urlencoded_form["mitmproxy"] = "rocks" else: diff --git a/examples/modify_querystring.py b/examples/modify_querystring.py index b89e5c8d..ee8a89ad 100644 --- a/examples/modify_querystring.py +++ b/examples/modify_querystring.py @@ -1,2 +1,2 @@ -def request(context, flow): +def request(flow): flow.request.query["mitmproxy"] = "rocks" diff --git a/examples/modify_response_body.py b/examples/modify_response_body.py index 994932a1..23ad0151 100644 --- a/examples/modify_response_body.py +++ b/examples/modify_response_body.py @@ -5,16 +5,20 @@ import sys from mitmproxy.models import decoded -def start(context): +state = {} + + +def start(): if len(sys.argv) != 3: raise ValueError('Usage: -s "modify_response_body.py old new"') # You may want to use Python's argparse for more sophisticated argument # parsing. - context.old, context.new = sys.argv[1].encode(), sys.argv[2].encode() + state["old"], state["new"] = sys.argv[1].encode(), sys.argv[2].encode() -def response(context, flow): +def response(flow): with decoded(flow.response): # automatically decode gzipped responses. flow.response.content = flow.response.content.replace( - context.old, - context.new) + state["old"], + state["new"] + ) diff --git a/examples/nonblocking.py b/examples/nonblocking.py index 4609f389..b81478df 100644 --- a/examples/nonblocking.py +++ b/examples/nonblocking.py @@ -1,9 +1,10 @@ import time +import mitmproxy from mitmproxy.script import concurrent @concurrent # Remove this and see what happens -def request(context, flow): - context.log("handle request: %s%s" % (flow.request.host, flow.request.path)) +def request(flow): + mitmproxy.ctx.log("handle request: %s%s" % (flow.request.host, flow.request.path)) time.sleep(5) - context.log("start request: %s%s" % (flow.request.host, flow.request.path)) + mitmproxy.ctx.log("start request: %s%s" % (flow.request.host, flow.request.path)) diff --git a/examples/proxapp.py b/examples/proxapp.py index 613d3f8b..2935b587 100644 --- a/examples/proxapp.py +++ b/examples/proxapp.py @@ -4,6 +4,7 @@ instance, we're using the Flask framework (http://flask.pocoo.org/) to expose a single simplest-possible page. """ from flask import Flask +import mitmproxy app = Flask("proxapp") @@ -15,10 +16,10 @@ def hello_world(): # Register the app using the magic domain "proxapp" on port 80. Requests to # this domain and port combination will now be routed to the WSGI app instance. -def start(context): - context.app_registry.add(app, "proxapp", 80) +def start(): + mitmproxy.ctx.master.apps.add(app, "proxapp", 80) # SSL works too, but the magic domain needs to be resolvable from the mitmproxy machine due to mitmproxy's design. # mitmproxy will connect to said domain and use serve its certificate (unless --no-upstream-cert is set) # but won't send any data. - context.app_registry.add(app, "example.com", 443) + mitmproxy.ctx.master.apps.add(app, "example.com", 443) diff --git a/examples/redirect_requests.py b/examples/redirect_requests.py index af2aa907..36594bcd 100644 --- a/examples/redirect_requests.py +++ b/examples/redirect_requests.py @@ -5,7 +5,7 @@ from mitmproxy.models import HTTPResponse from netlib.http import Headers -def request(context, flow): +def request(flow): # pretty_host takes the "Host" header of the request into account, # which is useful in transparent mode where we usually only have the IP # otherwise. diff --git a/examples/sslstrip.py b/examples/sslstrip.py index 8dde8e3e..afc95fc8 100644 --- a/examples/sslstrip.py +++ b/examples/sslstrip.py @@ -2,23 +2,21 @@ from netlib.http import decoded import re from six.moves import urllib +# set of SSL/TLS capable hosts +secure_hosts = set() -def start(context): - # set of SSL/TLS capable hosts - context.secure_hosts = set() - -def request(context, flow): +def request(flow): flow.request.headers.pop('If-Modified-Since', None) flow.request.headers.pop('Cache-Control', None) # proxy connections to SSL-enabled hosts - if flow.request.pretty_host in context.secure_hosts: + if flow.request.pretty_host in secure_hosts: flow.request.scheme = 'https' flow.request.port = 443 -def response(context, flow): +def response(flow): with decoded(flow.response): flow.request.headers.pop('Strict-Transport-Security', None) flow.request.headers.pop('Public-Key-Pins', None) @@ -31,7 +29,7 @@ def response(context, flow): location = flow.response.headers['Location'] hostname = urllib.parse.urlparse(location).hostname if hostname: - context.secure_hosts.add(hostname) + secure_hosts.add(hostname) flow.response.headers['Location'] = location.replace('https://', 'http://', 1) # strip secure flag from 'Set-Cookie' headers diff --git a/examples/stream.py b/examples/stream.py index 3adbe437..8598f329 100644 --- a/examples/stream.py +++ b/examples/stream.py @@ -1,4 +1,4 @@ -def responseheaders(context, flow): +def responseheaders(flow): """ Enables streaming for all responses. """ diff --git a/examples/stream_modify.py b/examples/stream_modify.py index aa395c03..5e5da95b 100644 --- a/examples/stream_modify.py +++ b/examples/stream_modify.py @@ -16,5 +16,5 @@ def modify(chunks): yield chunk.replace("foo", "bar") -def responseheaders(context, flow): +def responseheaders(flow): flow.response.stream = modify diff --git a/examples/stub.py b/examples/stub.py index a0f73538..10b34283 100644 --- a/examples/stub.py +++ b/examples/stub.py @@ -1,79 +1,80 @@ +import mitmproxy """ This is a script stub, with definitions for all events. """ -def start(context): +def start(): """ Called once on script startup, before any other events. """ - context.log("start") + mitmproxy.ctx.log("start") -def clientconnect(context, root_layer): +def clientconnect(root_layer): """ Called when a client initiates a connection to the proxy. Note that a connection can correspond to multiple HTTP requests """ - context.log("clientconnect") + mitmproxy.ctx.log("clientconnect") -def request(context, flow): +def request(flow): """ Called when a client request has been received. """ - context.log("request") + mitmproxy.ctx.log("request") -def serverconnect(context, server_conn): +def serverconnect(server_conn): """ Called when the proxy initiates a connection to the target server. Note that a connection can correspond to multiple HTTP requests """ - context.log("serverconnect") + mitmproxy.ctx.log("serverconnect") -def responseheaders(context, flow): +def responseheaders(flow): """ Called when the response headers for a server response have been received, but the response body has not been processed yet. Can be used to tell mitmproxy to stream the response. """ - context.log("responseheaders") + mitmproxy.ctx.log("responseheaders") -def response(context, flow): +def response(flow): """ Called when a server response has been received. """ - context.log("response") + mitmproxy.ctx.log("response") -def error(context, flow): +def error(flow): """ Called when a flow error has occured, e.g. invalid server responses, or interrupted connections. This is distinct from a valid server HTTP error response, which is simply a response with an HTTP error code. """ - context.log("error") + mitmproxy.ctx.log("error") -def serverdisconnect(context, server_conn): +def serverdisconnect(server_conn): """ Called when the proxy closes the connection to the target server. """ - context.log("serverdisconnect") + mitmproxy.ctx.log("serverdisconnect") -def clientdisconnect(context, root_layer): +def clientdisconnect(root_layer): """ Called when a client disconnects from the proxy. """ - context.log("clientdisconnect") + mitmproxy.ctx.log("clientdisconnect") -def done(context): +def done(): """ Called once on script shutdown, after any other events. """ - context.log("done") + mitmproxy.ctx.log("done") diff --git a/examples/tcp_message.py b/examples/tcp_message.py index 6eced0dc..b431c23f 100644 --- a/examples/tcp_message.py +++ b/examples/tcp_message.py @@ -11,7 +11,7 @@ mitmdump -T --host --tcp ".*" -q -s examples/tcp_message.py from netlib import strutils -def tcp_message(ctx, tcp_msg): +def tcp_message(tcp_msg): modified_msg = tcp_msg.message.replace("foo", "bar") is_modified = False if modified_msg == tcp_msg.message else True diff --git a/examples/tls_passthrough.py b/examples/tls_passthrough.py index 50aab65b..20e8f9be 100644 --- a/examples/tls_passthrough.py +++ b/examples/tls_passthrough.py @@ -20,13 +20,14 @@ Example: Authors: Maximilian Hils, Matthew Tuusberg """ -from __future__ import (absolute_import, print_function, division) +from __future__ import absolute_import, print_function, division import collections import random import sys from enum import Enum +import mitmproxy from mitmproxy.exceptions import TlsProtocolException from mitmproxy.protocol import TlsLayer, RawTCPLayer @@ -97,7 +98,6 @@ class TlsFeedback(TlsLayer): def _establish_tls_with_client(self): server_address = self.server_conn.address - tls_strategy = self.script_context.tls_strategy try: super(TlsFeedback, self)._establish_tls_with_client() @@ -110,15 +110,18 @@ class TlsFeedback(TlsLayer): # inline script hooks below. +tls_strategy = None -def start(context): + +def start(): + global tls_strategy if len(sys.argv) == 2: - context.tls_strategy = ProbabilisticStrategy(float(sys.argv[1])) + tls_strategy = ProbabilisticStrategy(float(sys.argv[1])) else: - context.tls_strategy = ConservativeStrategy() + tls_strategy = ConservativeStrategy() -def next_layer(context, next_layer): +def next_layer(next_layer): """ This hook does the actual magic - if the next layer is planned to be a TLS layer, we check if we want to enter pass-through mode instead. @@ -126,14 +129,13 @@ def next_layer(context, next_layer): if isinstance(next_layer, TlsLayer) and next_layer._client_tls: server_address = next_layer.server_conn.address - if context.tls_strategy.should_intercept(server_address): + if tls_strategy.should_intercept(server_address): # We try to intercept. # Monkey-Patch the layer to get feedback from the TLSLayer if interception worked. next_layer.__class__ = TlsFeedback - next_layer.script_context = context else: # We don't intercept - reply with a pass-through layer and add a "skipped" entry. - context.log("TLS passthrough for %s" % repr(next_layer.server_conn.address), "info") - next_layer_replacement = RawTCPLayer(next_layer.ctx, logging=False) + mitmproxy.ctx.log("TLS passthrough for %s" % repr(next_layer.server_conn.address), "info") + next_layer_replacement = RawTCPLayer(next_layer.ctx, ignore=True) next_layer.reply.send(next_layer_replacement) - context.tls_strategy.record_skipped(server_address) + tls_strategy.record_skipped(server_address) diff --git a/examples/upsidedownternet.py b/examples/upsidedownternet.py index 9aac9f05..fafdefce 100644 --- a/examples/upsidedownternet.py +++ b/examples/upsidedownternet.py @@ -3,7 +3,7 @@ from PIL import Image from mitmproxy.models import decoded -def response(context, flow): +def response(flow): if flow.response.headers.get("content-type", "").startswith("image"): with decoded(flow.response): # automatically decode gzipped responses. try: diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py index 95c9704d..93b5766d 100644 --- a/mitmproxy/console/master.py +++ b/mitmproxy/console/master.py @@ -366,7 +366,7 @@ class ConsoleMaster(flow.FlowMaster): signals.add_event("Running script on flow: %s" % command, "debug") try: - s = script.Script(command, script.ScriptContext(self)) + s = script.Script(command) s.load() except script.ScriptException as e: signals.status_message.send( @@ -812,6 +812,6 @@ class ConsoleMaster(flow.FlowMaster): @controller.handler def script_change(self, script): if super(ConsoleMaster, self).script_change(script): - signals.status_message.send(message='"{}" reloaded.'.format(script.filename)) + signals.status_message.send(message='"{}" reloaded.'.format(script.path)) else: - signals.status_message.send(message='Error reloading "{}".'.format(script.filename)) + signals.status_message.send(message='Error reloading "{}".'.format(script.path)) diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index a170d868..e2be3a53 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -2,11 +2,12 @@ from __future__ import absolute_import, print_function, division import functools import threading +import contextlib from six.moves import queue +from . import ctx as mitmproxy_ctx from netlib import basethread - from . import exceptions @@ -34,6 +35,16 @@ Events = frozenset([ ]) +class Log(object): + def __init__(self, master): + self.master = master + + def __call__(self, text, level="info"): + self.master.add_event(text, level) + + # We may want to add .log(), .warn() etc. here at a later point in time + + class Master(object): """ The master handles mitmproxy's main event loop. @@ -45,6 +56,20 @@ class Master(object): for i in servers: self.add_server(i) + @contextlib.contextmanager + def handlecontext(self): + # Handlecontexts also have to nest - leave cleanup to the outermost + if mitmproxy_ctx.master: + yield + return + mitmproxy_ctx.master = self + mitmproxy_ctx.log = Log(self) + try: + yield + finally: + mitmproxy_ctx.master = None + mitmproxy_ctx.log = None + def add_server(self, server): # We give a Channel to the server which can be used to communicate with the master channel = Channel(self.event_queue, self.should_exit) @@ -77,8 +102,8 @@ class Master(object): if mtype not in Events: raise exceptions.ControlException("Unknown event %s" % repr(mtype)) handle_func = getattr(self, mtype) - if not hasattr(handle_func, "__dict__"): - raise exceptions.ControlException("Handler %s not a function" % mtype) + if not callable(handle_func): + raise exceptions.ControlException("Handler %s not callable" % mtype) if not handle_func.__dict__.get("__handler"): raise exceptions.ControlException( "Handler function %s is not decorated with controller.handler" % ( @@ -151,15 +176,7 @@ class Channel(object): def handler(f): @functools.wraps(f) - def wrapper(*args, **kwargs): - # We can either be called as a method, or as a wrapped solo function - if len(args) == 1: - message = args[0] - elif len(args) == 2: - message = args[1] - else: - raise exceptions.ControlException("Handler takes one argument: a message") - + def wrapper(master, message): if not hasattr(message, "reply"): raise exceptions.ControlException("Message %s has no reply attribute" % message) @@ -172,7 +189,8 @@ def handler(f): handling = True message.reply.handled = True - ret = f(*args, **kwargs) + with master.handlecontext(): + ret = f(master, message) if handling and not message.reply.acked and not message.reply.taken: message.reply.ack() @@ -216,7 +234,7 @@ class Reply(object): def __del__(self): if not self.acked: # This will be ignored by the interpreter, but emit a warning - raise exceptions.ControlException("Un-acked message") + raise exceptions.ControlException("Un-acked message: %s" % self.obj) class DummyReply(object): diff --git a/mitmproxy/ctx.py b/mitmproxy/ctx.py new file mode 100644 index 00000000..fcfdfd0b --- /dev/null +++ b/mitmproxy/ctx.py @@ -0,0 +1,4 @@ +from typing import Callable # noqa + +master = None # type: "mitmproxy.flow.FlowMaster" +log = None # type: Callable[[str], None] diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py index 520f82e9..7590a3fa 100644 --- a/mitmproxy/flow/master.py +++ b/mitmproxy/flow/master.py @@ -89,9 +89,10 @@ class FlowMaster(controller.Master): Raises: ScriptException """ - s = script.Script(command, script.ScriptContext(self)) + s = script.Script(command) s.load() if use_reloader: + s.reply = controller.DummyReply() script.reloader.watch(s, lambda: self.event_queue.put(("script_change", s))) self.scripts.append(s) @@ -234,8 +235,12 @@ class FlowMaster(controller.Master): return super(FlowMaster, self).tick(timeout) def duplicate_flow(self, f): + """ + Duplicate flow, and insert it into state without triggering any of + the normal flow events. + """ f2 = f.copy() - self.load_flow(f2) + self.state.add_flow(f2) return f2 def create_request(self, method, scheme, host, port, path): @@ -479,14 +484,14 @@ class FlowMaster(controller.Master): s.unload() except script.ScriptException as e: ok = False - self.add_event('Error reloading "{}":\n{}'.format(s.filename, e), 'error') + self.add_event('Error reloading "{}":\n{}'.format(s.path, e), 'error') try: s.load() except script.ScriptException as e: ok = False - self.add_event('Error reloading "{}":\n{}'.format(s.filename, e), 'error') + self.add_event('Error reloading "{}":\n{}'.format(s.path, e), 'error') else: - self.add_event('"{}" reloaded.'.format(s.filename), 'info') + self.add_event('"{}" reloaded.'.format(s.path), 'info') return ok @controller.handler diff --git a/mitmproxy/script/__init__.py b/mitmproxy/script/__init__.py index d6bff4c7..9a3985ab 100644 --- a/mitmproxy/script/__init__.py +++ b/mitmproxy/script/__init__.py @@ -1,12 +1,10 @@ from . import reloader from .concurrent import concurrent from .script import Script -from .script_context import ScriptContext from ..exceptions import ScriptException __all__ = [ "Script", - "ScriptContext", "concurrent", "ScriptException", "reloader" diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index 56d39d0b..010a5fa0 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -18,9 +18,9 @@ def concurrent(fn): "Concurrent decorator not supported for '%s' method." % fn.__name__ ) - def _concurrent(ctx, obj): + def _concurrent(obj): def run(): - fn(ctx, obj) + fn(obj) if not obj.reply.acked: obj.reply.ack() obj.reply.take() diff --git a/mitmproxy/script/reloader.py b/mitmproxy/script/reloader.py index 50401034..857d76cd 100644 --- a/mitmproxy/script/reloader.py +++ b/mitmproxy/script/reloader.py @@ -15,8 +15,8 @@ _observers = {} def watch(script, callback): if script in _observers: raise RuntimeError("Script already observed") - script_dir = os.path.dirname(os.path.abspath(script.filename)) - script_name = os.path.basename(script.filename) + script_dir = os.path.dirname(os.path.abspath(script.path)) + script_name = os.path.basename(script.path) event_handler = _ScriptModificationHandler(callback, filename=script_name) observer = Observer() observer.schedule(event_handler, script_dir) diff --git a/mitmproxy/script/script.py b/mitmproxy/script/script.py index 9ff79f52..db4909ca 100644 --- a/mitmproxy/script/script.py +++ b/mitmproxy/script/script.py @@ -6,38 +6,40 @@ by the mitmproxy-specific ScriptContext. # Do not import __future__ here, this would apply transitively to the inline scripts. from __future__ import absolute_import, print_function, division -import inspect import os import shlex import sys import contextlib -import warnings import six +from typing import List # noqa from mitmproxy import exceptions @contextlib.contextmanager -def setargs(args): +def scriptenv(path, args): + # type: (str, List[str]) -> None oldargs = sys.argv - sys.argv = args + script_dir = os.path.dirname(os.path.abspath(path)) + + sys.argv = [path] + args + sys.path.append(script_dir) try: yield finally: sys.argv = oldargs + sys.path.pop() class Script(object): - """ Script object representing an inline script. """ - def __init__(self, command, context): + def __init__(self, command): self.command = command - self.args = self.parse_command(command) - self.ctx = context + self.path, self.args = self.parse_command(command) self.ns = None def __enter__(self): @@ -46,15 +48,15 @@ class Script(object): def __exit__(self, exc_type, exc_val, exc_tb): if exc_val: - return False # reraise the exception + return False # re-raise the exception self.unload() - @property - def filename(self): - return self.args[0] - @staticmethod def parse_command(command): + # type: (str) -> Tuple[str,List[str]] + """ + Returns a (path, args) tuple. + """ if not command or not command.strip(): raise exceptions.ScriptException("Empty script command.") # Windows: escape all backslashes in the path. @@ -71,7 +73,7 @@ class Script(object): args[0]) elif os.path.isdir(args[0]): raise exceptions.ScriptException("Not a file: %s" % args[0]) - return args + return args[0], args[1:] def load(self): """ @@ -85,31 +87,19 @@ class Script(object): """ if self.ns is not None: raise exceptions.ScriptException("Script is already loaded") - script_dir = os.path.dirname(os.path.abspath(self.args[0])) - self.ns = {'__file__': os.path.abspath(self.args[0])} - sys.path.append(script_dir) - sys.path.append(os.path.join(script_dir, "..")) - try: - with open(self.filename) as f: - code = compile(f.read(), self.filename, 'exec') - exec(code, self.ns, self.ns) - except Exception: - six.reraise( - exceptions.ScriptException, - exceptions.ScriptException.from_exception_context(), - sys.exc_info()[2] - ) - finally: - sys.path.pop() - sys.path.pop() - - start_fn = self.ns.get("start") - if start_fn and len(inspect.getargspec(start_fn).args) == 2: - warnings.warn( - "The 'args' argument of the start() script hook is deprecated. " - "Please use sys.argv instead." - ) - return self.run("start", self.args) + self.ns = {'__file__': os.path.abspath(self.path)} + + with scriptenv(self.path, self.args): + try: + with open(self.path) as f: + code = compile(f.read(), self.path, 'exec') + exec(code, self.ns, self.ns) + except Exception: + six.reraise( + exceptions.ScriptException, + exceptions.ScriptException.from_exception_context(), + sys.exc_info()[2] + ) return self.run("start") def unload(self): @@ -134,8 +124,8 @@ class Script(object): f = self.ns.get(name) if f: try: - with setargs(self.args): - return f(self.ctx, *args, **kwargs) + with scriptenv(self.path, self.args): + return f(*args, **kwargs) except Exception: six.reraise( exceptions.ScriptException, diff --git a/mitmproxy/script/script_context.py b/mitmproxy/script/script_context.py deleted file mode 100644 index 44e2736b..00000000 --- a/mitmproxy/script/script_context.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -The mitmproxy script context provides an API to inline scripts. -""" -from __future__ import absolute_import, print_function, division - -from mitmproxy import contentviews - - -class ScriptContext(object): - - """ - The script context should be used to interact with the global mitmproxy state from within a - script. - """ - - def __init__(self, master): - self._master = master - - def log(self, message, level="info"): - """ - Logs an event. - - By default, only events with level "error" get displayed. This can be controlled with the "-v" switch. - How log messages are handled depends on the front-end. mitmdump will print them to stdout, - mitmproxy sends output to the eventlog for display ("e" keyboard shortcut). - """ - self._master.add_event(message, level) - - def kill_flow(self, f): - """ - Kills a flow immediately. No further data will be sent to the client or the server. - """ - f.kill(self._master) - - def duplicate_flow(self, f): - """ - Returns a duplicate of the specified flow. The flow is also - injected into the current state, and is ready for editing, replay, - etc. - """ - self._master.pause_scripts = True - f = self._master.duplicate_flow(f) - self._master.pause_scripts = False - return f - - def replay_request(self, f): - """ - Replay the request on the current flow. The response will be added - to the flow object. - """ - return self._master.replay_request(f, block=True, run_scripthooks=False) - - @property - def app_registry(self): - return self._master.apps - - def add_contentview(self, view_obj): - contentviews.add(view_obj) - - def remove_contentview(self, view_obj): - contentviews.remove(view_obj) diff --git a/test/mitmproxy/data/scripts/a.py b/test/mitmproxy/data/scripts/a.py index 33dbaa64..ab0dbf96 100644 --- a/test/mitmproxy/data/scripts/a.py +++ b/test/mitmproxy/data/scripts/a.py @@ -5,12 +5,12 @@ from a_helper import parser var = 0 -def start(ctx): +def start(): global var var = parser.parse_args(sys.argv[1:]).var -def here(ctx): +def here(): global var var += 1 return var diff --git a/test/mitmproxy/data/scripts/all.py b/test/mitmproxy/data/scripts/all.py index dad2aade..bf8e93ec 100644 --- a/test/mitmproxy/data/scripts/all.py +++ b/test/mitmproxy/data/scripts/all.py @@ -1,36 +1,37 @@ +import mitmproxy log = [] -def clientconnect(ctx, cc): - ctx.log("XCLIENTCONNECT") +def clientconnect(cc): + mitmproxy.ctx.log("XCLIENTCONNECT") log.append("clientconnect") -def serverconnect(ctx, cc): - ctx.log("XSERVERCONNECT") +def serverconnect(cc): + mitmproxy.ctx.log("XSERVERCONNECT") log.append("serverconnect") -def request(ctx, f): - ctx.log("XREQUEST") +def request(f): + mitmproxy.ctx.log("XREQUEST") log.append("request") -def response(ctx, f): - ctx.log("XRESPONSE") +def response(f): + mitmproxy.ctx.log("XRESPONSE") log.append("response") -def responseheaders(ctx, f): - ctx.log("XRESPONSEHEADERS") +def responseheaders(f): + mitmproxy.ctx.log("XRESPONSEHEADERS") log.append("responseheaders") -def clientdisconnect(ctx, cc): - ctx.log("XCLIENTDISCONNECT") +def clientdisconnect(cc): + mitmproxy.ctx.log("XCLIENTDISCONNECT") log.append("clientdisconnect") -def error(ctx, cc): - ctx.log("XERROR") +def error(cc): + mitmproxy.ctx.log("XERROR") log.append("error") diff --git a/test/mitmproxy/data/scripts/concurrent_decorator.py b/test/mitmproxy/data/scripts/concurrent_decorator.py index e017f605..162c00f4 100644 --- a/test/mitmproxy/data/scripts/concurrent_decorator.py +++ b/test/mitmproxy/data/scripts/concurrent_decorator.py @@ -3,5 +3,5 @@ from mitmproxy.script import concurrent @concurrent -def request(context, flow): +def request(flow): time.sleep(0.1) diff --git a/test/mitmproxy/data/scripts/concurrent_decorator_err.py b/test/mitmproxy/data/scripts/concurrent_decorator_err.py index 349e5dd6..756869c8 100644 --- a/test/mitmproxy/data/scripts/concurrent_decorator_err.py +++ b/test/mitmproxy/data/scripts/concurrent_decorator_err.py @@ -2,5 +2,5 @@ from mitmproxy.script import concurrent @concurrent -def start(context): +def start(): pass diff --git a/test/mitmproxy/data/scripts/duplicate_flow.py b/test/mitmproxy/data/scripts/duplicate_flow.py index e13af786..565b1845 100644 --- a/test/mitmproxy/data/scripts/duplicate_flow.py +++ b/test/mitmproxy/data/scripts/duplicate_flow.py @@ -1,4 +1,6 @@ +import mitmproxy -def request(ctx, f): - f = ctx.duplicate_flow(f) - ctx.replay_request(f) + +def request(f): + f = mitmproxy.ctx.master.duplicate_flow(f) + mitmproxy.ctx.master.replay_request(f, block=True, run_scripthooks=False) diff --git a/test/mitmproxy/data/scripts/reqerr.py b/test/mitmproxy/data/scripts/reqerr.py index e7c503a8..7b419361 100644 --- a/test/mitmproxy/data/scripts/reqerr.py +++ b/test/mitmproxy/data/scripts/reqerr.py @@ -1,2 +1,2 @@ -def request(ctx, r): - raise ValueError +def request(r): + raise ValueError() diff --git a/test/mitmproxy/data/scripts/starterr.py b/test/mitmproxy/data/scripts/starterr.py index 82d773bd..28ba2ff1 100644 --- a/test/mitmproxy/data/scripts/starterr.py +++ b/test/mitmproxy/data/scripts/starterr.py @@ -1,3 +1,3 @@ -def start(ctx): +def start(): raise ValueError() diff --git a/test/mitmproxy/data/scripts/stream_modify.py b/test/mitmproxy/data/scripts/stream_modify.py index 8221b0dd..4fbf45c2 100644 --- a/test/mitmproxy/data/scripts/stream_modify.py +++ b/test/mitmproxy/data/scripts/stream_modify.py @@ -3,5 +3,5 @@ def modify(chunks): yield chunk.replace(b"foo", b"bar") -def responseheaders(context, flow): +def responseheaders(flow): flow.response.stream = modify diff --git a/test/mitmproxy/data/scripts/tcp_stream_modify.py b/test/mitmproxy/data/scripts/tcp_stream_modify.py index 0965beba..2281e6e6 100644 --- a/test/mitmproxy/data/scripts/tcp_stream_modify.py +++ b/test/mitmproxy/data/scripts/tcp_stream_modify.py @@ -1,4 +1,4 @@ -def tcp_message(ctx, flow): +def tcp_message(flow): message = flow.messages[-1] if not message.from_client: message.content = message.content.replace(b"foo", b"bar") diff --git a/test/mitmproxy/data/scripts/unloaderr.py b/test/mitmproxy/data/scripts/unloaderr.py index fba02734..6a48ab43 100644 --- a/test/mitmproxy/data/scripts/unloaderr.py +++ b/test/mitmproxy/data/scripts/unloaderr.py @@ -1,2 +1,2 @@ -def done(ctx): +def done(): raise RuntimeError() diff --git a/test/mitmproxy/script/test_concurrent.py b/test/mitmproxy/script/test_concurrent.py index 62541f3f..57eeca19 100644 --- a/test/mitmproxy/script/test_concurrent.py +++ b/test/mitmproxy/script/test_concurrent.py @@ -11,7 +11,7 @@ class Thing: @tutils.skip_appveyor def test_concurrent(): - with Script(tutils.test_data.path("data/scripts/concurrent_decorator.py"), None) as s: + with Script(tutils.test_data.path("data/scripts/concurrent_decorator.py")) as s: f1, f2 = Thing(), Thing() s.run("request", f1) s.run("request", f2) @@ -23,6 +23,6 @@ def test_concurrent(): def test_concurrent_err(): - s = Script(tutils.test_data.path("data/scripts/concurrent_decorator_err.py"), None) + s = Script(tutils.test_data.path("data/scripts/concurrent_decorator_err.py")) with tutils.raises("Concurrent decorator not supported for 'start' method"): s.load() diff --git a/test/mitmproxy/script/test_reloader.py b/test/mitmproxy/script/test_reloader.py index 0345f6ed..e33903b9 100644 --- a/test/mitmproxy/script/test_reloader.py +++ b/test/mitmproxy/script/test_reloader.py @@ -10,7 +10,7 @@ def test_simple(): pass script = mock.Mock() - script.filename = "foo.py" + script.path = "foo.py" e = Event() diff --git a/test/mitmproxy/script/test_script.py b/test/mitmproxy/script/test_script.py index fe98fab5..48fe65c9 100644 --- a/test/mitmproxy/script/test_script.py +++ b/test/mitmproxy/script/test_script.py @@ -21,21 +21,21 @@ class TestParseCommand: def test_parse_args(self): with tutils.chdir(tutils.test_data.dirname): - assert Script.parse_command("data/scripts/a.py") == ["data/scripts/a.py"] - assert Script.parse_command("data/scripts/a.py foo bar") == ["data/scripts/a.py", "foo", "bar"] - assert Script.parse_command("data/scripts/a.py 'foo bar'") == ["data/scripts/a.py", "foo bar"] + assert Script.parse_command("data/scripts/a.py") == ("data/scripts/a.py", []) + assert Script.parse_command("data/scripts/a.py foo bar") == ("data/scripts/a.py", ["foo", "bar"]) + assert Script.parse_command("data/scripts/a.py 'foo bar'") == ("data/scripts/a.py", ["foo bar"]) @tutils.skip_not_windows def test_parse_windows(self): with tutils.chdir(tutils.test_data.dirname): - assert Script.parse_command("data\\scripts\\a.py") == ["data\\scripts\\a.py"] - assert Script.parse_command("data\\scripts\\a.py 'foo \\ bar'") == ["data\\scripts\\a.py", 'foo \\ bar'] + assert Script.parse_command("data\\scripts\\a.py") == ("data\\scripts\\a.py", []) + assert Script.parse_command("data\\scripts\\a.py 'foo \\ bar'") == ("data\\scripts\\a.py", ['foo \\ bar']) def test_simple(): with tutils.chdir(tutils.test_data.path("data/scripts")): - s = Script("a.py --var 42", None) - assert s.filename == "a.py" + s = Script("a.py --var 42") + assert s.path == "a.py" assert s.ns is None s.load() @@ -50,34 +50,34 @@ def test_simple(): with tutils.raises(ScriptException): s.run("here") - with Script("a.py --var 42", None) as s: + with Script("a.py --var 42") as s: s.run("here") def test_script_exception(): with tutils.chdir(tutils.test_data.path("data/scripts")): - s = Script("syntaxerr.py", None) + s = Script("syntaxerr.py") with tutils.raises(ScriptException): s.load() - s = Script("starterr.py", None) + s = Script("starterr.py") with tutils.raises(ScriptException): s.load() - s = Script("a.py", None) + s = Script("a.py") s.load() with tutils.raises(ScriptException): s.load() - s = Script("a.py", None) + s = Script("a.py") with tutils.raises(ScriptException): s.run("here") with tutils.raises(ScriptException): - with Script("reqerr.py", None) as s: + with Script("reqerr.py") as s: s.run("request", None) - s = Script("unloaderr.py", None) + s = Script("unloaderr.py") s.load() with tutils.raises(ScriptException): s.unload() diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index f30973e7..bdadcd11 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -1,47 +1,31 @@ import glob import json +import mock import os import sys from contextlib import contextmanager from mitmproxy import script -from mitmproxy.proxy import config import netlib.utils from netlib import tutils as netutils from netlib.http import Headers -from . import tservers, tutils +from . import tutils example_dir = netlib.utils.Data(__name__).path("../../examples") -class DummyContext(object): - """Emulate script.ScriptContext() functionality.""" - - contentview = None - - def log(self, *args, **kwargs): - pass - - def add_contentview(self, view_obj): - self.contentview = view_obj - - def remove_contentview(self, view_obj): - self.contentview = None - - @contextmanager def example(command): command = os.path.join(example_dir, command) - ctx = DummyContext() - with script.Script(command, ctx) as s: + with script.Script(command) as s: yield s -def test_load_scripts(): +@mock.patch("mitmproxy.ctx.master") +@mock.patch("mitmproxy.ctx.log") +def test_load_scripts(log, master): scripts = glob.glob("%s/*.py" % example_dir) - tmaster = tservers.TestMaster(config.ProxyConfig()) - for f in scripts: if "har_extractor" in f: continue @@ -54,7 +38,7 @@ def test_load_scripts(): if "modify_response_body" in f: f += " foo bar" # two arguments required - s = script.Script(f, script.ScriptContext(tmaster)) + s = script.Script(f) try: s.load() except Exception as v: @@ -71,17 +55,21 @@ def test_add_header(): assert flow.response.headers["newheader"] == "foo" -def test_custom_contentviews(): - with example("custom_contentviews.py") as ex: - pig = ex.ctx.contentview +@mock.patch("mitmproxy.contentviews.remove") +@mock.patch("mitmproxy.contentviews.add") +def test_custom_contentviews(add, remove): + with example("custom_contentviews.py"): + assert add.called + pig = add.call_args[0][0] _, fmt = pig(b"<html>test!</html>") assert any(b'esttay!' in val[0][1] for val in fmt) assert not pig(b"gobbledygook") + assert remove.called def test_iframe_injector(): with tutils.raises(script.ScriptException): - with example("iframe_injector.py") as ex: + with example("iframe_injector.py"): pass flow = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>")) @@ -121,7 +109,7 @@ def test_modify_response_body(): flow = tutils.tflow(resp=netutils.tresp(content=b"I <3 mitmproxy")) with example("modify_response_body.py mitmproxy rocks") as ex: - assert ex.ctx.old == b"mitmproxy" and ex.ctx.new == b"rocks" + assert ex.ns["state"]["old"] == b"mitmproxy" and ex.ns["state"]["new"] == b"rocks" ex.run("response", flow) assert flow.response.content == b"I <3 rocks" @@ -133,7 +121,8 @@ def test_redirect_requests(): assert flow.request.host == "mitmproxy.org" -def test_har_extractor(): +@mock.patch("mitmproxy.ctx.log") +def test_har_extractor(log): if sys.version_info >= (3, 0): with tutils.raises("does not work on Python 3"): with example("har_extractor.py -"): @@ -159,4 +148,4 @@ def test_har_extractor(): with open(tutils.test_data.path("data/har_extractor.har")) as fp: test_data = json.load(fp) - assert json.loads(ex.ctx.HARLog.json()) == test_data["test_response"] + assert json.loads(ex.ns["context"].HARLog.json()) == test_data["test_response"] diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index 0ab7624e..9dd8b79c 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -1,6 +1,7 @@ import os import socket import time +import types from OpenSSL import SSL from netlib.exceptions import HttpReadDisconnect, HttpException from netlib.tcp import Address @@ -945,7 +946,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest): f.reply.kill() return _func(f) - setattr(master, attr, handler) + setattr(master, attr, types.MethodType(handler, master)) kill_requests( self.chain[1].tmaster, |