diff options
author | Aldo Cortesi <aldo@corte.si> | 2018-04-28 10:59:12 +1200 |
---|---|---|
committer | Aldo Cortesi <aldo@nullcube.com> | 2018-04-30 17:17:03 +1200 |
commit | 236a2fb6fde4ff8837f85cf0a217f915b0bfed79 (patch) | |
tree | 7342213c11fbfb93701976457d78a84cf2e6dd89 | |
parent | 28d53d5a245cbac19896bac30a41435024b17b78 (diff) | |
download | mitmproxy-236a2fb6fde4ff8837f85cf0a217f915b0bfed79.tar.gz mitmproxy-236a2fb6fde4ff8837f85cf0a217f915b0bfed79.tar.bz2 mitmproxy-236a2fb6fde4ff8837f85cf0a217f915b0bfed79.zip |
client replay: re-design
Re-design the way client replay works. Before, we would fire up a thread,
replay, wait for the thread to complete, get the next flow, and repeat the
procedure. Now, we have one replay thread that starts when the addon starts,
which pops flows off a thread-safe queue. This is much cleaner, removes the
need for busy tick, and sets the scene for optimisations like server connection
reuse down the track.
-rw-r--r-- | mitmproxy/addons/clientplayback.py | 242 | ||||
-rw-r--r-- | mitmproxy/test/taddons.py | 10 | ||||
-rw-r--r-- | test/mitmproxy/addons/test_check_ca.py | 4 | ||||
-rw-r--r-- | test/mitmproxy/addons/test_clientplayback.py | 72 | ||||
-rw-r--r-- | test/mitmproxy/addons/test_readfile.py | 2 | ||||
-rw-r--r-- | test/mitmproxy/addons/test_save.py | 8 | ||||
-rw-r--r-- | test/mitmproxy/addons/test_script.py | 17 | ||||
-rw-r--r-- | test/mitmproxy/addons/test_view.py | 8 |
8 files changed, 154 insertions, 209 deletions
diff --git a/mitmproxy/addons/clientplayback.py b/mitmproxy/addons/clientplayback.py index 8cd8e3a8..305920da 100644 --- a/mitmproxy/addons/clientplayback.py +++ b/mitmproxy/addons/clientplayback.py @@ -1,3 +1,6 @@ +import queue +import typing + from mitmproxy import log from mitmproxy import controller from mitmproxy import exceptions @@ -14,46 +17,47 @@ from mitmproxy import io from mitmproxy import command import mitmproxy.types -import typing - class RequestReplayThread(basethread.BaseThread): - name = "RequestReplayThread" + daemon = True def __init__( self, opts: options.Options, - f: http.HTTPFlow, channel: controller.Channel, + queue: queue.Queue, ) -> None: self.options = opts - self.f = f - f.live = True self.channel = channel - super().__init__( - "RequestReplay (%s)" % f.request.url - ) - self.daemon = True + self.queue = queue + super().__init__("RequestReplayThread") def run(self): - r = self.f.request + while True: + f = self.queue.get(block=True, timeout=None) + self.replay(f) + + def replay(self, f): + f.live = True + r = f.request bsl = human.parse_size(self.options.body_size_limit) first_line_format_backup = r.first_line_format server = None try: - self.f.response = None + f.response = None # If we have a channel, run script hooks. - if self.channel: - request_reply = self.channel.ask("request", self.f) - if isinstance(request_reply, http.HTTPResponse): - self.f.response = request_reply + request_reply = self.channel.ask("request", f) + if isinstance(request_reply, http.HTTPResponse): + f.response = request_reply - if not self.f.response: + if not f.response: # In all modes, we directly connect to the server displayed if self.options.mode.startswith("upstream:"): server_address = server_spec.parse_with_mode(self.options.mode)[1].address - server = connections.ServerConnection(server_address, (self.options.listen_host, 0)) + server = connections.ServerConnection( + server_address, (self.options.listen_host, 0) + ) server.connect() if r.scheme == "https": connect_request = http.make_connect_request((r.data.host, r.port)) @@ -65,9 +69,11 @@ class RequestReplayThread(basethread.BaseThread): body_size_limit=bsl ) if resp.status_code != 200: - raise exceptions.ReplayException("Upstream server refuses CONNECT request") + raise exceptions.ReplayException( + "Upstream server refuses CONNECT request" + ) server.establish_tls( - sni=self.f.server_conn.sni, + sni=f.server_conn.sni, **tls.client_arguments_from_options(self.options) ) r.first_line_format = "relative" @@ -82,7 +88,7 @@ class RequestReplayThread(basethread.BaseThread): server.connect() if r.scheme == "https": server.establish_tls( - sni=self.f.server_conn.sni, + sni=f.server_conn.sni, **tls.client_arguments_from_options(self.options) ) r.first_line_format = "relative" @@ -90,104 +96,44 @@ class RequestReplayThread(basethread.BaseThread): server.wfile.write(http1.assemble_request(r)) server.wfile.flush() - if self.f.server_conn: - self.f.server_conn.close() - self.f.server_conn = server + if f.server_conn: + f.server_conn.close() + f.server_conn = server - self.f.response = http.HTTPResponse.wrap( - http1.read_response( - server.rfile, - r, - body_size_limit=bsl - ) + f.response = http.HTTPResponse.wrap( + http1.read_response(server.rfile, r, body_size_limit=bsl) ) - if self.channel: - response_reply = self.channel.ask("response", self.f) - if response_reply == exceptions.Kill: - raise exceptions.Kill() + response_reply = self.channel.ask("response", f) + if response_reply == exceptions.Kill: + raise exceptions.Kill() except (exceptions.ReplayException, exceptions.NetlibException) as e: - self.f.error = flow.Error(str(e)) - if self.channel: - self.channel.ask("error", self.f) + f.error = flow.Error(str(e)) + self.channel.ask("error", f) except exceptions.Kill: - # Kill should only be raised if there's a channel in the - # first place. - self.channel.tell( - "log", - log.LogEntry("Connection killed", "info") - ) + self.channel.tell("log", log.LogEntry("Connection killed", "info")) except Exception as e: - self.channel.tell( - "log", - log.LogEntry(repr(e), "error") - ) + self.channel.tell("log", log.LogEntry(repr(e), "error")) finally: r.first_line_format = first_line_format_backup - self.f.live = False + f.live = False if server.connected(): server.finish() class ClientPlayback: def __init__(self): - self.flows: typing.List[flow.Flow] = [] - self.current_thread = None - self.configured = False - - def replay_request( - self, - f: http.HTTPFlow, - block: bool=False - ) -> RequestReplayThread: - """ - Replay a HTTP request to receive a new response from the server. - - Args: - f: The flow to replay. - block: If True, this function will wait for the replay to finish. - This causes a deadlock if activated in the main thread. - - Returns: - The thread object doing the replay. - - Raises: - exceptions.ReplayException, if the flow is in a state - where it is ineligible for replay. - """ + self.q: queue.Queue = queue.Queue() + self.thread: RequestReplayThread | None = None + def check(self, f: http.HTTPFlow): if f.live: - raise exceptions.ReplayException( - "Can't replay live flow." - ) + return "Can't replay live flow." if f.intercepted: - raise exceptions.ReplayException( - "Can't replay intercepted flow." - ) + return "Can't replay intercepted flow." if not f.request: - raise exceptions.ReplayException( - "Can't replay flow with missing request." - ) + return "Can't replay flow with missing request." if f.request.raw_content is None: - raise exceptions.ReplayException( - "Can't replay flow with missing content." - ) - - f.backup() - f.request.is_replay = True - - f.response = None - f.error = None - - if f.request.http_version == "HTTP/2.0": # https://github.com/mitmproxy/mitmproxy/issues/2197 - f.request.http_version = "HTTP/1.1" - host = f.request.headers.pop(":authority") - f.request.headers.insert(0, "host", host) - - rt = RequestReplayThread(ctx.master.options, f, ctx.master.channel) - rt.start() # pragma: no cover - if block: - rt.join() - return rt + return "Can't replay flow with missing content." def load(self, loader): loader.add_option( @@ -195,65 +141,73 @@ class ClientPlayback: "Replay client requests from a saved file." ) + def running(self): + self.thread = RequestReplayThread( + ctx.options, + ctx.master.channel, + self.q, + ) + self.thread.start() + + def configure(self, updated): + if "client_replay" in updated and ctx.options.client_replay: + try: + flows = io.read_flows_from_paths(ctx.options.client_replay) + except exceptions.FlowReadException as e: + raise exceptions.OptionsError(str(e)) + self.start_replay(flows) + + @command.command("replay.client.count") def count(self) -> int: - if self.current_thread: - current = 1 - else: - current = 0 - return current + len(self.flows) + """ + Approximate number of flows queued for replay. + """ + return self.q.qsize() @command.command("replay.client.stop") def stop_replay(self) -> None: """ - Stop client replay. + Clear the replay queue. """ - self.flows = [] - ctx.log.alert("Client replay stopped.") - ctx.master.addons.trigger("update", []) + with self.q.mutex: + lst = list(self.q.queue) + self.q.queue.clear() + ctx.master.addons.trigger("update", lst) + ctx.log.alert("Client replay queue cleared.") @command.command("replay.client") def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None: """ - Replay requests from flows. + Add flows to the replay queue, skipping flows that can't be replayed. """ + lst = [] for f in flows: - if f.live: - raise exceptions.CommandError("Can't replay live flow.") - self.flows = list(flows) - ctx.log.alert("Replaying %s flows." % len(self.flows)) - ctx.master.addons.trigger("update", []) + err = self.check(f) + if err: + ctx.log.warn(err) + continue + + lst.append(f) + # Prepare the flow for replay + f.backup() + f.request.is_replay = True + f.response = None + f.error = None + # https://github.com/mitmproxy/mitmproxy/issues/2197 + if f.request.http_version == "HTTP/2.0": + f.request.http_version = "HTTP/1.1" + host = f.request.headers.pop(":authority") + f.request.headers.insert(0, "host", host) + self.q.put(f) + ctx.master.addons.trigger("update", lst) @command.command("replay.client.file") def load_file(self, path: mitmproxy.types.Path) -> None: + """ + Load flows from file, and add them to the replay queue. + """ try: flows = io.read_flows_from_paths([path]) except exceptions.FlowReadException as e: raise exceptions.CommandError(str(e)) - ctx.log.alert("Replaying %s flows." % len(self.flows)) - self.flows = flows - ctx.master.addons.trigger("update", []) - - def configure(self, updated): - if not self.configured and ctx.options.client_replay: - self.configured = True - ctx.log.info("Client Replay: {}".format(ctx.options.client_replay)) - try: - flows = io.read_flows_from_paths(ctx.options.client_replay) - except exceptions.FlowReadException as e: - raise exceptions.OptionsError(str(e)) - self.start_replay(flows) - - def tick(self): - current_is_done = self.current_thread and not self.current_thread.is_alive() - can_start_new = not self.current_thread or current_is_done - will_start_new = can_start_new and self.flows - - if current_is_done: - self.current_thread = None - ctx.master.addons.trigger("update", []) - if will_start_new: - f = self.flows.pop(0) - self.current_thread = self.replay_request(f) - ctx.master.addons.trigger("update", [f]) - if current_is_done and not will_start_new: - ctx.master.addons.trigger("processing_complete") + self.start_replay(flows) diff --git a/mitmproxy/test/taddons.py b/mitmproxy/test/taddons.py index 0505f9f7..67c15f75 100644 --- a/mitmproxy/test/taddons.py +++ b/mitmproxy/test/taddons.py @@ -112,12 +112,10 @@ class context: if addon not in self.master.addons: self.master.addons.register(addon) with self.options.rollback(kwargs.keys(), reraise=True): - self.options.update(**kwargs) - self.master.addons.invoke_addon( - addon, - "configure", - kwargs.keys() - ) + if kwargs: + self.options.update(**kwargs) + else: + self.master.addons.invoke_addon(addon, "configure", {}) def script(self, path): """ diff --git a/test/mitmproxy/addons/test_check_ca.py b/test/mitmproxy/addons/test_check_ca.py index 5e820b6d..27e6f7e6 100644 --- a/test/mitmproxy/addons/test_check_ca.py +++ b/test/mitmproxy/addons/test_check_ca.py @@ -12,11 +12,11 @@ class TestCheckCA: async def test_check_ca(self, expired): msg = 'The mitmproxy certificate authority has expired!' - with taddons.context() as tctx: + a = check_ca.CheckCA() + with taddons.context(a) as tctx: tctx.master.server = mock.MagicMock() tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock( return_value = expired ) - a = check_ca.CheckCA() tctx.configure(a) assert await tctx.master.await_log(msg) == expired diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py index 0bb24e87..a63bec53 100644 --- a/test/mitmproxy/addons/test_clientplayback.py +++ b/test/mitmproxy/addons/test_clientplayback.py @@ -92,47 +92,13 @@ class TestClientPlayback: # assert rt.f.request.http_version == "HTTP/1.1" # assert ":authority" not in rt.f.request.headers - def test_playback(self): - cp = clientplayback.ClientPlayback() - with taddons.context(cp) as tctx: - assert cp.count() == 0 - f = tflow.tflow(resp=True) - cp.start_replay([f]) - assert cp.count() == 1 - RP = "mitmproxy.addons.clientplayback.RequestReplayThread" - with mock.patch(RP) as rp: - assert not cp.current_thread - cp.tick() - assert rp.called - assert cp.current_thread - - cp.flows = [] - cp.current_thread.is_alive.return_value = False - assert cp.count() == 1 - cp.tick() - assert cp.count() == 0 - assert tctx.master.has_event("update") - assert tctx.master.has_event("processing_complete") - - cp.current_thread = MockThread() - cp.tick() - assert cp.current_thread is None - - cp.start_replay([f]) - cp.stop_replay() - assert not cp.flows - - df = tflow.DummyFlow(tflow.tclient_conn(), tflow.tserver_conn(), True) - with pytest.raises(exceptions.CommandError, match="Can't replay live flow."): - cp.start_replay([df]) - def test_load_file(self, tmpdir): cp = clientplayback.ClientPlayback() with taddons.context(cp): fpath = str(tmpdir.join("flows")) tdump(fpath, [tflow.tflow(resp=True)]) cp.load_file(fpath) - assert cp.flows + assert cp.count() == 1 with pytest.raises(exceptions.CommandError): cp.load_file("/nonexistent") @@ -141,11 +107,39 @@ class TestClientPlayback: with taddons.context(cp) as tctx: path = str(tmpdir.join("flows")) tdump(path, [tflow.tflow()]) + assert cp.count() == 0 tctx.configure(cp, client_replay=[path]) - cp.configured = False + assert cp.count() == 1 tctx.configure(cp, client_replay=[]) - cp.configured = False - tctx.configure(cp) - cp.configured = False with pytest.raises(exceptions.OptionsError): tctx.configure(cp, client_replay=["nonexistent"]) + + def test_check(self): + cp = clientplayback.ClientPlayback() + with taddons.context(cp): + f = tflow.tflow(resp=True) + f.live = True + assert "live flow" in cp.check(f) + + f = tflow.tflow(resp=True) + f.intercepted = True + assert "intercepted flow" in cp.check(f) + + f = tflow.tflow(resp=True) + f.request = None + assert "missing request" in cp.check(f) + + f = tflow.tflow(resp=True) + f.request.raw_content = None + assert "missing content" in cp.check(f) + + def test_playback(self): + cp = clientplayback.ClientPlayback() + with taddons.context(cp): + assert cp.count() == 0 + f = tflow.tflow(resp=True) + cp.start_replay([f]) + assert cp.count() == 1 + + cp.stop_replay() + assert cp.count() == 0
\ No newline at end of file diff --git a/test/mitmproxy/addons/test_readfile.py b/test/mitmproxy/addons/test_readfile.py index f7e0c5c5..d22382a8 100644 --- a/test/mitmproxy/addons/test_readfile.py +++ b/test/mitmproxy/addons/test_readfile.py @@ -42,7 +42,7 @@ def corrupt_data(): class TestReadFile: def test_configure(self): rf = readfile.ReadFile() - with taddons.context() as tctx: + with taddons.context(rf) as tctx: tctx.configure(rf, readfile_filter="~q") with pytest.raises(Exception, match="Invalid readfile filter"): tctx.configure(rf, readfile_filter="~~") diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py index 4486ff78..616caf58 100644 --- a/test/mitmproxy/addons/test_save.py +++ b/test/mitmproxy/addons/test_save.py @@ -11,7 +11,7 @@ from mitmproxy.addons import view def test_configure(tmpdir): sa = save.Save() - with taddons.context() as tctx: + with taddons.context(sa) as tctx: with pytest.raises(exceptions.OptionsError): tctx.configure(sa, save_stream_file=str(tmpdir)) with pytest.raises(Exception, match="Invalid filter"): @@ -32,7 +32,7 @@ def rd(p): def test_tcp(tmpdir): sa = save.Save() - with taddons.context() as tctx: + with taddons.context(sa) as tctx: p = str(tmpdir.join("foo")) tctx.configure(sa, save_stream_file=p) @@ -45,7 +45,7 @@ def test_tcp(tmpdir): def test_websocket(tmpdir): sa = save.Save() - with taddons.context() as tctx: + with taddons.context(sa) as tctx: p = str(tmpdir.join("foo")) tctx.configure(sa, save_stream_file=p) @@ -78,7 +78,7 @@ def test_save_command(tmpdir): def test_simple(tmpdir): sa = save.Save() - with taddons.context() as tctx: + with taddons.context(sa) as tctx: p = str(tmpdir.join("foo")) tctx.configure(sa, save_stream_file=p) diff --git a/test/mitmproxy/addons/test_script.py b/test/mitmproxy/addons/test_script.py index c358f019..91637489 100644 --- a/test/mitmproxy/addons/test_script.py +++ b/test/mitmproxy/addons/test_script.py @@ -92,14 +92,13 @@ class TestScript: @pytest.mark.asyncio async def test_simple(self, tdata): - with taddons.context() as tctx: - sc = script.Script( - tdata.path( - "mitmproxy/data/addonscripts/recorder/recorder.py" - ), - True, - ) - tctx.master.addons.add(sc) + sc = script.Script( + tdata.path( + "mitmproxy/data/addonscripts/recorder/recorder.py" + ), + True, + ) + with taddons.context(sc) as tctx: tctx.configure(sc) await tctx.master.await_log("recorder running") rec = tctx.master.addons.get("recorder") @@ -284,7 +283,7 @@ class TestScriptLoader: rec = tdata.path("mitmproxy/data/addonscripts/recorder") sc = script.ScriptLoader() sc.is_running = True - with taddons.context() as tctx: + with taddons.context(sc) as tctx: tctx.configure( sc, scripts = [ diff --git a/test/mitmproxy/addons/test_view.py b/test/mitmproxy/addons/test_view.py index 62a6aeb0..bd724950 100644 --- a/test/mitmproxy/addons/test_view.py +++ b/test/mitmproxy/addons/test_view.py @@ -155,7 +155,7 @@ def test_create(): def test_orders(): v = view.View() - with taddons.context(): + with taddons.context(v): assert v.order_options() @@ -303,7 +303,7 @@ def test_setgetval(): def test_order(): v = view.View() - with taddons.context() as tctx: + with taddons.context(v) as tctx: v.request(tft(method="get", start=1)) v.request(tft(method="put", start=2)) v.request(tft(method="get", start=3)) @@ -434,7 +434,7 @@ def test_signals(): def test_focus_follow(): v = view.View() - with taddons.context() as tctx: + with taddons.context(v) as tctx: console_addon = consoleaddons.ConsoleAddon(tctx.master) tctx.configure(console_addon) tctx.configure(v, console_focus_follow=True, view_filter="~m get") @@ -553,7 +553,7 @@ def test_settings(): def test_configure(): v = view.View() - with taddons.context() as tctx: + with taddons.context(v) as tctx: tctx.configure(v, view_filter="~q") with pytest.raises(Exception, match="Invalid interception filter"): tctx.configure(v, view_filter="~~") |