aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2016-08-10 02:22:39 -0700
committerGitHub <noreply@github.com>2016-08-10 02:22:39 -0700
commitea2f23feffd95877fe4d59666c42dd45319f18b9 (patch)
tree9ac4359b3774b35b20acc81d4863574eda39e2b6
parent4f5e312fbcee608a116f5d8bc35e5334dafdb845 (diff)
parent5a22496ee8f6f72bc5f75f623ca0d68d7a0d7855 (diff)
downloadmitmproxy-ea2f23feffd95877fe4d59666c42dd45319f18b9.tar.gz
mitmproxy-ea2f23feffd95877fe4d59666c42dd45319f18b9.tar.bz2
mitmproxy-ea2f23feffd95877fe4d59666c42dd45319f18b9.zip
Merge pull request #1474 from mhils/reply-fix
Improve controller.Reply semantics
-rw-r--r--mitmproxy/builtins/replace.py4
-rw-r--r--mitmproxy/builtins/setheaders.py4
-rw-r--r--mitmproxy/console/common.py2
-rw-r--r--mitmproxy/console/flowlist.py4
-rw-r--r--mitmproxy/console/flowview.py8
-rw-r--r--mitmproxy/console/master.py1
-rw-r--r--mitmproxy/controller.py147
-rw-r--r--mitmproxy/flow/master.py10
-rw-r--r--mitmproxy/flow/state.py2
-rw-r--r--mitmproxy/models/flow.py13
-rw-r--r--mitmproxy/script/concurrent.py8
-rw-r--r--mitmproxy/web/app.py3
-rw-r--r--mitmproxy/web/master.py1
-rw-r--r--test/mitmproxy/builtins/test_anticache.py4
-rw-r--r--test/mitmproxy/builtins/test_anticomp.py4
-rw-r--r--test/mitmproxy/builtins/test_dumper.py2
-rw-r--r--test/mitmproxy/builtins/test_filestreamer.py6
-rw-r--r--test/mitmproxy/builtins/test_replace.py4
-rw-r--r--test/mitmproxy/builtins/test_script.py4
-rw-r--r--test/mitmproxy/builtins/test_setheaders.py8
-rw-r--r--test/mitmproxy/builtins/test_stickyauth.py4
-rw-r--r--test/mitmproxy/builtins/test_stickycookie.py18
-rw-r--r--test/mitmproxy/mastertest.py21
-rw-r--r--test/mitmproxy/script/test_concurrent.py6
-rw-r--r--test/mitmproxy/test_controller.py134
-rw-r--r--test/mitmproxy/test_examples.py22
-rw-r--r--test/mitmproxy/test_flow.py29
27 files changed, 319 insertions, 154 deletions
diff --git a/mitmproxy/builtins/replace.py b/mitmproxy/builtins/replace.py
index 2c94fbb5..c938d683 100644
--- a/mitmproxy/builtins/replace.py
+++ b/mitmproxy/builtins/replace.py
@@ -41,9 +41,9 @@ class Replace:
f.request.replace(rex, s)
def request(self, flow):
- if not flow.reply.acked:
+ if not flow.reply.has_message:
self.execute(flow)
def response(self, flow):
- if not flow.reply.acked:
+ if not flow.reply.has_message:
self.execute(flow)
diff --git a/mitmproxy/builtins/setheaders.py b/mitmproxy/builtins/setheaders.py
index 4a784a1d..4cb9905e 100644
--- a/mitmproxy/builtins/setheaders.py
+++ b/mitmproxy/builtins/setheaders.py
@@ -31,9 +31,9 @@ class SetHeaders:
hdrs.add(header, value)
def request(self, flow):
- if not flow.reply.acked:
+ if not flow.reply.has_message:
self.run(flow, flow.request.headers)
def response(self, flow):
- if not flow.reply.acked:
+ if not flow.reply.has_message:
self.run(flow, flow.response.headers)
diff --git a/mitmproxy/console/common.py b/mitmproxy/console/common.py
index 2eb6a7d9..5a24e789 100644
--- a/mitmproxy/console/common.py
+++ b/mitmproxy/console/common.py
@@ -413,7 +413,7 @@ def raw_format_flow(f, focus, extended):
def format_flow(f, focus, extended=False, hostheader=False):
d = dict(
intercepted = f.intercepted,
- acked = f.reply.acked,
+ acked = f.reply.state == "committed",
req_timestamp = f.request.timestamp_start,
req_is_replay = f.request.is_replay,
diff --git a/mitmproxy/console/flowlist.py b/mitmproxy/console/flowlist.py
index 12caf315..7e69e098 100644
--- a/mitmproxy/console/flowlist.py
+++ b/mitmproxy/console/flowlist.py
@@ -182,7 +182,7 @@ class ConnectionItem(urwid.WidgetWrap):
self.flow.accept_intercept(self.master)
signals.flowlist_change.send(self)
elif key == "d":
- if not self.flow.reply.acked:
+ if self.flow.killable:
self.flow.kill(self.master)
self.state.delete_flow(self.flow)
signals.flowlist_change.send(self)
@@ -246,7 +246,7 @@ class ConnectionItem(urwid.WidgetWrap):
callback = self.save_flows_prompt,
)
elif key == "X":
- if not self.flow.reply.acked:
+ if self.flow.killable:
self.flow.kill(self.master)
elif key == "enter":
if self.flow.request:
diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py
index 1c3c4e98..6d74be65 100644
--- a/mitmproxy/console/flowview.py
+++ b/mitmproxy/console/flowview.py
@@ -8,7 +8,6 @@ import urwid
from typing import Optional, Union # noqa
from mitmproxy import contentviews
-from mitmproxy import controller
from mitmproxy import models
from mitmproxy import utils
from mitmproxy.console import common
@@ -148,13 +147,13 @@ class FlowView(tabs.Tabs):
signals.flow_change.connect(self.sig_flow_change)
def tab_request(self):
- if self.flow.intercepted and not self.flow.reply.acked and not self.flow.response:
+ if self.flow.intercepted and not self.flow.response:
return "Request intercepted"
else:
return "Request"
def tab_response(self):
- if self.flow.intercepted and not self.flow.reply.acked and self.flow.response:
+ if self.flow.intercepted and self.flow.response:
return "Response intercepted"
else:
return "Response"
@@ -379,7 +378,6 @@ class FlowView(tabs.Tabs):
self.flow.request.http_version,
200, b"OK", Headers(), b""
)
- self.flow.response.reply = controller.DummyReply()
message = self.flow.response
self.flow.backup()
@@ -538,7 +536,7 @@ class FlowView(tabs.Tabs):
else:
self.view_next_flow(self.flow)
f = self.flow
- if not f.reply.acked:
+ if f.killable:
f.kill(self.master)
self.state.delete_flow(f)
elif key == "D":
diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py
index 18a4c1f0..a6942ca4 100644
--- a/mitmproxy/console/master.py
+++ b/mitmproxy/console/master.py
@@ -736,7 +736,6 @@ class ConsoleMaster(flow.FlowMaster):
)
if should_intercept:
f.intercept(self)
- f.reply.take()
signals.flowlist_change.send(self)
signals.flow_change.send(self, flow = f)
diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py
index 35817a85..72374f31 100644
--- a/mitmproxy/controller.py
+++ b/mitmproxy/controller.py
@@ -185,6 +185,7 @@ class Channel(object):
if g == exceptions.Kill:
raise exceptions.Kill()
return g
+ m.reply._state = "committed" # suppress error message in __del__
raise exceptions.Kill()
def tell(self, mtype, m):
@@ -202,34 +203,47 @@ def handler(f):
if not hasattr(message, "reply"):
raise exceptions.ControlException("Message %s has no reply attribute" % message)
+ # DummyReplys may be reused multiple times.
+ # We only clear them up on the next handler so that we can access value and
+ # state in the meantime.
+ if isinstance(message.reply, DummyReply):
+ message.reply.reset()
+
# The following ensures that inheritance with wrapped handlers in the
# base class works. If we're the first handler, then responsibility for
# acking is ours. If not, it's someone else's and we ignore it.
handling = False
# We're the first handler - ack responsibility is ours
- if not message.reply.handled:
+ if message.reply.state == "unhandled":
handling = True
- message.reply.handled = True
+ message.reply.handle()
with master.handlecontext():
ret = f(master, message)
if handling:
master.addons(f.__name__, message)
- if handling and not message.reply.acked and not message.reply.taken:
- message.reply.ack()
-
# Reset the handled flag - it's common for us to feed the same object
# through handlers repeatedly, so we don't want this to persist across
# calls.
- if message.reply.handled:
- message.reply.handled = False
+ if handling and message.reply.state == "handled":
+ message.reply.take()
+ if not message.reply.has_message:
+ message.reply.ack()
+ message.reply.commit()
+
+ # DummyReplys may be reused multiple times.
+ if isinstance(message.reply, DummyReply):
+ message.reply.mark_reset()
return ret
# Mark this function as a handler wrapper
wrapper.__dict__["__handler"] = True
return wrapper
+NO_REPLY = object() # special object we can distinguish from a valid "None" reply.
+
+
class Reply(object):
"""
Messages sent through a channel are decorated with a "reply" attribute.
@@ -238,53 +252,104 @@ class Reply(object):
"""
def __init__(self, obj):
self.obj = obj
- self.q = queue.Queue()
- # Has this message been acked?
- self.acked = False
- # Has the user taken responsibility for ack-ing?
- self.taken = False
- # Has a handler taken responsibility for ack-ing?
- self.handled = False
+ self.q = queue.Queue() # type: queue.Queue
+
+ self._state = "unhandled" # "unhandled" -> "handled" -> "taken" -> "committed"
+ self.value = NO_REPLY # holds the reply value. May change before things are actually commited.
+
+ @property
+ def state(self):
+ """
+ The state the reply is currently in. A normal reply object goes sequentially through the following lifecycle:
+
+ 1. unhandled: Initial State.
+ 2. handled: The reply object has been handled by the topmost handler function.
+ 3. taken: The reply object has been taken to be commited.
+ 4. committed: The reply has been sent back to the requesting party.
+
+ This attribute is read-only and can only be modified by calling one of state transition functions.
+ """
+ return self._state
- def ack(self):
- self.send(self.obj)
+ @property
+ def has_message(self):
+ return self.value != NO_REPLY
- def kill(self):
- self.send(exceptions.Kill)
+ @property
+ def done(self):
+ return self.state == "committed"
+
+ def handle(self):
+ """
+ Reply are handled by controller.handlers, which may be nested. The first handler takes
+ responsibility and handles the reply.
+ """
+ if self.state != "unhandled":
+ raise exceptions.ControlException("Reply is {}, but expected it to be unhandled.".format(self.state))
+ self._state = "handled"
def take(self):
- self.taken = True
+ """
+ Scripts or other parties make "take" a reply out of a normal flow.
+ For example, intercepted flows are taken out so that the connection thread does not proceed.
+ """
+ if self.state != "handled":
+ raise exceptions.ControlException("Reply is {}, but expected it to be handled.".format(self.state))
+ self._state = "taken"
- def send(self, msg):
- if self.acked:
- raise exceptions.ControlException("Message already acked.")
- self.acked = True
- self.q.put(msg)
+ def commit(self):
+ """
+ Ultimately, messages are commited. This is done either automatically by the handler
+ if the message is not taken or manually by the entity which called .take().
+ """
+ if self.state != "taken":
+ raise exceptions.ControlException("Reply is {}, but expected it to be taken.".format(self.state))
+ if not self.has_message:
+ raise exceptions.ControlException("There is no reply message.")
+ self._state = "committed"
+ self.q.put(self.value)
+
+ def ack(self, force=False):
+ self.send(self.obj, force)
+
+ def kill(self, force=False):
+ self.send(exceptions.Kill, force)
+
+ def send(self, msg, force=False):
+ if self.state not in ("handled", "taken"):
+ raise exceptions.ControlException(
+ "Reply is {}, did not expect a call to .send().".format(self.state)
+ )
+ if self.has_message and not force:
+ raise exceptions.ControlException("There is already a reply message.")
+ self.value = msg
def __del__(self):
- if not self.acked:
+ if self.state != "committed":
# This will be ignored by the interpreter, but emit a warning
- raise exceptions.ControlException("Un-acked message: %s" % self.obj)
+ raise exceptions.ControlException("Uncommitted reply: %s" % self.obj)
-class DummyReply(object):
+class DummyReply(Reply):
"""
- A reply object that does nothing. Useful when we need an object to seem
- like it has a channel, and during testing.
+ A reply object that is not connected to anything. In contrast to regular Reply objects,
+ DummyReply objects are reset to "unhandled" at the end of an handler so that they can be used
+ multiple times. Useful when we need an object to seem like it has a channel,
+ and during testing.
"""
def __init__(self):
- self.acked = False
- self.taken = False
- self.handled = False
+ super(DummyReply, self).__init__(None)
+ self._should_reset = False
- def kill(self):
- self.send(None)
+ def mark_reset(self):
+ if self.state != "committed":
+ raise exceptions.ControlException("Uncommitted reply: %s" % self.obj)
+ self._should_reset = True
- def ack(self):
- self.send(None)
+ def reset(self):
+ if self._should_reset:
+ self._state = "unhandled"
+ self.value = NO_REPLY
- def take(self):
- self.taken = True
-
- def send(self, msg):
- self.acked = True
+ def __del__(self):
+ pass
diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py
index 65a95e44..0475ef4e 100644
--- a/mitmproxy/flow/master.py
+++ b/mitmproxy/flow/master.py
@@ -234,7 +234,7 @@ class FlowMaster(controller.Master):
pb = self.do_server_playback(f)
if not pb and self.kill_nonreplay:
self.add_log("Killed {}".format(f.request.url), "info")
- f.kill(self)
+ f.reply.kill()
def replay_request(self, f, block=False):
"""
@@ -314,8 +314,7 @@ class FlowMaster(controller.Master):
return
if f not in self.state.flows: # don't add again on replay
self.state.add_flow(f)
- if not f.reply.acked:
- self.process_new_request(f)
+ self.process_new_request(f)
return f
@controller.handler
@@ -331,9 +330,8 @@ class FlowMaster(controller.Master):
@controller.handler
def response(self, f):
self.state.update_flow(f)
- if not f.reply.acked:
- if self.client_playback:
- self.client_playback.clear(f)
+ if self.client_playback:
+ self.client_playback.clear(f)
return f
def handle_intercept(self, f):
diff --git a/mitmproxy/flow/state.py b/mitmproxy/flow/state.py
index efcb2d89..8576fadc 100644
--- a/mitmproxy/flow/state.py
+++ b/mitmproxy/flow/state.py
@@ -178,7 +178,7 @@ class FlowStore(FlowList):
def kill_all(self, master):
for f in self._list:
- if not f.reply.acked:
+ if f.killable:
f.kill(master)
diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py
index f4a2b54b..fc673274 100644
--- a/mitmproxy/models/flow.py
+++ b/mitmproxy/models/flow.py
@@ -149,13 +149,22 @@ class Flow(stateobject.StateObject):
self.set_state(self._backup)
self._backup = None
+ @property
+ def killable(self):
+ return self.reply and self.reply.state in {"handled", "taken"}
+
def kill(self, master):
"""
Kill this request.
"""
self.error = Error("Connection killed")
self.intercepted = False
- self.reply.kill()
+ # reply.state should only be "handled" or "taken" here.
+ # if none of this is the case, .take() will raise an exception.
+ if self.reply.state != "taken":
+ self.reply.take()
+ self.reply.kill(force=True)
+ self.reply.commit()
master.error(self)
def intercept(self, master):
@@ -166,6 +175,7 @@ class Flow(stateobject.StateObject):
if self.intercepted:
return
self.intercepted = True
+ self.reply.take()
master.handle_intercept(self)
def accept_intercept(self, master):
@@ -176,6 +186,7 @@ class Flow(stateobject.StateObject):
return
self.intercepted = False
self.reply.ack()
+ self.reply.commit()
master.handle_accept_intercept(self)
def match(self, f):
diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py
index 0cc0514e..9ed08065 100644
--- a/mitmproxy/script/concurrent.py
+++ b/mitmproxy/script/concurrent.py
@@ -13,7 +13,7 @@ class ScriptThread(basethread.BaseThread):
def concurrent(fn):
- if fn.__name__ not in controller.Events - set(["start", "configure", "tick"]):
+ if fn.__name__ not in controller.Events - {"start", "configure", "tick"}:
raise NotImplementedError(
"Concurrent decorator not supported for '%s' method." % fn.__name__
)
@@ -21,8 +21,10 @@ def concurrent(fn):
def _concurrent(obj):
def run():
fn(obj)
- if not obj.reply.acked:
- obj.reply.ack()
+ if obj.reply.state == "taken":
+ if not obj.reply.has_message:
+ obj.reply.ack()
+ obj.reply.commit()
obj.reply.take()
ScriptThread(
"script.concurrent (%s)" % fn.__name__,
diff --git a/mitmproxy/web/app.py b/mitmproxy/web/app.py
index f8f85f3d..5bd6f274 100644
--- a/mitmproxy/web/app.py
+++ b/mitmproxy/web/app.py
@@ -234,7 +234,7 @@ class AcceptFlow(RequestHandler):
class FlowHandler(RequestHandler):
def delete(self, flow_id):
- if not self.flow.reply.acked:
+ if self.flow.killable:
self.flow.kill(self.master)
self.state.delete_flow(self.flow)
@@ -438,6 +438,7 @@ class Application(tornado.web.Application):
xsrf_cookies=True,
cookie_secret=os.urandom(256),
debug=debug,
+ autoreload=False,
wauthenticator=wauthenticator,
)
super(Application, self).__init__(handlers, **settings)
diff --git a/mitmproxy/web/master.py b/mitmproxy/web/master.py
index 9ddb61d4..5751c9dd 100644
--- a/mitmproxy/web/master.py
+++ b/mitmproxy/web/master.py
@@ -183,7 +183,6 @@ class WebMaster(flow.FlowMaster):
if self.state.intercept and self.state.intercept(
f) and not f.request.is_replay:
f.intercept(self)
- f.reply.take()
return f
@controller.handler
diff --git a/test/mitmproxy/builtins/test_anticache.py b/test/mitmproxy/builtins/test_anticache.py
index ac321e26..8897de52 100644
--- a/test/mitmproxy/builtins/test_anticache.py
+++ b/test/mitmproxy/builtins/test_anticache.py
@@ -14,11 +14,11 @@ class TestAntiCache(mastertest.MasterTest):
m.addons.add(o, sa)
f = tutils.tflow(resp=True)
- self.invoke(m, "request", f)
+ m.request(f)
f = tutils.tflow(resp=True)
f.request.headers["if-modified-since"] = "test"
f.request.headers["if-none-match"] = "test"
- self.invoke(m, "request", f)
+ m.request(f)
assert "if-modified-since" not in f.request.headers
assert "if-none-match" not in f.request.headers
diff --git a/test/mitmproxy/builtins/test_anticomp.py b/test/mitmproxy/builtins/test_anticomp.py
index a5f5a270..af9e4a6a 100644
--- a/test/mitmproxy/builtins/test_anticomp.py
+++ b/test/mitmproxy/builtins/test_anticomp.py
@@ -14,10 +14,10 @@ class TestAntiComp(mastertest.MasterTest):
m.addons.add(o, sa)
f = tutils.tflow(resp=True)
- self.invoke(m, "request", f)
+ m.request(f)
f = tutils.tflow(resp=True)
f.request.headers["Accept-Encoding"] = "foobar"
- self.invoke(m, "request", f)
+ m.request(f)
assert f.request.headers["Accept-Encoding"] == "identity"
diff --git a/test/mitmproxy/builtins/test_dumper.py b/test/mitmproxy/builtins/test_dumper.py
index 1c7173e0..9b518b53 100644
--- a/test/mitmproxy/builtins/test_dumper.py
+++ b/test/mitmproxy/builtins/test_dumper.py
@@ -80,5 +80,5 @@ class TestContentView(mastertest.MasterTest):
m = mastertest.RecordingMaster(o, None, s)
d = dumper.Dumper()
m.addons.add(o, d)
- self.invoke(m, "response", tutils.tflow())
+ m.response(tutils.tflow())
assert "Content viewer failed" in m.event_log[0][1]
diff --git a/test/mitmproxy/builtins/test_filestreamer.py b/test/mitmproxy/builtins/test_filestreamer.py
index 0e69b340..94d68813 100644
--- a/test/mitmproxy/builtins/test_filestreamer.py
+++ b/test/mitmproxy/builtins/test_filestreamer.py
@@ -28,8 +28,8 @@ class TestStream(mastertest.MasterTest):
m.addons.add(o, sa)
f = tutils.tflow(resp=True)
- self.invoke(m, "request", f)
- self.invoke(m, "response", f)
+ m.request(f)
+ m.response(f)
m.addons.remove(sa)
assert r()[0].response
@@ -38,6 +38,6 @@ class TestStream(mastertest.MasterTest):
m.addons.add(o, sa)
f = tutils.tflow()
- self.invoke(m, "request", f)
+ m.request(f)
m.addons.remove(sa)
assert not r()[1].response
diff --git a/test/mitmproxy/builtins/test_replace.py b/test/mitmproxy/builtins/test_replace.py
index 5e70ce56..07abcda4 100644
--- a/test/mitmproxy/builtins/test_replace.py
+++ b/test/mitmproxy/builtins/test_replace.py
@@ -43,10 +43,10 @@ class TestReplace(mastertest.MasterTest):
f = tutils.tflow()
f.request.content = b"foo"
- self.invoke(m, "request", f)
+ m.request(f)
assert f.request.content == b"bar"
f = tutils.tflow(resp=True)
f.response.content = b"foo"
- self.invoke(m, "response", f)
+ m.response(f)
assert f.response.content == b"bar"
diff --git a/test/mitmproxy/builtins/test_script.py b/test/mitmproxy/builtins/test_script.py
index 2870fd17..0bac6ca0 100644
--- a/test/mitmproxy/builtins/test_script.py
+++ b/test/mitmproxy/builtins/test_script.py
@@ -69,7 +69,7 @@ class TestScript(mastertest.MasterTest):
sc.ns.call_log = []
f = tutils.tflow(resp=True)
- self.invoke(m, "request", f)
+ m.request(f)
recf = sc.ns.call_log[0]
assert recf[1] == "request"
@@ -102,7 +102,7 @@ class TestScript(mastertest.MasterTest):
)
m.addons.add(o, sc)
f = tutils.tflow(resp=True)
- self.invoke(m, "request", f)
+ m.request(f)
assert m.event_log[0][0] == "error"
def test_duplicate_flow(self):
diff --git a/test/mitmproxy/builtins/test_setheaders.py b/test/mitmproxy/builtins/test_setheaders.py
index 41c18360..63685177 100644
--- a/test/mitmproxy/builtins/test_setheaders.py
+++ b/test/mitmproxy/builtins/test_setheaders.py
@@ -33,12 +33,12 @@ class TestSetHeaders(mastertest.MasterTest):
)
f = tutils.tflow()
f.request.headers["one"] = "xxx"
- self.invoke(m, "request", f)
+ m.request(f)
assert f.request.headers["one"] == "two"
f = tutils.tflow(resp=True)
f.response.headers["one"] = "xxx"
- self.invoke(m, "response", f)
+ m.response(f)
assert f.response.headers["one"] == "three"
m, sh = self.mkmaster(
@@ -50,7 +50,7 @@ class TestSetHeaders(mastertest.MasterTest):
f = tutils.tflow(resp=True)
f.request.headers["one"] = "xxx"
f.response.headers["one"] = "xxx"
- self.invoke(m, "response", f)
+ m.response(f)
assert f.response.headers.get_all("one") == ["two", "three"]
m, sh = self.mkmaster(
@@ -61,5 +61,5 @@ class TestSetHeaders(mastertest.MasterTest):
)
f = tutils.tflow()
f.request.headers["one"] = "xxx"
- self.invoke(m, "request", f)
+ m.request(f)
assert f.request.headers.get_all("one") == ["two", "three"]
diff --git a/test/mitmproxy/builtins/test_stickyauth.py b/test/mitmproxy/builtins/test_stickyauth.py
index 5757fb2d..00b12072 100644
--- a/test/mitmproxy/builtins/test_stickyauth.py
+++ b/test/mitmproxy/builtins/test_stickyauth.py
@@ -15,10 +15,10 @@ class TestStickyAuth(mastertest.MasterTest):
f = tutils.tflow(resp=True)
f.request.headers["authorization"] = "foo"
- self.invoke(m, "request", f)
+ m.request(f)
assert "address" in sa.hosts
f = tutils.tflow(resp=True)
- self.invoke(m, "request", f)
+ m.request(f)
assert f.request.headers["authorization"] == "foo"
diff --git a/test/mitmproxy/builtins/test_stickycookie.py b/test/mitmproxy/builtins/test_stickycookie.py
index e9d92c83..13cea6a2 100644
--- a/test/mitmproxy/builtins/test_stickycookie.py
+++ b/test/mitmproxy/builtins/test_stickycookie.py
@@ -34,23 +34,23 @@ class TestStickyCookie(mastertest.MasterTest):
f = tutils.tflow(resp=True)
f.response.headers["set-cookie"] = "foo=bar"
- self.invoke(m, "request", f)
+ m.request(f)
f.reply.acked = False
- self.invoke(m, "response", f)
+ m.response(f)
assert sc.jar
assert "cookie" not in f.request.headers
f = f.copy()
f.reply.acked = False
- self.invoke(m, "request", f)
+ m.request(f)
assert f.request.headers["cookie"] == "foo=bar"
def _response(self, s, m, sc, cookie, host):
f = tutils.tflow(req=ntutils.treq(host=host, port=80), resp=True)
f.response.headers["Set-Cookie"] = cookie
- self.invoke(m, "response", f)
+ m.response(f)
return f
def test_response(self):
@@ -79,7 +79,7 @@ class TestStickyCookie(mastertest.MasterTest):
c2 = "othercookie=helloworld; Path=/"
f = self._response(s, m, sc, c1, "www.google.com")
f.response.headers["Set-Cookie"] = c2
- self.invoke(m, "response", f)
+ m.response(f)
googlekey = list(sc.jar.keys())[0]
assert len(sc.jar[googlekey].keys()) == 2
@@ -96,7 +96,7 @@ class TestStickyCookie(mastertest.MasterTest):
]
for c in cs:
f.response.headers["Set-Cookie"] = c
- self.invoke(m, "response", f)
+ m.response(f)
googlekey = list(sc.jar.keys())[0]
assert len(sc.jar[googlekey].keys()) == len(cs)
@@ -108,7 +108,7 @@ class TestStickyCookie(mastertest.MasterTest):
c2 = "somecookie=newvalue; Path=/"
f = self._response(s, m, sc, c1, "www.google.com")
f.response.headers["Set-Cookie"] = c2
- self.invoke(m, "response", f)
+ m.response(f)
googlekey = list(sc.jar.keys())[0]
assert len(sc.jar[googlekey].keys()) == 1
assert list(sc.jar[googlekey]["somecookie"].items())[0][1] == "newvalue"
@@ -120,7 +120,7 @@ class TestStickyCookie(mastertest.MasterTest):
# by setting the expire time in the past
f = self._response(s, m, sc, "duffer=zafar; Path=/", "www.google.com")
f.response.headers["Set-Cookie"] = "duffer=; Expires=Thu, 01-Jan-1970 00:00:00 GMT"
- self.invoke(m, "response", f)
+ m.response(f)
assert not sc.jar.keys()
def test_request(self):
@@ -128,5 +128,5 @@ class TestStickyCookie(mastertest.MasterTest):
f = self._response(s, m, sc, "SSID=mooo", "www.google.com")
assert "cookie" not in f.request.headers
- self.invoke(m, "request", f)
+ m.request(f)
assert "cookie" in f.request.headers
diff --git a/test/mitmproxy/mastertest.py b/test/mitmproxy/mastertest.py
index dcc0dc48..08659d19 100644
--- a/test/mitmproxy/mastertest.py
+++ b/test/mitmproxy/mastertest.py
@@ -1,5 +1,3 @@
-import mock
-
from . import tutils
import netlib.tutils
@@ -8,26 +6,19 @@ from mitmproxy import flow, proxy, models, controller
class MasterTest:
- def invoke(self, master, handler, *message):
- with master.handlecontext():
- func = getattr(master, handler)
- func(*message)
- if message:
- message[0].reply = controller.DummyReply()
def cycle(self, master, content):
f = tutils.tflow(req=netlib.tutils.treq(content=content))
l = proxy.Log("connect")
- l.reply = mock.MagicMock()
+ l.reply = controller.DummyReply()
master.log(l)
- self.invoke(master, "clientconnect", f.client_conn)
- self.invoke(master, "clientconnect", f.client_conn)
- self.invoke(master, "serverconnect", f.server_conn)
- self.invoke(master, "request", f)
+ master.clientconnect(f.client_conn)
+ master.serverconnect(f.server_conn)
+ master.request(f)
if not f.error:
f.response = models.HTTPResponse.wrap(netlib.tutils.tresp(content=content))
- self.invoke(master, "response", f)
- self.invoke(master, "clientdisconnect", f)
+ master.response(f)
+ master.clientdisconnect(f)
return f
def dummy_cycle(self, master, n, content):
diff --git a/test/mitmproxy/script/test_concurrent.py b/test/mitmproxy/script/test_concurrent.py
index a5f76994..c4f1e9ae 100644
--- a/test/mitmproxy/script/test_concurrent.py
+++ b/test/mitmproxy/script/test_concurrent.py
@@ -25,11 +25,11 @@ class TestConcurrent(mastertest.MasterTest):
)
m.addons.add(m.options, sc)
f1, f2 = tutils.tflow(), tutils.tflow()
- self.invoke(m, "request", f1)
- self.invoke(m, "request", f2)
+ m.request(f1)
+ m.request(f2)
start = time.time()
while time.time() - start < 5:
- if f1.reply.acked and f2.reply.acked:
+ if f1.reply.state == f2.reply.state == "committed":
return
raise ValueError("Script never acked")
diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py
index 6d4b8fe6..abf66b6c 100644
--- a/test/mitmproxy/test_controller.py
+++ b/test/mitmproxy/test_controller.py
@@ -1,3 +1,4 @@
+from test.mitmproxy import tutils
from threading import Thread, Event
from mock import Mock
@@ -5,7 +6,7 @@ from mock import Mock
from mitmproxy import controller
from six.moves import queue
-from mitmproxy.exceptions import Kill
+from mitmproxy.exceptions import Kill, ControlException
from mitmproxy.proxy import DummyServer
from netlib.tutils import raises
@@ -55,7 +56,7 @@ class TestChannel(object):
def test_tell(self):
q = queue.Queue()
channel = controller.Channel(q, Event())
- m = Mock()
+ m = Mock(name="test_tell")
channel.tell("test", m)
assert q.get() == ("test", m)
assert m.reply
@@ -66,12 +67,15 @@ class TestChannel(object):
def reply():
m, obj = q.get()
assert m == "test"
+ obj.reply.handle()
obj.reply.send(42)
+ obj.reply.take()
+ obj.reply.commit()
Thread(target=reply).start()
channel = controller.Channel(q, Event())
- assert channel.ask("test", Mock()) == 42
+ assert channel.ask("test", Mock(name="test_ask_simple")) == 42
def test_ask_shutdown(self):
q = queue.Queue()
@@ -79,31 +83,125 @@ class TestChannel(object):
done.set()
channel = controller.Channel(q, done)
with raises(Kill):
- channel.ask("test", Mock())
-
-
-class TestDummyReply(object):
- def test_simple(self):
- reply = controller.DummyReply()
- assert not reply.acked
- reply.ack()
- assert reply.acked
+ channel.ask("test", Mock(name="test_ask_shutdown"))
class TestReply(object):
def test_simple(self):
reply = controller.Reply(42)
- assert not reply.acked
+ assert reply.state == "unhandled"
+
+ reply.handle()
+ assert reply.state == "handled"
+
reply.send("foo")
- assert reply.acked
+ assert reply.value == "foo"
+
+ reply.take()
+ assert reply.state == "taken"
+
+ with tutils.raises(queue.Empty):
+ reply.q.get_nowait()
+ reply.commit()
+ assert reply.state == "committed"
assert reply.q.get() == "foo"
- def test_default(self):
- reply = controller.Reply(42)
+ def test_kill(self):
+ reply = controller.Reply(43)
+ reply.handle()
+ reply.kill()
+ reply.take()
+ reply.commit()
+ assert reply.q.get() == Kill
+
+ def test_ack(self):
+ reply = controller.Reply(44)
+ reply.handle()
reply.ack()
- assert reply.q.get() == 42
+ reply.take()
+ reply.commit()
+ assert reply.q.get() == 44
def test_reply_none(self):
- reply = controller.Reply(42)
+ reply = controller.Reply(45)
+ reply.handle()
reply.send(None)
+ reply.take()
+ reply.commit()
assert reply.q.get() is None
+
+ def test_commit_no_reply(self):
+ reply = controller.Reply(46)
+ reply.handle()
+ reply.take()
+ with tutils.raises(ControlException):
+ reply.commit()
+ reply.ack()
+ reply.commit()
+
+ def test_double_send(self):
+ reply = controller.Reply(47)
+ reply.handle()
+ reply.send(1)
+ with tutils.raises(ControlException):
+ reply.send(2)
+ reply.take()
+ reply.commit()
+
+ def test_state_transitions(self):
+ states = {"unhandled", "handled", "taken", "committed"}
+ accept = {
+ "handle": {"unhandled"},
+ "take": {"handled"},
+ "commit": {"taken"},
+ "ack": {"handled", "taken"},
+ }
+ for fn, ok in accept.items():
+ for state in states:
+ r = controller.Reply(48)
+ r._state = state
+ if fn == "commit":
+ r.value = 49
+ if state in ok:
+ getattr(r, fn)()
+ else:
+ with tutils.raises(ControlException):
+ getattr(r, fn)()
+ r._state = "committed" # hide warnings on deletion
+
+ def test_del(self):
+ reply = controller.Reply(47)
+ with tutils.raises(ControlException):
+ reply.__del__()
+ reply.handle()
+ reply.ack()
+ reply.take()
+ reply.commit()
+
+
+class TestDummyReply(object):
+ def test_simple(self):
+ reply = controller.DummyReply()
+ for _ in range(2):
+ reply.handle()
+ reply.ack()
+ reply.take()
+ reply.commit()
+ reply.mark_reset()
+ reply.reset()
+ assert reply.state == "unhandled"
+
+ def test_reset(self):
+ reply = controller.DummyReply()
+ reply.handle()
+ reply.ack()
+ reply.take()
+ reply.commit()
+ reply.mark_reset()
+ assert reply.state == "committed"
+ reply.reset()
+ assert reply.state == "unhandled"
+
+ def test_del(self):
+ reply = controller.DummyReply()
+ reply.__del__()
diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py
index 34fcc261..6c24ace5 100644
--- a/test/mitmproxy/test_examples.py
+++ b/test/mitmproxy/test_examples.py
@@ -39,7 +39,7 @@ class TestScripts(mastertest.MasterTest):
def test_add_header(self):
m, _ = tscript("add_header.py")
f = tutils.tflow(resp=netutils.tresp())
- self.invoke(m, "response", f)
+ m.response(f)
assert f.response.headers["newheader"] == "foo"
def test_custom_contentviews(self):
@@ -54,9 +54,9 @@ class TestScripts(mastertest.MasterTest):
tscript("iframe_injector.py")
m, sc = tscript("iframe_injector.py", "http://example.org/evil_iframe")
- flow = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>"))
- self.invoke(m, "response", flow)
- content = flow.response.content
+ f = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>"))
+ m.response(f)
+ content = f.response.content
assert b'iframe' in content and b'evil_iframe' in content
def test_modify_form(self):
@@ -64,23 +64,23 @@ class TestScripts(mastertest.MasterTest):
form_header = Headers(content_type="application/x-www-form-urlencoded")
f = tutils.tflow(req=netutils.treq(headers=form_header))
- self.invoke(m, "request", f)
+ m.request(f)
assert f.request.urlencoded_form[b"mitmproxy"] == b"rocks"
f.request.headers["content-type"] = ""
- self.invoke(m, "request", f)
+ m.request(f)
assert list(f.request.urlencoded_form.items()) == [(b"foo", b"bar")]
def test_modify_querystring(self):
m, sc = tscript("modify_querystring.py")
f = tutils.tflow(req=netutils.treq(path="/search?q=term"))
- self.invoke(m, "request", f)
+ m.request(f)
assert f.request.query["mitmproxy"] == "rocks"
f.request.path = "/"
- self.invoke(m, "request", f)
+ m.request(f)
assert f.request.query["mitmproxy"] == "rocks"
def test_modify_response_body(self):
@@ -89,13 +89,13 @@ class TestScripts(mastertest.MasterTest):
m, sc = tscript("modify_response_body.py", "mitmproxy rocks")
f = tutils.tflow(resp=netutils.tresp(content=b"I <3 mitmproxy"))
- self.invoke(m, "response", f)
+ m.response(f)
assert f.response.content == b"I <3 rocks"
def test_redirect_requests(self):
m, sc = tscript("redirect_requests.py")
f = tutils.tflow(req=netutils.treq(host="example.org"))
- self.invoke(m, "request", f)
+ m.request(f)
assert f.request.host == "mitmproxy.org"
def test_har_extractor(self):
@@ -119,7 +119,7 @@ class TestScripts(mastertest.MasterTest):
req=netutils.treq(**times),
resp=netutils.tresp(**times)
)
- self.invoke(m, "response", f)
+ m.response(f)
m.addons.remove(sc)
with open(path, "rb") as f:
diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py
index d4bf764c..1caeb100 100644
--- a/test/mitmproxy/test_flow.py
+++ b/test/mitmproxy/test_flow.py
@@ -3,9 +3,9 @@ import io
import netlib.utils
from netlib.http import Headers
-from mitmproxy import filt, controller, flow, options
+from mitmproxy import filt, flow, options
from mitmproxy.contrib import tnetstring
-from mitmproxy.exceptions import FlowReadException
+from mitmproxy.exceptions import FlowReadException, Kill
from mitmproxy.models import Error
from mitmproxy.models import Flow
from mitmproxy.models import HTTPFlow
@@ -372,19 +372,23 @@ class TestHTTPFlow(object):
assert f.get_state() == f2.get_state()
def test_kill(self):
- s = flow.State()
- fm = flow.FlowMaster(None, None, s)
+ fm = mock.Mock()
f = tutils.tflow()
- f.intercept(mock.Mock())
+ f.reply.handle()
+ f.intercept(fm)
+ assert fm.handle_intercept.called
+ assert f.killable
f.kill(fm)
- for i in s.view:
- assert "killed" in str(i.error)
+ assert not f.killable
+ assert fm.error.called
+ assert f.reply.value == Kill
def test_killall(self):
s = flow.State()
fm = flow.FlowMaster(None, None, s)
f = tutils.tflow()
+ f.reply.handle()
f.intercept(fm)
s.killall(fm)
@@ -393,11 +397,11 @@ class TestHTTPFlow(object):
def test_accept_intercept(self):
f = tutils.tflow()
-
+ f.reply.handle()
f.intercept(mock.Mock())
- assert not f.reply.acked
+ assert f.reply.state == "taken"
f.accept_intercept(mock.Mock())
- assert f.reply.acked
+ assert f.reply.state == "committed"
def test_replace_unicode(self):
f = tutils.tflow(resp=True)
@@ -735,7 +739,6 @@ class TestFlowMaster:
fm.clientdisconnect(f.client_conn)
f.error = Error("msg")
- f.error.reply = controller.DummyReply()
fm.error(f)
fm.shutdown()
@@ -834,8 +837,8 @@ class TestFlowMaster:
f = tutils.tflow()
f.request.host = "nonexistent"
- fm.process_new_request(f)
- assert "killed" in f.error.msg
+ fm.request(f)
+ assert f.reply.value == Kill
class TestRequest: