aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@corte.si>2018-04-28 10:59:12 +1200
committerAldo Cortesi <aldo@nullcube.com>2018-04-30 17:17:03 +1200
commit236a2fb6fde4ff8837f85cf0a217f915b0bfed79 (patch)
tree7342213c11fbfb93701976457d78a84cf2e6dd89
parent28d53d5a245cbac19896bac30a41435024b17b78 (diff)
downloadmitmproxy-236a2fb6fde4ff8837f85cf0a217f915b0bfed79.tar.gz
mitmproxy-236a2fb6fde4ff8837f85cf0a217f915b0bfed79.tar.bz2
mitmproxy-236a2fb6fde4ff8837f85cf0a217f915b0bfed79.zip
client replay: re-design
Re-design the way client replay works. Before, we would fire up a thread, replay, wait for the thread to complete, get the next flow, and repeat the procedure. Now, we have one replay thread that starts when the addon starts, which pops flows off a thread-safe queue. This is much cleaner, removes the need for busy tick, and sets the scene for optimisations like server connection reuse down the track.
-rw-r--r--mitmproxy/addons/clientplayback.py242
-rw-r--r--mitmproxy/test/taddons.py10
-rw-r--r--test/mitmproxy/addons/test_check_ca.py4
-rw-r--r--test/mitmproxy/addons/test_clientplayback.py72
-rw-r--r--test/mitmproxy/addons/test_readfile.py2
-rw-r--r--test/mitmproxy/addons/test_save.py8
-rw-r--r--test/mitmproxy/addons/test_script.py17
-rw-r--r--test/mitmproxy/addons/test_view.py8
8 files changed, 154 insertions, 209 deletions
diff --git a/mitmproxy/addons/clientplayback.py b/mitmproxy/addons/clientplayback.py
index 8cd8e3a8..305920da 100644
--- a/mitmproxy/addons/clientplayback.py
+++ b/mitmproxy/addons/clientplayback.py
@@ -1,3 +1,6 @@
+import queue
+import typing
+
from mitmproxy import log
from mitmproxy import controller
from mitmproxy import exceptions
@@ -14,46 +17,47 @@ from mitmproxy import io
from mitmproxy import command
import mitmproxy.types
-import typing
-
class RequestReplayThread(basethread.BaseThread):
- name = "RequestReplayThread"
+ daemon = True
def __init__(
self,
opts: options.Options,
- f: http.HTTPFlow,
channel: controller.Channel,
+ queue: queue.Queue,
) -> None:
self.options = opts
- self.f = f
- f.live = True
self.channel = channel
- super().__init__(
- "RequestReplay (%s)" % f.request.url
- )
- self.daemon = True
+ self.queue = queue
+ super().__init__("RequestReplayThread")
def run(self):
- r = self.f.request
+ while True:
+ f = self.queue.get(block=True, timeout=None)
+ self.replay(f)
+
+ def replay(self, f):
+ f.live = True
+ r = f.request
bsl = human.parse_size(self.options.body_size_limit)
first_line_format_backup = r.first_line_format
server = None
try:
- self.f.response = None
+ f.response = None
# If we have a channel, run script hooks.
- if self.channel:
- request_reply = self.channel.ask("request", self.f)
- if isinstance(request_reply, http.HTTPResponse):
- self.f.response = request_reply
+ request_reply = self.channel.ask("request", f)
+ if isinstance(request_reply, http.HTTPResponse):
+ f.response = request_reply
- if not self.f.response:
+ if not f.response:
# In all modes, we directly connect to the server displayed
if self.options.mode.startswith("upstream:"):
server_address = server_spec.parse_with_mode(self.options.mode)[1].address
- server = connections.ServerConnection(server_address, (self.options.listen_host, 0))
+ server = connections.ServerConnection(
+ server_address, (self.options.listen_host, 0)
+ )
server.connect()
if r.scheme == "https":
connect_request = http.make_connect_request((r.data.host, r.port))
@@ -65,9 +69,11 @@ class RequestReplayThread(basethread.BaseThread):
body_size_limit=bsl
)
if resp.status_code != 200:
- raise exceptions.ReplayException("Upstream server refuses CONNECT request")
+ raise exceptions.ReplayException(
+ "Upstream server refuses CONNECT request"
+ )
server.establish_tls(
- sni=self.f.server_conn.sni,
+ sni=f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
@@ -82,7 +88,7 @@ class RequestReplayThread(basethread.BaseThread):
server.connect()
if r.scheme == "https":
server.establish_tls(
- sni=self.f.server_conn.sni,
+ sni=f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
@@ -90,104 +96,44 @@ class RequestReplayThread(basethread.BaseThread):
server.wfile.write(http1.assemble_request(r))
server.wfile.flush()
- if self.f.server_conn:
- self.f.server_conn.close()
- self.f.server_conn = server
+ if f.server_conn:
+ f.server_conn.close()
+ f.server_conn = server
- self.f.response = http.HTTPResponse.wrap(
- http1.read_response(
- server.rfile,
- r,
- body_size_limit=bsl
- )
+ f.response = http.HTTPResponse.wrap(
+ http1.read_response(server.rfile, r, body_size_limit=bsl)
)
- if self.channel:
- response_reply = self.channel.ask("response", self.f)
- if response_reply == exceptions.Kill:
- raise exceptions.Kill()
+ response_reply = self.channel.ask("response", f)
+ if response_reply == exceptions.Kill:
+ raise exceptions.Kill()
except (exceptions.ReplayException, exceptions.NetlibException) as e:
- self.f.error = flow.Error(str(e))
- if self.channel:
- self.channel.ask("error", self.f)
+ f.error = flow.Error(str(e))
+ self.channel.ask("error", f)
except exceptions.Kill:
- # Kill should only be raised if there's a channel in the
- # first place.
- self.channel.tell(
- "log",
- log.LogEntry("Connection killed", "info")
- )
+ self.channel.tell("log", log.LogEntry("Connection killed", "info"))
except Exception as e:
- self.channel.tell(
- "log",
- log.LogEntry(repr(e), "error")
- )
+ self.channel.tell("log", log.LogEntry(repr(e), "error"))
finally:
r.first_line_format = first_line_format_backup
- self.f.live = False
+ f.live = False
if server.connected():
server.finish()
class ClientPlayback:
def __init__(self):
- self.flows: typing.List[flow.Flow] = []
- self.current_thread = None
- self.configured = False
-
- def replay_request(
- self,
- f: http.HTTPFlow,
- block: bool=False
- ) -> RequestReplayThread:
- """
- Replay a HTTP request to receive a new response from the server.
-
- Args:
- f: The flow to replay.
- block: If True, this function will wait for the replay to finish.
- This causes a deadlock if activated in the main thread.
-
- Returns:
- The thread object doing the replay.
-
- Raises:
- exceptions.ReplayException, if the flow is in a state
- where it is ineligible for replay.
- """
+ self.q: queue.Queue = queue.Queue()
+ self.thread: RequestReplayThread | None = None
+ def check(self, f: http.HTTPFlow):
if f.live:
- raise exceptions.ReplayException(
- "Can't replay live flow."
- )
+ return "Can't replay live flow."
if f.intercepted:
- raise exceptions.ReplayException(
- "Can't replay intercepted flow."
- )
+ return "Can't replay intercepted flow."
if not f.request:
- raise exceptions.ReplayException(
- "Can't replay flow with missing request."
- )
+ return "Can't replay flow with missing request."
if f.request.raw_content is None:
- raise exceptions.ReplayException(
- "Can't replay flow with missing content."
- )
-
- f.backup()
- f.request.is_replay = True
-
- f.response = None
- f.error = None
-
- if f.request.http_version == "HTTP/2.0": # https://github.com/mitmproxy/mitmproxy/issues/2197
- f.request.http_version = "HTTP/1.1"
- host = f.request.headers.pop(":authority")
- f.request.headers.insert(0, "host", host)
-
- rt = RequestReplayThread(ctx.master.options, f, ctx.master.channel)
- rt.start() # pragma: no cover
- if block:
- rt.join()
- return rt
+ return "Can't replay flow with missing content."
def load(self, loader):
loader.add_option(
@@ -195,65 +141,73 @@ class ClientPlayback:
"Replay client requests from a saved file."
)
+ def running(self):
+ self.thread = RequestReplayThread(
+ ctx.options,
+ ctx.master.channel,
+ self.q,
+ )
+ self.thread.start()
+
+ def configure(self, updated):
+ if "client_replay" in updated and ctx.options.client_replay:
+ try:
+ flows = io.read_flows_from_paths(ctx.options.client_replay)
+ except exceptions.FlowReadException as e:
+ raise exceptions.OptionsError(str(e))
+ self.start_replay(flows)
+
+ @command.command("replay.client.count")
def count(self) -> int:
- if self.current_thread:
- current = 1
- else:
- current = 0
- return current + len(self.flows)
+ """
+ Approximate number of flows queued for replay.
+ """
+ return self.q.qsize()
@command.command("replay.client.stop")
def stop_replay(self) -> None:
"""
- Stop client replay.
+ Clear the replay queue.
"""
- self.flows = []
- ctx.log.alert("Client replay stopped.")
- ctx.master.addons.trigger("update", [])
+ with self.q.mutex:
+ lst = list(self.q.queue)
+ self.q.queue.clear()
+ ctx.master.addons.trigger("update", lst)
+ ctx.log.alert("Client replay queue cleared.")
@command.command("replay.client")
def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None:
"""
- Replay requests from flows.
+ Add flows to the replay queue, skipping flows that can't be replayed.
"""
+ lst = []
for f in flows:
- if f.live:
- raise exceptions.CommandError("Can't replay live flow.")
- self.flows = list(flows)
- ctx.log.alert("Replaying %s flows." % len(self.flows))
- ctx.master.addons.trigger("update", [])
+ err = self.check(f)
+ if err:
+ ctx.log.warn(err)
+ continue
+
+ lst.append(f)
+ # Prepare the flow for replay
+ f.backup()
+ f.request.is_replay = True
+ f.response = None
+ f.error = None
+ # https://github.com/mitmproxy/mitmproxy/issues/2197
+ if f.request.http_version == "HTTP/2.0":
+ f.request.http_version = "HTTP/1.1"
+ host = f.request.headers.pop(":authority")
+ f.request.headers.insert(0, "host", host)
+ self.q.put(f)
+ ctx.master.addons.trigger("update", lst)
@command.command("replay.client.file")
def load_file(self, path: mitmproxy.types.Path) -> None:
+ """
+ Load flows from file, and add them to the replay queue.
+ """
try:
flows = io.read_flows_from_paths([path])
except exceptions.FlowReadException as e:
raise exceptions.CommandError(str(e))
- ctx.log.alert("Replaying %s flows." % len(self.flows))
- self.flows = flows
- ctx.master.addons.trigger("update", [])
-
- def configure(self, updated):
- if not self.configured and ctx.options.client_replay:
- self.configured = True
- ctx.log.info("Client Replay: {}".format(ctx.options.client_replay))
- try:
- flows = io.read_flows_from_paths(ctx.options.client_replay)
- except exceptions.FlowReadException as e:
- raise exceptions.OptionsError(str(e))
- self.start_replay(flows)
-
- def tick(self):
- current_is_done = self.current_thread and not self.current_thread.is_alive()
- can_start_new = not self.current_thread or current_is_done
- will_start_new = can_start_new and self.flows
-
- if current_is_done:
- self.current_thread = None
- ctx.master.addons.trigger("update", [])
- if will_start_new:
- f = self.flows.pop(0)
- self.current_thread = self.replay_request(f)
- ctx.master.addons.trigger("update", [f])
- if current_is_done and not will_start_new:
- ctx.master.addons.trigger("processing_complete")
+ self.start_replay(flows)
diff --git a/mitmproxy/test/taddons.py b/mitmproxy/test/taddons.py
index 0505f9f7..67c15f75 100644
--- a/mitmproxy/test/taddons.py
+++ b/mitmproxy/test/taddons.py
@@ -112,12 +112,10 @@ class context:
if addon not in self.master.addons:
self.master.addons.register(addon)
with self.options.rollback(kwargs.keys(), reraise=True):
- self.options.update(**kwargs)
- self.master.addons.invoke_addon(
- addon,
- "configure",
- kwargs.keys()
- )
+ if kwargs:
+ self.options.update(**kwargs)
+ else:
+ self.master.addons.invoke_addon(addon, "configure", {})
def script(self, path):
"""
diff --git a/test/mitmproxy/addons/test_check_ca.py b/test/mitmproxy/addons/test_check_ca.py
index 5e820b6d..27e6f7e6 100644
--- a/test/mitmproxy/addons/test_check_ca.py
+++ b/test/mitmproxy/addons/test_check_ca.py
@@ -12,11 +12,11 @@ class TestCheckCA:
async def test_check_ca(self, expired):
msg = 'The mitmproxy certificate authority has expired!'
- with taddons.context() as tctx:
+ a = check_ca.CheckCA()
+ with taddons.context(a) as tctx:
tctx.master.server = mock.MagicMock()
tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock(
return_value = expired
)
- a = check_ca.CheckCA()
tctx.configure(a)
assert await tctx.master.await_log(msg) == expired
diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py
index 0bb24e87..a63bec53 100644
--- a/test/mitmproxy/addons/test_clientplayback.py
+++ b/test/mitmproxy/addons/test_clientplayback.py
@@ -92,47 +92,13 @@ class TestClientPlayback:
# assert rt.f.request.http_version == "HTTP/1.1"
# assert ":authority" not in rt.f.request.headers
- def test_playback(self):
- cp = clientplayback.ClientPlayback()
- with taddons.context(cp) as tctx:
- assert cp.count() == 0
- f = tflow.tflow(resp=True)
- cp.start_replay([f])
- assert cp.count() == 1
- RP = "mitmproxy.addons.clientplayback.RequestReplayThread"
- with mock.patch(RP) as rp:
- assert not cp.current_thread
- cp.tick()
- assert rp.called
- assert cp.current_thread
-
- cp.flows = []
- cp.current_thread.is_alive.return_value = False
- assert cp.count() == 1
- cp.tick()
- assert cp.count() == 0
- assert tctx.master.has_event("update")
- assert tctx.master.has_event("processing_complete")
-
- cp.current_thread = MockThread()
- cp.tick()
- assert cp.current_thread is None
-
- cp.start_replay([f])
- cp.stop_replay()
- assert not cp.flows
-
- df = tflow.DummyFlow(tflow.tclient_conn(), tflow.tserver_conn(), True)
- with pytest.raises(exceptions.CommandError, match="Can't replay live flow."):
- cp.start_replay([df])
-
def test_load_file(self, tmpdir):
cp = clientplayback.ClientPlayback()
with taddons.context(cp):
fpath = str(tmpdir.join("flows"))
tdump(fpath, [tflow.tflow(resp=True)])
cp.load_file(fpath)
- assert cp.flows
+ assert cp.count() == 1
with pytest.raises(exceptions.CommandError):
cp.load_file("/nonexistent")
@@ -141,11 +107,39 @@ class TestClientPlayback:
with taddons.context(cp) as tctx:
path = str(tmpdir.join("flows"))
tdump(path, [tflow.tflow()])
+ assert cp.count() == 0
tctx.configure(cp, client_replay=[path])
- cp.configured = False
+ assert cp.count() == 1
tctx.configure(cp, client_replay=[])
- cp.configured = False
- tctx.configure(cp)
- cp.configured = False
with pytest.raises(exceptions.OptionsError):
tctx.configure(cp, client_replay=["nonexistent"])
+
+ def test_check(self):
+ cp = clientplayback.ClientPlayback()
+ with taddons.context(cp):
+ f = tflow.tflow(resp=True)
+ f.live = True
+ assert "live flow" in cp.check(f)
+
+ f = tflow.tflow(resp=True)
+ f.intercepted = True
+ assert "intercepted flow" in cp.check(f)
+
+ f = tflow.tflow(resp=True)
+ f.request = None
+ assert "missing request" in cp.check(f)
+
+ f = tflow.tflow(resp=True)
+ f.request.raw_content = None
+ assert "missing content" in cp.check(f)
+
+ def test_playback(self):
+ cp = clientplayback.ClientPlayback()
+ with taddons.context(cp):
+ assert cp.count() == 0
+ f = tflow.tflow(resp=True)
+ cp.start_replay([f])
+ assert cp.count() == 1
+
+ cp.stop_replay()
+ assert cp.count() == 0 \ No newline at end of file
diff --git a/test/mitmproxy/addons/test_readfile.py b/test/mitmproxy/addons/test_readfile.py
index f7e0c5c5..d22382a8 100644
--- a/test/mitmproxy/addons/test_readfile.py
+++ b/test/mitmproxy/addons/test_readfile.py
@@ -42,7 +42,7 @@ def corrupt_data():
class TestReadFile:
def test_configure(self):
rf = readfile.ReadFile()
- with taddons.context() as tctx:
+ with taddons.context(rf) as tctx:
tctx.configure(rf, readfile_filter="~q")
with pytest.raises(Exception, match="Invalid readfile filter"):
tctx.configure(rf, readfile_filter="~~")
diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py
index 4486ff78..616caf58 100644
--- a/test/mitmproxy/addons/test_save.py
+++ b/test/mitmproxy/addons/test_save.py
@@ -11,7 +11,7 @@ from mitmproxy.addons import view
def test_configure(tmpdir):
sa = save.Save()
- with taddons.context() as tctx:
+ with taddons.context(sa) as tctx:
with pytest.raises(exceptions.OptionsError):
tctx.configure(sa, save_stream_file=str(tmpdir))
with pytest.raises(Exception, match="Invalid filter"):
@@ -32,7 +32,7 @@ def rd(p):
def test_tcp(tmpdir):
sa = save.Save()
- with taddons.context() as tctx:
+ with taddons.context(sa) as tctx:
p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p)
@@ -45,7 +45,7 @@ def test_tcp(tmpdir):
def test_websocket(tmpdir):
sa = save.Save()
- with taddons.context() as tctx:
+ with taddons.context(sa) as tctx:
p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p)
@@ -78,7 +78,7 @@ def test_save_command(tmpdir):
def test_simple(tmpdir):
sa = save.Save()
- with taddons.context() as tctx:
+ with taddons.context(sa) as tctx:
p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p)
diff --git a/test/mitmproxy/addons/test_script.py b/test/mitmproxy/addons/test_script.py
index c358f019..91637489 100644
--- a/test/mitmproxy/addons/test_script.py
+++ b/test/mitmproxy/addons/test_script.py
@@ -92,14 +92,13 @@ class TestScript:
@pytest.mark.asyncio
async def test_simple(self, tdata):
- with taddons.context() as tctx:
- sc = script.Script(
- tdata.path(
- "mitmproxy/data/addonscripts/recorder/recorder.py"
- ),
- True,
- )
- tctx.master.addons.add(sc)
+ sc = script.Script(
+ tdata.path(
+ "mitmproxy/data/addonscripts/recorder/recorder.py"
+ ),
+ True,
+ )
+ with taddons.context(sc) as tctx:
tctx.configure(sc)
await tctx.master.await_log("recorder running")
rec = tctx.master.addons.get("recorder")
@@ -284,7 +283,7 @@ class TestScriptLoader:
rec = tdata.path("mitmproxy/data/addonscripts/recorder")
sc = script.ScriptLoader()
sc.is_running = True
- with taddons.context() as tctx:
+ with taddons.context(sc) as tctx:
tctx.configure(
sc,
scripts = [
diff --git a/test/mitmproxy/addons/test_view.py b/test/mitmproxy/addons/test_view.py
index 62a6aeb0..bd724950 100644
--- a/test/mitmproxy/addons/test_view.py
+++ b/test/mitmproxy/addons/test_view.py
@@ -155,7 +155,7 @@ def test_create():
def test_orders():
v = view.View()
- with taddons.context():
+ with taddons.context(v):
assert v.order_options()
@@ -303,7 +303,7 @@ def test_setgetval():
def test_order():
v = view.View()
- with taddons.context() as tctx:
+ with taddons.context(v) as tctx:
v.request(tft(method="get", start=1))
v.request(tft(method="put", start=2))
v.request(tft(method="get", start=3))
@@ -434,7 +434,7 @@ def test_signals():
def test_focus_follow():
v = view.View()
- with taddons.context() as tctx:
+ with taddons.context(v) as tctx:
console_addon = consoleaddons.ConsoleAddon(tctx.master)
tctx.configure(console_addon)
tctx.configure(v, console_focus_follow=True, view_filter="~m get")
@@ -553,7 +553,7 @@ def test_settings():
def test_configure():
v = view.View()
- with taddons.context() as tctx:
+ with taddons.context(v) as tctx:
tctx.configure(v, view_filter="~q")
with pytest.raises(Exception, match="Invalid interception filter"):
tctx.configure(v, view_filter="~~")