diff options
author | Maximilian Hils <git@maximilianhils.com> | 2016-08-10 02:22:39 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-10 02:22:39 -0700 |
commit | ea2f23feffd95877fe4d59666c42dd45319f18b9 (patch) | |
tree | 9ac4359b3774b35b20acc81d4863574eda39e2b6 | |
parent | 4f5e312fbcee608a116f5d8bc35e5334dafdb845 (diff) | |
parent | 5a22496ee8f6f72bc5f75f623ca0d68d7a0d7855 (diff) | |
download | mitmproxy-ea2f23feffd95877fe4d59666c42dd45319f18b9.tar.gz mitmproxy-ea2f23feffd95877fe4d59666c42dd45319f18b9.tar.bz2 mitmproxy-ea2f23feffd95877fe4d59666c42dd45319f18b9.zip |
Merge pull request #1474 from mhils/reply-fix
Improve controller.Reply semantics
27 files changed, 319 insertions, 154 deletions
diff --git a/mitmproxy/builtins/replace.py b/mitmproxy/builtins/replace.py index 2c94fbb5..c938d683 100644 --- a/mitmproxy/builtins/replace.py +++ b/mitmproxy/builtins/replace.py @@ -41,9 +41,9 @@ class Replace: f.request.replace(rex, s) def request(self, flow): - if not flow.reply.acked: + if not flow.reply.has_message: self.execute(flow) def response(self, flow): - if not flow.reply.acked: + if not flow.reply.has_message: self.execute(flow) diff --git a/mitmproxy/builtins/setheaders.py b/mitmproxy/builtins/setheaders.py index 4a784a1d..4cb9905e 100644 --- a/mitmproxy/builtins/setheaders.py +++ b/mitmproxy/builtins/setheaders.py @@ -31,9 +31,9 @@ class SetHeaders: hdrs.add(header, value) def request(self, flow): - if not flow.reply.acked: + if not flow.reply.has_message: self.run(flow, flow.request.headers) def response(self, flow): - if not flow.reply.acked: + if not flow.reply.has_message: self.run(flow, flow.response.headers) diff --git a/mitmproxy/console/common.py b/mitmproxy/console/common.py index 2eb6a7d9..5a24e789 100644 --- a/mitmproxy/console/common.py +++ b/mitmproxy/console/common.py @@ -413,7 +413,7 @@ def raw_format_flow(f, focus, extended): def format_flow(f, focus, extended=False, hostheader=False): d = dict( intercepted = f.intercepted, - acked = f.reply.acked, + acked = f.reply.state == "committed", req_timestamp = f.request.timestamp_start, req_is_replay = f.request.is_replay, diff --git a/mitmproxy/console/flowlist.py b/mitmproxy/console/flowlist.py index 12caf315..7e69e098 100644 --- a/mitmproxy/console/flowlist.py +++ b/mitmproxy/console/flowlist.py @@ -182,7 +182,7 @@ class ConnectionItem(urwid.WidgetWrap): self.flow.accept_intercept(self.master) signals.flowlist_change.send(self) elif key == "d": - if not self.flow.reply.acked: + if self.flow.killable: self.flow.kill(self.master) self.state.delete_flow(self.flow) signals.flowlist_change.send(self) @@ -246,7 +246,7 @@ class ConnectionItem(urwid.WidgetWrap): callback = self.save_flows_prompt, ) elif key == "X": - if not self.flow.reply.acked: + if self.flow.killable: self.flow.kill(self.master) elif key == "enter": if self.flow.request: diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py index 1c3c4e98..6d74be65 100644 --- a/mitmproxy/console/flowview.py +++ b/mitmproxy/console/flowview.py @@ -8,7 +8,6 @@ import urwid from typing import Optional, Union # noqa from mitmproxy import contentviews -from mitmproxy import controller from mitmproxy import models from mitmproxy import utils from mitmproxy.console import common @@ -148,13 +147,13 @@ class FlowView(tabs.Tabs): signals.flow_change.connect(self.sig_flow_change) def tab_request(self): - if self.flow.intercepted and not self.flow.reply.acked and not self.flow.response: + if self.flow.intercepted and not self.flow.response: return "Request intercepted" else: return "Request" def tab_response(self): - if self.flow.intercepted and not self.flow.reply.acked and self.flow.response: + if self.flow.intercepted and self.flow.response: return "Response intercepted" else: return "Response" @@ -379,7 +378,6 @@ class FlowView(tabs.Tabs): self.flow.request.http_version, 200, b"OK", Headers(), b"" ) - self.flow.response.reply = controller.DummyReply() message = self.flow.response self.flow.backup() @@ -538,7 +536,7 @@ class FlowView(tabs.Tabs): else: self.view_next_flow(self.flow) f = self.flow - if not f.reply.acked: + if f.killable: f.kill(self.master) self.state.delete_flow(f) elif key == "D": diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py index 18a4c1f0..a6942ca4 100644 --- a/mitmproxy/console/master.py +++ b/mitmproxy/console/master.py @@ -736,7 +736,6 @@ class ConsoleMaster(flow.FlowMaster): ) if should_intercept: f.intercept(self) - f.reply.take() signals.flowlist_change.send(self) signals.flow_change.send(self, flow = f) diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index 35817a85..72374f31 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -185,6 +185,7 @@ class Channel(object): if g == exceptions.Kill: raise exceptions.Kill() return g + m.reply._state = "committed" # suppress error message in __del__ raise exceptions.Kill() def tell(self, mtype, m): @@ -202,34 +203,47 @@ def handler(f): if not hasattr(message, "reply"): raise exceptions.ControlException("Message %s has no reply attribute" % message) + # DummyReplys may be reused multiple times. + # We only clear them up on the next handler so that we can access value and + # state in the meantime. + if isinstance(message.reply, DummyReply): + message.reply.reset() + # The following ensures that inheritance with wrapped handlers in the # base class works. If we're the first handler, then responsibility for # acking is ours. If not, it's someone else's and we ignore it. handling = False # We're the first handler - ack responsibility is ours - if not message.reply.handled: + if message.reply.state == "unhandled": handling = True - message.reply.handled = True + message.reply.handle() with master.handlecontext(): ret = f(master, message) if handling: master.addons(f.__name__, message) - if handling and not message.reply.acked and not message.reply.taken: - message.reply.ack() - # Reset the handled flag - it's common for us to feed the same object # through handlers repeatedly, so we don't want this to persist across # calls. - if message.reply.handled: - message.reply.handled = False + if handling and message.reply.state == "handled": + message.reply.take() + if not message.reply.has_message: + message.reply.ack() + message.reply.commit() + + # DummyReplys may be reused multiple times. + if isinstance(message.reply, DummyReply): + message.reply.mark_reset() return ret # Mark this function as a handler wrapper wrapper.__dict__["__handler"] = True return wrapper +NO_REPLY = object() # special object we can distinguish from a valid "None" reply. + + class Reply(object): """ Messages sent through a channel are decorated with a "reply" attribute. @@ -238,53 +252,104 @@ class Reply(object): """ def __init__(self, obj): self.obj = obj - self.q = queue.Queue() - # Has this message been acked? - self.acked = False - # Has the user taken responsibility for ack-ing? - self.taken = False - # Has a handler taken responsibility for ack-ing? - self.handled = False + self.q = queue.Queue() # type: queue.Queue + + self._state = "unhandled" # "unhandled" -> "handled" -> "taken" -> "committed" + self.value = NO_REPLY # holds the reply value. May change before things are actually commited. + + @property + def state(self): + """ + The state the reply is currently in. A normal reply object goes sequentially through the following lifecycle: + + 1. unhandled: Initial State. + 2. handled: The reply object has been handled by the topmost handler function. + 3. taken: The reply object has been taken to be commited. + 4. committed: The reply has been sent back to the requesting party. + + This attribute is read-only and can only be modified by calling one of state transition functions. + """ + return self._state - def ack(self): - self.send(self.obj) + @property + def has_message(self): + return self.value != NO_REPLY - def kill(self): - self.send(exceptions.Kill) + @property + def done(self): + return self.state == "committed" + + def handle(self): + """ + Reply are handled by controller.handlers, which may be nested. The first handler takes + responsibility and handles the reply. + """ + if self.state != "unhandled": + raise exceptions.ControlException("Reply is {}, but expected it to be unhandled.".format(self.state)) + self._state = "handled" def take(self): - self.taken = True + """ + Scripts or other parties make "take" a reply out of a normal flow. + For example, intercepted flows are taken out so that the connection thread does not proceed. + """ + if self.state != "handled": + raise exceptions.ControlException("Reply is {}, but expected it to be handled.".format(self.state)) + self._state = "taken" - def send(self, msg): - if self.acked: - raise exceptions.ControlException("Message already acked.") - self.acked = True - self.q.put(msg) + def commit(self): + """ + Ultimately, messages are commited. This is done either automatically by the handler + if the message is not taken or manually by the entity which called .take(). + """ + if self.state != "taken": + raise exceptions.ControlException("Reply is {}, but expected it to be taken.".format(self.state)) + if not self.has_message: + raise exceptions.ControlException("There is no reply message.") + self._state = "committed" + self.q.put(self.value) + + def ack(self, force=False): + self.send(self.obj, force) + + def kill(self, force=False): + self.send(exceptions.Kill, force) + + def send(self, msg, force=False): + if self.state not in ("handled", "taken"): + raise exceptions.ControlException( + "Reply is {}, did not expect a call to .send().".format(self.state) + ) + if self.has_message and not force: + raise exceptions.ControlException("There is already a reply message.") + self.value = msg def __del__(self): - if not self.acked: + if self.state != "committed": # This will be ignored by the interpreter, but emit a warning - raise exceptions.ControlException("Un-acked message: %s" % self.obj) + raise exceptions.ControlException("Uncommitted reply: %s" % self.obj) -class DummyReply(object): +class DummyReply(Reply): """ - A reply object that does nothing. Useful when we need an object to seem - like it has a channel, and during testing. + A reply object that is not connected to anything. In contrast to regular Reply objects, + DummyReply objects are reset to "unhandled" at the end of an handler so that they can be used + multiple times. Useful when we need an object to seem like it has a channel, + and during testing. """ def __init__(self): - self.acked = False - self.taken = False - self.handled = False + super(DummyReply, self).__init__(None) + self._should_reset = False - def kill(self): - self.send(None) + def mark_reset(self): + if self.state != "committed": + raise exceptions.ControlException("Uncommitted reply: %s" % self.obj) + self._should_reset = True - def ack(self): - self.send(None) + def reset(self): + if self._should_reset: + self._state = "unhandled" + self.value = NO_REPLY - def take(self): - self.taken = True - - def send(self, msg): - self.acked = True + def __del__(self): + pass diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py index 65a95e44..0475ef4e 100644 --- a/mitmproxy/flow/master.py +++ b/mitmproxy/flow/master.py @@ -234,7 +234,7 @@ class FlowMaster(controller.Master): pb = self.do_server_playback(f) if not pb and self.kill_nonreplay: self.add_log("Killed {}".format(f.request.url), "info") - f.kill(self) + f.reply.kill() def replay_request(self, f, block=False): """ @@ -314,8 +314,7 @@ class FlowMaster(controller.Master): return if f not in self.state.flows: # don't add again on replay self.state.add_flow(f) - if not f.reply.acked: - self.process_new_request(f) + self.process_new_request(f) return f @controller.handler @@ -331,9 +330,8 @@ class FlowMaster(controller.Master): @controller.handler def response(self, f): self.state.update_flow(f) - if not f.reply.acked: - if self.client_playback: - self.client_playback.clear(f) + if self.client_playback: + self.client_playback.clear(f) return f def handle_intercept(self, f): diff --git a/mitmproxy/flow/state.py b/mitmproxy/flow/state.py index efcb2d89..8576fadc 100644 --- a/mitmproxy/flow/state.py +++ b/mitmproxy/flow/state.py @@ -178,7 +178,7 @@ class FlowStore(FlowList): def kill_all(self, master): for f in self._list: - if not f.reply.acked: + if f.killable: f.kill(master) diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py index f4a2b54b..fc673274 100644 --- a/mitmproxy/models/flow.py +++ b/mitmproxy/models/flow.py @@ -149,13 +149,22 @@ class Flow(stateobject.StateObject): self.set_state(self._backup) self._backup = None + @property + def killable(self): + return self.reply and self.reply.state in {"handled", "taken"} + def kill(self, master): """ Kill this request. """ self.error = Error("Connection killed") self.intercepted = False - self.reply.kill() + # reply.state should only be "handled" or "taken" here. + # if none of this is the case, .take() will raise an exception. + if self.reply.state != "taken": + self.reply.take() + self.reply.kill(force=True) + self.reply.commit() master.error(self) def intercept(self, master): @@ -166,6 +175,7 @@ class Flow(stateobject.StateObject): if self.intercepted: return self.intercepted = True + self.reply.take() master.handle_intercept(self) def accept_intercept(self, master): @@ -176,6 +186,7 @@ class Flow(stateobject.StateObject): return self.intercepted = False self.reply.ack() + self.reply.commit() master.handle_accept_intercept(self) def match(self, f): diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index 0cc0514e..9ed08065 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -13,7 +13,7 @@ class ScriptThread(basethread.BaseThread): def concurrent(fn): - if fn.__name__ not in controller.Events - set(["start", "configure", "tick"]): + if fn.__name__ not in controller.Events - {"start", "configure", "tick"}: raise NotImplementedError( "Concurrent decorator not supported for '%s' method." % fn.__name__ ) @@ -21,8 +21,10 @@ def concurrent(fn): def _concurrent(obj): def run(): fn(obj) - if not obj.reply.acked: - obj.reply.ack() + if obj.reply.state == "taken": + if not obj.reply.has_message: + obj.reply.ack() + obj.reply.commit() obj.reply.take() ScriptThread( "script.concurrent (%s)" % fn.__name__, diff --git a/mitmproxy/web/app.py b/mitmproxy/web/app.py index f8f85f3d..5bd6f274 100644 --- a/mitmproxy/web/app.py +++ b/mitmproxy/web/app.py @@ -234,7 +234,7 @@ class AcceptFlow(RequestHandler): class FlowHandler(RequestHandler): def delete(self, flow_id): - if not self.flow.reply.acked: + if self.flow.killable: self.flow.kill(self.master) self.state.delete_flow(self.flow) @@ -438,6 +438,7 @@ class Application(tornado.web.Application): xsrf_cookies=True, cookie_secret=os.urandom(256), debug=debug, + autoreload=False, wauthenticator=wauthenticator, ) super(Application, self).__init__(handlers, **settings) diff --git a/mitmproxy/web/master.py b/mitmproxy/web/master.py index 9ddb61d4..5751c9dd 100644 --- a/mitmproxy/web/master.py +++ b/mitmproxy/web/master.py @@ -183,7 +183,6 @@ class WebMaster(flow.FlowMaster): if self.state.intercept and self.state.intercept( f) and not f.request.is_replay: f.intercept(self) - f.reply.take() return f @controller.handler diff --git a/test/mitmproxy/builtins/test_anticache.py b/test/mitmproxy/builtins/test_anticache.py index ac321e26..8897de52 100644 --- a/test/mitmproxy/builtins/test_anticache.py +++ b/test/mitmproxy/builtins/test_anticache.py @@ -14,11 +14,11 @@ class TestAntiCache(mastertest.MasterTest): m.addons.add(o, sa) f = tutils.tflow(resp=True) - self.invoke(m, "request", f) + m.request(f) f = tutils.tflow(resp=True) f.request.headers["if-modified-since"] = "test" f.request.headers["if-none-match"] = "test" - self.invoke(m, "request", f) + m.request(f) assert "if-modified-since" not in f.request.headers assert "if-none-match" not in f.request.headers diff --git a/test/mitmproxy/builtins/test_anticomp.py b/test/mitmproxy/builtins/test_anticomp.py index a5f5a270..af9e4a6a 100644 --- a/test/mitmproxy/builtins/test_anticomp.py +++ b/test/mitmproxy/builtins/test_anticomp.py @@ -14,10 +14,10 @@ class TestAntiComp(mastertest.MasterTest): m.addons.add(o, sa) f = tutils.tflow(resp=True) - self.invoke(m, "request", f) + m.request(f) f = tutils.tflow(resp=True) f.request.headers["Accept-Encoding"] = "foobar" - self.invoke(m, "request", f) + m.request(f) assert f.request.headers["Accept-Encoding"] == "identity" diff --git a/test/mitmproxy/builtins/test_dumper.py b/test/mitmproxy/builtins/test_dumper.py index 1c7173e0..9b518b53 100644 --- a/test/mitmproxy/builtins/test_dumper.py +++ b/test/mitmproxy/builtins/test_dumper.py @@ -80,5 +80,5 @@ class TestContentView(mastertest.MasterTest): m = mastertest.RecordingMaster(o, None, s) d = dumper.Dumper() m.addons.add(o, d) - self.invoke(m, "response", tutils.tflow()) + m.response(tutils.tflow()) assert "Content viewer failed" in m.event_log[0][1] diff --git a/test/mitmproxy/builtins/test_filestreamer.py b/test/mitmproxy/builtins/test_filestreamer.py index 0e69b340..94d68813 100644 --- a/test/mitmproxy/builtins/test_filestreamer.py +++ b/test/mitmproxy/builtins/test_filestreamer.py @@ -28,8 +28,8 @@ class TestStream(mastertest.MasterTest): m.addons.add(o, sa) f = tutils.tflow(resp=True) - self.invoke(m, "request", f) - self.invoke(m, "response", f) + m.request(f) + m.response(f) m.addons.remove(sa) assert r()[0].response @@ -38,6 +38,6 @@ class TestStream(mastertest.MasterTest): m.addons.add(o, sa) f = tutils.tflow() - self.invoke(m, "request", f) + m.request(f) m.addons.remove(sa) assert not r()[1].response diff --git a/test/mitmproxy/builtins/test_replace.py b/test/mitmproxy/builtins/test_replace.py index 5e70ce56..07abcda4 100644 --- a/test/mitmproxy/builtins/test_replace.py +++ b/test/mitmproxy/builtins/test_replace.py @@ -43,10 +43,10 @@ class TestReplace(mastertest.MasterTest): f = tutils.tflow() f.request.content = b"foo" - self.invoke(m, "request", f) + m.request(f) assert f.request.content == b"bar" f = tutils.tflow(resp=True) f.response.content = b"foo" - self.invoke(m, "response", f) + m.response(f) assert f.response.content == b"bar" diff --git a/test/mitmproxy/builtins/test_script.py b/test/mitmproxy/builtins/test_script.py index 2870fd17..0bac6ca0 100644 --- a/test/mitmproxy/builtins/test_script.py +++ b/test/mitmproxy/builtins/test_script.py @@ -69,7 +69,7 @@ class TestScript(mastertest.MasterTest): sc.ns.call_log = [] f = tutils.tflow(resp=True) - self.invoke(m, "request", f) + m.request(f) recf = sc.ns.call_log[0] assert recf[1] == "request" @@ -102,7 +102,7 @@ class TestScript(mastertest.MasterTest): ) m.addons.add(o, sc) f = tutils.tflow(resp=True) - self.invoke(m, "request", f) + m.request(f) assert m.event_log[0][0] == "error" def test_duplicate_flow(self): diff --git a/test/mitmproxy/builtins/test_setheaders.py b/test/mitmproxy/builtins/test_setheaders.py index 41c18360..63685177 100644 --- a/test/mitmproxy/builtins/test_setheaders.py +++ b/test/mitmproxy/builtins/test_setheaders.py @@ -33,12 +33,12 @@ class TestSetHeaders(mastertest.MasterTest): ) f = tutils.tflow() f.request.headers["one"] = "xxx" - self.invoke(m, "request", f) + m.request(f) assert f.request.headers["one"] == "two" f = tutils.tflow(resp=True) f.response.headers["one"] = "xxx" - self.invoke(m, "response", f) + m.response(f) assert f.response.headers["one"] == "three" m, sh = self.mkmaster( @@ -50,7 +50,7 @@ class TestSetHeaders(mastertest.MasterTest): f = tutils.tflow(resp=True) f.request.headers["one"] = "xxx" f.response.headers["one"] = "xxx" - self.invoke(m, "response", f) + m.response(f) assert f.response.headers.get_all("one") == ["two", "three"] m, sh = self.mkmaster( @@ -61,5 +61,5 @@ class TestSetHeaders(mastertest.MasterTest): ) f = tutils.tflow() f.request.headers["one"] = "xxx" - self.invoke(m, "request", f) + m.request(f) assert f.request.headers.get_all("one") == ["two", "three"] diff --git a/test/mitmproxy/builtins/test_stickyauth.py b/test/mitmproxy/builtins/test_stickyauth.py index 5757fb2d..00b12072 100644 --- a/test/mitmproxy/builtins/test_stickyauth.py +++ b/test/mitmproxy/builtins/test_stickyauth.py @@ -15,10 +15,10 @@ class TestStickyAuth(mastertest.MasterTest): f = tutils.tflow(resp=True) f.request.headers["authorization"] = "foo" - self.invoke(m, "request", f) + m.request(f) assert "address" in sa.hosts f = tutils.tflow(resp=True) - self.invoke(m, "request", f) + m.request(f) assert f.request.headers["authorization"] == "foo" diff --git a/test/mitmproxy/builtins/test_stickycookie.py b/test/mitmproxy/builtins/test_stickycookie.py index e9d92c83..13cea6a2 100644 --- a/test/mitmproxy/builtins/test_stickycookie.py +++ b/test/mitmproxy/builtins/test_stickycookie.py @@ -34,23 +34,23 @@ class TestStickyCookie(mastertest.MasterTest): f = tutils.tflow(resp=True) f.response.headers["set-cookie"] = "foo=bar" - self.invoke(m, "request", f) + m.request(f) f.reply.acked = False - self.invoke(m, "response", f) + m.response(f) assert sc.jar assert "cookie" not in f.request.headers f = f.copy() f.reply.acked = False - self.invoke(m, "request", f) + m.request(f) assert f.request.headers["cookie"] == "foo=bar" def _response(self, s, m, sc, cookie, host): f = tutils.tflow(req=ntutils.treq(host=host, port=80), resp=True) f.response.headers["Set-Cookie"] = cookie - self.invoke(m, "response", f) + m.response(f) return f def test_response(self): @@ -79,7 +79,7 @@ class TestStickyCookie(mastertest.MasterTest): c2 = "othercookie=helloworld; Path=/" f = self._response(s, m, sc, c1, "www.google.com") f.response.headers["Set-Cookie"] = c2 - self.invoke(m, "response", f) + m.response(f) googlekey = list(sc.jar.keys())[0] assert len(sc.jar[googlekey].keys()) == 2 @@ -96,7 +96,7 @@ class TestStickyCookie(mastertest.MasterTest): ] for c in cs: f.response.headers["Set-Cookie"] = c - self.invoke(m, "response", f) + m.response(f) googlekey = list(sc.jar.keys())[0] assert len(sc.jar[googlekey].keys()) == len(cs) @@ -108,7 +108,7 @@ class TestStickyCookie(mastertest.MasterTest): c2 = "somecookie=newvalue; Path=/" f = self._response(s, m, sc, c1, "www.google.com") f.response.headers["Set-Cookie"] = c2 - self.invoke(m, "response", f) + m.response(f) googlekey = list(sc.jar.keys())[0] assert len(sc.jar[googlekey].keys()) == 1 assert list(sc.jar[googlekey]["somecookie"].items())[0][1] == "newvalue" @@ -120,7 +120,7 @@ class TestStickyCookie(mastertest.MasterTest): # by setting the expire time in the past f = self._response(s, m, sc, "duffer=zafar; Path=/", "www.google.com") f.response.headers["Set-Cookie"] = "duffer=; Expires=Thu, 01-Jan-1970 00:00:00 GMT" - self.invoke(m, "response", f) + m.response(f) assert not sc.jar.keys() def test_request(self): @@ -128,5 +128,5 @@ class TestStickyCookie(mastertest.MasterTest): f = self._response(s, m, sc, "SSID=mooo", "www.google.com") assert "cookie" not in f.request.headers - self.invoke(m, "request", f) + m.request(f) assert "cookie" in f.request.headers diff --git a/test/mitmproxy/mastertest.py b/test/mitmproxy/mastertest.py index dcc0dc48..08659d19 100644 --- a/test/mitmproxy/mastertest.py +++ b/test/mitmproxy/mastertest.py @@ -1,5 +1,3 @@ -import mock - from . import tutils import netlib.tutils @@ -8,26 +6,19 @@ from mitmproxy import flow, proxy, models, controller class MasterTest: - def invoke(self, master, handler, *message): - with master.handlecontext(): - func = getattr(master, handler) - func(*message) - if message: - message[0].reply = controller.DummyReply() def cycle(self, master, content): f = tutils.tflow(req=netlib.tutils.treq(content=content)) l = proxy.Log("connect") - l.reply = mock.MagicMock() + l.reply = controller.DummyReply() master.log(l) - self.invoke(master, "clientconnect", f.client_conn) - self.invoke(master, "clientconnect", f.client_conn) - self.invoke(master, "serverconnect", f.server_conn) - self.invoke(master, "request", f) + master.clientconnect(f.client_conn) + master.serverconnect(f.server_conn) + master.request(f) if not f.error: f.response = models.HTTPResponse.wrap(netlib.tutils.tresp(content=content)) - self.invoke(master, "response", f) - self.invoke(master, "clientdisconnect", f) + master.response(f) + master.clientdisconnect(f) return f def dummy_cycle(self, master, n, content): diff --git a/test/mitmproxy/script/test_concurrent.py b/test/mitmproxy/script/test_concurrent.py index a5f76994..c4f1e9ae 100644 --- a/test/mitmproxy/script/test_concurrent.py +++ b/test/mitmproxy/script/test_concurrent.py @@ -25,11 +25,11 @@ class TestConcurrent(mastertest.MasterTest): ) m.addons.add(m.options, sc) f1, f2 = tutils.tflow(), tutils.tflow() - self.invoke(m, "request", f1) - self.invoke(m, "request", f2) + m.request(f1) + m.request(f2) start = time.time() while time.time() - start < 5: - if f1.reply.acked and f2.reply.acked: + if f1.reply.state == f2.reply.state == "committed": return raise ValueError("Script never acked") diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py index 6d4b8fe6..abf66b6c 100644 --- a/test/mitmproxy/test_controller.py +++ b/test/mitmproxy/test_controller.py @@ -1,3 +1,4 @@ +from test.mitmproxy import tutils from threading import Thread, Event from mock import Mock @@ -5,7 +6,7 @@ from mock import Mock from mitmproxy import controller from six.moves import queue -from mitmproxy.exceptions import Kill +from mitmproxy.exceptions import Kill, ControlException from mitmproxy.proxy import DummyServer from netlib.tutils import raises @@ -55,7 +56,7 @@ class TestChannel(object): def test_tell(self): q = queue.Queue() channel = controller.Channel(q, Event()) - m = Mock() + m = Mock(name="test_tell") channel.tell("test", m) assert q.get() == ("test", m) assert m.reply @@ -66,12 +67,15 @@ class TestChannel(object): def reply(): m, obj = q.get() assert m == "test" + obj.reply.handle() obj.reply.send(42) + obj.reply.take() + obj.reply.commit() Thread(target=reply).start() channel = controller.Channel(q, Event()) - assert channel.ask("test", Mock()) == 42 + assert channel.ask("test", Mock(name="test_ask_simple")) == 42 def test_ask_shutdown(self): q = queue.Queue() @@ -79,31 +83,125 @@ class TestChannel(object): done.set() channel = controller.Channel(q, done) with raises(Kill): - channel.ask("test", Mock()) - - -class TestDummyReply(object): - def test_simple(self): - reply = controller.DummyReply() - assert not reply.acked - reply.ack() - assert reply.acked + channel.ask("test", Mock(name="test_ask_shutdown")) class TestReply(object): def test_simple(self): reply = controller.Reply(42) - assert not reply.acked + assert reply.state == "unhandled" + + reply.handle() + assert reply.state == "handled" + reply.send("foo") - assert reply.acked + assert reply.value == "foo" + + reply.take() + assert reply.state == "taken" + + with tutils.raises(queue.Empty): + reply.q.get_nowait() + reply.commit() + assert reply.state == "committed" assert reply.q.get() == "foo" - def test_default(self): - reply = controller.Reply(42) + def test_kill(self): + reply = controller.Reply(43) + reply.handle() + reply.kill() + reply.take() + reply.commit() + assert reply.q.get() == Kill + + def test_ack(self): + reply = controller.Reply(44) + reply.handle() reply.ack() - assert reply.q.get() == 42 + reply.take() + reply.commit() + assert reply.q.get() == 44 def test_reply_none(self): - reply = controller.Reply(42) + reply = controller.Reply(45) + reply.handle() reply.send(None) + reply.take() + reply.commit() assert reply.q.get() is None + + def test_commit_no_reply(self): + reply = controller.Reply(46) + reply.handle() + reply.take() + with tutils.raises(ControlException): + reply.commit() + reply.ack() + reply.commit() + + def test_double_send(self): + reply = controller.Reply(47) + reply.handle() + reply.send(1) + with tutils.raises(ControlException): + reply.send(2) + reply.take() + reply.commit() + + def test_state_transitions(self): + states = {"unhandled", "handled", "taken", "committed"} + accept = { + "handle": {"unhandled"}, + "take": {"handled"}, + "commit": {"taken"}, + "ack": {"handled", "taken"}, + } + for fn, ok in accept.items(): + for state in states: + r = controller.Reply(48) + r._state = state + if fn == "commit": + r.value = 49 + if state in ok: + getattr(r, fn)() + else: + with tutils.raises(ControlException): + getattr(r, fn)() + r._state = "committed" # hide warnings on deletion + + def test_del(self): + reply = controller.Reply(47) + with tutils.raises(ControlException): + reply.__del__() + reply.handle() + reply.ack() + reply.take() + reply.commit() + + +class TestDummyReply(object): + def test_simple(self): + reply = controller.DummyReply() + for _ in range(2): + reply.handle() + reply.ack() + reply.take() + reply.commit() + reply.mark_reset() + reply.reset() + assert reply.state == "unhandled" + + def test_reset(self): + reply = controller.DummyReply() + reply.handle() + reply.ack() + reply.take() + reply.commit() + reply.mark_reset() + assert reply.state == "committed" + reply.reset() + assert reply.state == "unhandled" + + def test_del(self): + reply = controller.DummyReply() + reply.__del__() diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index 34fcc261..6c24ace5 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -39,7 +39,7 @@ class TestScripts(mastertest.MasterTest): def test_add_header(self): m, _ = tscript("add_header.py") f = tutils.tflow(resp=netutils.tresp()) - self.invoke(m, "response", f) + m.response(f) assert f.response.headers["newheader"] == "foo" def test_custom_contentviews(self): @@ -54,9 +54,9 @@ class TestScripts(mastertest.MasterTest): tscript("iframe_injector.py") m, sc = tscript("iframe_injector.py", "http://example.org/evil_iframe") - flow = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>")) - self.invoke(m, "response", flow) - content = flow.response.content + f = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>")) + m.response(f) + content = f.response.content assert b'iframe' in content and b'evil_iframe' in content def test_modify_form(self): @@ -64,23 +64,23 @@ class TestScripts(mastertest.MasterTest): form_header = Headers(content_type="application/x-www-form-urlencoded") f = tutils.tflow(req=netutils.treq(headers=form_header)) - self.invoke(m, "request", f) + m.request(f) assert f.request.urlencoded_form[b"mitmproxy"] == b"rocks" f.request.headers["content-type"] = "" - self.invoke(m, "request", f) + m.request(f) assert list(f.request.urlencoded_form.items()) == [(b"foo", b"bar")] def test_modify_querystring(self): m, sc = tscript("modify_querystring.py") f = tutils.tflow(req=netutils.treq(path="/search?q=term")) - self.invoke(m, "request", f) + m.request(f) assert f.request.query["mitmproxy"] == "rocks" f.request.path = "/" - self.invoke(m, "request", f) + m.request(f) assert f.request.query["mitmproxy"] == "rocks" def test_modify_response_body(self): @@ -89,13 +89,13 @@ class TestScripts(mastertest.MasterTest): m, sc = tscript("modify_response_body.py", "mitmproxy rocks") f = tutils.tflow(resp=netutils.tresp(content=b"I <3 mitmproxy")) - self.invoke(m, "response", f) + m.response(f) assert f.response.content == b"I <3 rocks" def test_redirect_requests(self): m, sc = tscript("redirect_requests.py") f = tutils.tflow(req=netutils.treq(host="example.org")) - self.invoke(m, "request", f) + m.request(f) assert f.request.host == "mitmproxy.org" def test_har_extractor(self): @@ -119,7 +119,7 @@ class TestScripts(mastertest.MasterTest): req=netutils.treq(**times), resp=netutils.tresp(**times) ) - self.invoke(m, "response", f) + m.response(f) m.addons.remove(sc) with open(path, "rb") as f: diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index d4bf764c..1caeb100 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -3,9 +3,9 @@ import io import netlib.utils from netlib.http import Headers -from mitmproxy import filt, controller, flow, options +from mitmproxy import filt, flow, options from mitmproxy.contrib import tnetstring -from mitmproxy.exceptions import FlowReadException +from mitmproxy.exceptions import FlowReadException, Kill from mitmproxy.models import Error from mitmproxy.models import Flow from mitmproxy.models import HTTPFlow @@ -372,19 +372,23 @@ class TestHTTPFlow(object): assert f.get_state() == f2.get_state() def test_kill(self): - s = flow.State() - fm = flow.FlowMaster(None, None, s) + fm = mock.Mock() f = tutils.tflow() - f.intercept(mock.Mock()) + f.reply.handle() + f.intercept(fm) + assert fm.handle_intercept.called + assert f.killable f.kill(fm) - for i in s.view: - assert "killed" in str(i.error) + assert not f.killable + assert fm.error.called + assert f.reply.value == Kill def test_killall(self): s = flow.State() fm = flow.FlowMaster(None, None, s) f = tutils.tflow() + f.reply.handle() f.intercept(fm) s.killall(fm) @@ -393,11 +397,11 @@ class TestHTTPFlow(object): def test_accept_intercept(self): f = tutils.tflow() - + f.reply.handle() f.intercept(mock.Mock()) - assert not f.reply.acked + assert f.reply.state == "taken" f.accept_intercept(mock.Mock()) - assert f.reply.acked + assert f.reply.state == "committed" def test_replace_unicode(self): f = tutils.tflow(resp=True) @@ -735,7 +739,6 @@ class TestFlowMaster: fm.clientdisconnect(f.client_conn) f.error = Error("msg") - f.error.reply = controller.DummyReply() fm.error(f) fm.shutdown() @@ -834,8 +837,8 @@ class TestFlowMaster: f = tutils.tflow() f.request.host = "nonexistent" - fm.process_new_request(f) - assert "killed" in f.error.msg + fm.request(f) + assert f.reply.value == Kill class TestRequest: |