diff options
Diffstat (limited to 'mitmproxy/test/test_flow.py')
| -rw-r--r-- | mitmproxy/test/test_flow.py | 1333 | 
1 files changed, 1333 insertions, 0 deletions
diff --git a/mitmproxy/test/test_flow.py b/mitmproxy/test/test_flow.py new file mode 100644 index 00000000..b122489f --- /dev/null +++ b/mitmproxy/test/test_flow.py @@ -0,0 +1,1333 @@ +import Queue +import time +import os.path +from cStringIO import StringIO +import email.utils + +import mock + +import netlib.utils +from netlib import odict +from netlib.http import CONTENT_MISSING, Headers +from libmproxy import filt, controller, tnetstring, flow +from libmproxy.models import Error +from libmproxy.models import Flow +from libmproxy.models import HTTPFlow +from libmproxy.models import HTTPRequest +from libmproxy.models import HTTPResponse +from libmproxy.proxy.config import HostMatcher +from libmproxy.proxy import ProxyConfig +from libmproxy.proxy.server import DummyServer +from libmproxy.models.connections import ClientConnection +from . import tutils + + +def test_app_registry(): +    ar = flow.AppRegistry() +    ar.add("foo", "domain", 80) + +    r = HTTPRequest.wrap(netlib.tutils.treq()) +    r.host = "domain" +    r.port = 80 +    assert ar.get(r) + +    r.port = 81 +    assert not ar.get(r) + +    r = HTTPRequest.wrap(netlib.tutils.treq()) +    r.host = "domain2" +    r.port = 80 +    assert not ar.get(r) +    r.headers["host"] = "domain" +    assert ar.get(r) + + +class TestStickyCookieState: + +    def _response(self, cookie, host): +        s = flow.StickyCookieState(filt.parse(".*")) +        f = tutils.tflow(req=netlib.tutils.treq(host=host, port=80), resp=True) +        f.response.headers["Set-Cookie"] = cookie +        s.handle_response(f) +        return s, f + +    def test_domain_match(self): +        s = flow.StickyCookieState(filt.parse(".*")) +        assert s.domain_match("www.google.com", ".google.com") +        assert s.domain_match("google.com", ".google.com") + +    def test_handle_response(self): +        c = "SSID=mooo; domain=.google.com, FOO=bar; Domain=.google.com; Path=/; "\ +            "Expires=Wed, 13-Jan-2021 22:23:01 GMT; Secure; " + +        s, f = self._response(c, "host") +        assert not s.jar.keys() + +        s, f = self._response(c, "www.google.com") +        assert s.jar.keys() + +        s, f = self._response("SSID=mooo", "www.google.com") +        assert s.jar.keys()[0] == ('www.google.com', 80, '/') + +    def test_handle_request(self): +        s, f = self._response("SSID=mooo", "www.google.com") +        assert "cookie" not in f.request.headers +        s.handle_request(f) +        assert "cookie" in f.request.headers + + +class TestStickyAuthState: + +    def test_handle_response(self): +        s = flow.StickyAuthState(filt.parse(".*")) +        f = tutils.tflow(resp=True) +        f.request.headers["authorization"] = "foo" +        s.handle_request(f) +        assert "address" in s.hosts + +        f = tutils.tflow(resp=True) +        s.handle_request(f) +        assert f.request.headers["authorization"] == "foo" + + +class TestClientPlaybackState: + +    def test_tick(self): +        first = tutils.tflow() +        s = flow.State() +        fm = flow.FlowMaster(None, s) +        fm.start_client_playback([first, tutils.tflow()], True) +        c = fm.client_playback +        c.testing = True + +        assert not c.done() +        assert not s.flow_count() +        assert c.count() == 2 +        c.tick(fm) +        assert s.flow_count() +        assert c.count() == 1 + +        c.tick(fm) +        assert c.count() == 1 + +        c.clear(c.current) +        c.tick(fm) +        assert c.count() == 0 +        c.clear(c.current) +        assert c.done() + +        q = Queue.Queue() +        fm.state.clear() +        fm.tick(q, timeout=0) + +        fm.stop_client_playback() +        assert not fm.client_playback + + +class TestServerPlaybackState: + +    def test_hash(self): +        s = flow.ServerPlaybackState( +            None, +            [], +            False, +            False, +            None, +            False, +            None, +            False) +        r = tutils.tflow() +        r2 = tutils.tflow() + +        assert s._hash(r) +        assert s._hash(r) == s._hash(r2) +        r.request.headers["foo"] = "bar" +        assert s._hash(r) == s._hash(r2) +        r.request.path = "voing" +        assert s._hash(r) != s._hash(r2) + +        r.request.path = "path?blank_value" +        r2.request.path = "path?" +        assert s._hash(r) != s._hash(r2) + +    def test_headers(self): +        s = flow.ServerPlaybackState( +            ["foo"], +            [], +            False, +            False, +            None, +            False, +            None, +            False) +        r = tutils.tflow(resp=True) +        r.request.headers["foo"] = "bar" +        r2 = tutils.tflow(resp=True) +        assert not s._hash(r) == s._hash(r2) +        r2.request.headers["foo"] = "bar" +        assert s._hash(r) == s._hash(r2) +        r2.request.headers["oink"] = "bar" +        assert s._hash(r) == s._hash(r2) + +        r = tutils.tflow(resp=True) +        r2 = tutils.tflow(resp=True) +        assert s._hash(r) == s._hash(r2) + +    def test_load(self): +        r = tutils.tflow(resp=True) +        r.request.headers["key"] = "one" + +        r2 = tutils.tflow(resp=True) +        r2.request.headers["key"] = "two" + +        s = flow.ServerPlaybackState( +            None, [ +                r, r2], False, False, None, False, None, False) +        assert s.count() == 2 +        assert len(s.fmap.keys()) == 1 + +        n = s.next_flow(r) +        assert n.request.headers["key"] == "one" +        assert s.count() == 1 + +        n = s.next_flow(r) +        assert n.request.headers["key"] == "two" +        assert s.count() == 0 + +        assert not s.next_flow(r) + +    def test_load_with_nopop(self): +        r = tutils.tflow(resp=True) +        r.request.headers["key"] = "one" + +        r2 = tutils.tflow(resp=True) +        r2.request.headers["key"] = "two" + +        s = flow.ServerPlaybackState( +            None, [ +                r, r2], False, True, None, False, None, False) + +        assert s.count() == 2 +        s.next_flow(r) +        assert s.count() == 2 + +    def test_ignore_params(self): +        s = flow.ServerPlaybackState( +            None, [], False, False, [ +                "param1", "param2"], False, None, False) +        r = tutils.tflow(resp=True) +        r.request.path = "/test?param1=1" +        r2 = tutils.tflow(resp=True) +        r2.request.path = "/test" +        assert s._hash(r) == s._hash(r2) +        r2.request.path = "/test?param1=2" +        assert s._hash(r) == s._hash(r2) +        r2.request.path = "/test?param2=1" +        assert s._hash(r) == s._hash(r2) +        r2.request.path = "/test?param3=2" +        assert not s._hash(r) == s._hash(r2) + +    def test_ignore_payload_params(self): +        s = flow.ServerPlaybackState( +            None, [], False, False, None, False, [ +                "param1", "param2"], False) +        r = tutils.tflow(resp=True) +        r.request.headers["Content-Type"] = "application/x-www-form-urlencoded" +        r.request.content = "paramx=x¶m1=1" +        r2 = tutils.tflow(resp=True) +        r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded" +        r2.request.content = "paramx=x¶m1=1" +        # same parameters +        assert s._hash(r) == s._hash(r2) +        # ignored parameters != +        r2.request.content = "paramx=x¶m1=2" +        assert s._hash(r) == s._hash(r2) +        # missing parameter +        r2.request.content = "paramx=x" +        assert s._hash(r) == s._hash(r2) +        # ignorable parameter added +        r2.request.content = "paramx=x¶m1=2" +        assert s._hash(r) == s._hash(r2) +        # not ignorable parameter changed +        r2.request.content = "paramx=y¶m1=1" +        assert not s._hash(r) == s._hash(r2) +        # not ignorable parameter missing +        r2.request.content = "param1=1" +        assert not s._hash(r) == s._hash(r2) + +    def test_ignore_payload_params_other_content_type(self): +        s = flow.ServerPlaybackState( +            None, [], False, False, None, False, [ +                "param1", "param2"], False) +        r = tutils.tflow(resp=True) +        r.request.headers["Content-Type"] = "application/json" +        r.request.content = '{"param1":"1"}' +        r2 = tutils.tflow(resp=True) +        r2.request.headers["Content-Type"] = "application/json" +        r2.request.content = '{"param1":"1"}' +        # same content +        assert s._hash(r) == s._hash(r2) +        # distint content (note only x-www-form-urlencoded payload is analysed) +        r2.request.content = '{"param1":"2"}' +        assert not s._hash(r) == s._hash(r2) + +    def test_ignore_payload_wins_over_params(self): +        # NOTE: parameters are mutually exclusive in options +        s = flow.ServerPlaybackState( +            None, [], False, False, None, True, [ +                "param1", "param2"], False) +        r = tutils.tflow(resp=True) +        r.request.headers["Content-Type"] = "application/x-www-form-urlencoded" +        r.request.content = "paramx=y" +        r2 = tutils.tflow(resp=True) +        r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded" +        r2.request.content = "paramx=x" +        # same parameters +        assert s._hash(r) == s._hash(r2) + +    def test_ignore_content(self): +        s = flow.ServerPlaybackState( +            None, +            [], +            False, +            False, +            None, +            False, +            None, +            False) +        r = tutils.tflow(resp=True) +        r2 = tutils.tflow(resp=True) + +        r.request.content = "foo" +        r2.request.content = "foo" +        assert s._hash(r) == s._hash(r2) +        r2.request.content = "bar" +        assert not s._hash(r) == s._hash(r2) + +        # now ignoring content +        s = flow.ServerPlaybackState( +            None, +            [], +            False, +            False, +            None, +            True, +            None, +            False) +        r = tutils.tflow(resp=True) +        r2 = tutils.tflow(resp=True) +        r.request.content = "foo" +        r2.request.content = "foo" +        assert s._hash(r) == s._hash(r2) +        r2.request.content = "bar" +        assert s._hash(r) == s._hash(r2) +        r2.request.content = "" +        assert s._hash(r) == s._hash(r2) +        r2.request.content = None +        assert s._hash(r) == s._hash(r2) + +    def test_ignore_host(self): +        s = flow.ServerPlaybackState( +            None, +            [], +            False, +            False, +            None, +            False, +            None, +            True) +        r = tutils.tflow(resp=True) +        r2 = tutils.tflow(resp=True) + +        r.request.host = "address" +        r2.request.host = "address" +        assert s._hash(r) == s._hash(r2) +        r2.request.host = "wrong_address" +        assert s._hash(r) == s._hash(r2) + + +class TestFlow(object): + +    def test_copy(self): +        f = tutils.tflow(resp=True) +        f.get_state() +        f2 = f.copy() +        a = f.get_state() +        b = f2.get_state() +        del a["id"] +        del b["id"] +        assert a == b +        assert not f == f2 +        assert not f is f2 +        assert f.request.get_state() == f2.request.get_state() +        assert not f.request is f2.request +        assert f.request.headers == f2.request.headers +        assert not f.request.headers is f2.request.headers +        assert f.response.get_state() == f2.response.get_state() +        assert not f.response is f2.response + +        f = tutils.tflow(err=True) +        f2 = f.copy() +        assert not f is f2 +        assert not f.request is f2.request +        assert f.request.headers == f2.request.headers +        assert not f.request.headers is f2.request.headers +        assert f.error.get_state() == f2.error.get_state() +        assert not f.error is f2.error + +    def test_match(self): +        f = tutils.tflow(resp=True) +        assert not f.match("~b test") +        assert f.match(None) +        assert not f.match("~b test") + +        f = tutils.tflow(err=True) +        assert f.match("~e") + +        tutils.raises(ValueError, f.match, "~") + +    def test_backup(self): +        f = tutils.tflow() +        f.response = HTTPResponse.wrap(netlib.tutils.tresp()) +        f.request.content = "foo" +        assert not f.modified() +        f.backup() +        f.request.content = "bar" +        assert f.modified() +        f.revert() +        assert f.request.content == "foo" + +    def test_backup_idempotence(self): +        f = tutils.tflow(resp=True) +        f.backup() +        f.revert() +        f.backup() +        f.revert() + +    def test_getset_state(self): +        f = tutils.tflow(resp=True) +        state = f.get_state() +        assert f.get_state() == HTTPFlow.from_state( +            state).get_state() + +        f.response = None +        f.error = Error("error") +        state = f.get_state() +        assert f.get_state() == HTTPFlow.from_state( +            state).get_state() + +        f2 = f.copy() +        f2.id = f.id  # copy creates a different uuid +        assert f.get_state() == f2.get_state() +        assert not f == f2 +        f2.error = Error("e2") +        assert not f == f2 +        f.set_state(f2.get_state()) +        assert f.get_state() == f2.get_state() + +    def test_kill(self): +        s = flow.State() +        fm = flow.FlowMaster(None, s) +        f = tutils.tflow() +        f.intercept(mock.Mock()) +        assert not f.reply.acked +        f.kill(fm) +        assert f.reply.acked + +    def test_killall(self): +        s = flow.State() +        fm = flow.FlowMaster(None, s) + +        f = tutils.tflow() +        fm.handle_request(f) + +        f = tutils.tflow() +        fm.handle_request(f) + +        for i in s.view: +            assert not i.reply.acked +        s.killall(fm) +        for i in s.view: +            assert i.reply.acked + +    def test_accept_intercept(self): +        f = tutils.tflow() + +        f.intercept(mock.Mock()) +        assert not f.reply.acked +        f.accept_intercept(mock.Mock()) +        assert f.reply.acked + +    def test_replace_unicode(self): +        f = tutils.tflow(resp=True) +        f.response.content = "\xc2foo" +        f.replace("foo", u"bar") + +    def test_replace_no_content(self): +        f = tutils.tflow() +        f.request.content = CONTENT_MISSING +        assert f.replace("foo", "bar") == 0 + +    def test_replace(self): +        f = tutils.tflow(resp=True) +        f.request.headers["foo"] = "foo" +        f.request.content = "afoob" + +        f.response.headers["foo"] = "foo" +        f.response.content = "afoob" + +        assert f.replace("foo", "bar") == 6 + +        assert f.request.headers["bar"] == "bar" +        assert f.request.content == "abarb" +        assert f.response.headers["bar"] == "bar" +        assert f.response.content == "abarb" + +    def test_replace_encoded(self): +        f = tutils.tflow(resp=True) +        f.request.content = "afoob" +        f.request.encode("gzip") +        f.response.content = "afoob" +        f.response.encode("gzip") + +        f.replace("foo", "bar") + +        assert f.request.content != "abarb" +        f.request.decode() +        assert f.request.content == "abarb" + +        assert f.response.content != "abarb" +        f.response.decode() +        assert f.response.content == "abarb" + + +class TestState: + +    def test_backup(self): +        c = flow.State() +        f = tutils.tflow() +        c.add_flow(f) +        f.backup() +        c.revert(f) + +    def test_flow(self): +        """ +            normal flow: + +                connect -> request -> response +        """ +        c = flow.State() +        f = tutils.tflow() +        c.add_flow(f) +        assert f +        assert c.flow_count() == 1 +        assert c.active_flow_count() == 1 + +        newf = tutils.tflow() +        assert c.add_flow(newf) +        assert c.active_flow_count() == 2 + +        f.response = HTTPResponse.wrap(netlib.tutils.tresp()) +        assert c.update_flow(f) +        assert c.flow_count() == 2 +        assert c.active_flow_count() == 1 + +        assert not c.update_flow(None) +        assert c.active_flow_count() == 1 + +        newf.response = HTTPResponse.wrap(netlib.tutils.tresp()) +        assert c.update_flow(newf) +        assert c.active_flow_count() == 0 + +    def test_err(self): +        c = flow.State() +        f = tutils.tflow() +        c.add_flow(f) +        f.error = Error("message") +        assert c.update_flow(f) + +        c = flow.State() +        f = tutils.tflow() +        c.add_flow(f) +        c.set_limit("~e") +        assert not c.view +        f.error = tutils.terr() +        assert c.update_flow(f) +        assert c.view + +    def test_set_limit(self): +        c = flow.State() + +        f = tutils.tflow() +        assert len(c.view) == 0 + +        c.add_flow(f) +        assert len(c.view) == 1 + +        c.set_limit("~s") +        assert c.limit_txt == "~s" +        assert len(c.view) == 0 +        f.response = HTTPResponse.wrap(netlib.tutils.tresp()) +        c.update_flow(f) +        assert len(c.view) == 1 +        c.set_limit(None) +        assert len(c.view) == 1 + +        f = tutils.tflow() +        c.add_flow(f) +        assert len(c.view) == 2 +        c.set_limit("~q") +        assert len(c.view) == 1 +        c.set_limit("~s") +        assert len(c.view) == 1 + +        assert "Invalid" in c.set_limit("~") + +    def test_set_intercept(self): +        c = flow.State() +        assert not c.set_intercept("~q") +        assert c.intercept_txt == "~q" +        assert "Invalid" in c.set_intercept("~") +        assert not c.set_intercept(None) +        assert c.intercept_txt is None + +    def _add_request(self, state): +        f = tutils.tflow() +        state.add_flow(f) +        return f + +    def _add_response(self, state): +        f = tutils.tflow() +        state.add_flow(f) +        f.response = HTTPResponse.wrap(netlib.tutils.tresp()) +        state.update_flow(f) + +    def _add_error(self, state): +        f = tutils.tflow(err=True) +        state.add_flow(f) + +    def test_clear(self): +        c = flow.State() +        f = self._add_request(c) +        f.intercepted = True + +        c.clear() +        assert c.flow_count() == 0 + +    def test_dump_flows(self): +        c = flow.State() +        self._add_request(c) +        self._add_response(c) +        self._add_request(c) +        self._add_response(c) +        self._add_request(c) +        self._add_response(c) +        self._add_error(c) + +        flows = c.view[:] +        c.clear() + +        c.load_flows(flows) +        assert isinstance(c.flows[0], Flow) + +    def test_accept_all(self): +        c = flow.State() +        self._add_request(c) +        self._add_response(c) +        self._add_request(c) +        c.accept_all(mock.Mock()) + + +class TestSerialize: + +    def _treader(self): +        sio = StringIO() +        w = flow.FlowWriter(sio) +        for i in range(3): +            f = tutils.tflow(resp=True) +            w.add(f) +        for i in range(3): +            f = tutils.tflow(err=True) +            w.add(f) + +        sio.seek(0) +        return flow.FlowReader(sio) + +    def test_roundtrip(self): +        sio = StringIO() +        f = tutils.tflow() +        f.request.content = "".join(chr(i) for i in range(255)) +        w = flow.FlowWriter(sio) +        w.add(f) + +        sio.seek(0) +        r = flow.FlowReader(sio) +        l = list(r.stream()) +        assert len(l) == 1 + +        f2 = l[0] +        assert f2.get_state() == f.get_state() +        assert f2.request == f.request + +    def test_load_flows(self): +        r = self._treader() +        s = flow.State() +        fm = flow.FlowMaster(None, s) +        fm.load_flows(r) +        assert len(s.flows) == 6 + +    def test_load_flows_reverse(self): +        r = self._treader() +        s = flow.State() +        conf = ProxyConfig( +            mode="reverse", +            upstream_server=("https", ("use-this-domain", 80)) +        ) +        fm = flow.FlowMaster(DummyServer(conf), s) +        fm.load_flows(r) +        assert s.flows[0].request.host == "use-this-domain" + +    def test_filter(self): +        sio = StringIO() +        fl = filt.parse("~c 200") +        w = flow.FilteredFlowWriter(sio, fl) + +        f = tutils.tflow(resp=True) +        f.response.status_code = 200 +        w.add(f) + +        f = tutils.tflow(resp=True) +        f.response.status_code = 201 +        w.add(f) + +        sio.seek(0) +        r = flow.FlowReader(sio) +        assert len(list(r.stream())) + +    def test_error(self): +        sio = StringIO() +        sio.write("bogus") +        sio.seek(0) +        r = flow.FlowReader(sio) +        tutils.raises(flow.FlowReadError, list, r.stream()) + +        f = flow.FlowReadError("foo") +        assert f.strerror == "foo" + +    def test_versioncheck(self): +        f = tutils.tflow() +        d = f.get_state() +        d["version"] = (0, 0) +        sio = StringIO() +        tnetstring.dump(d, sio) +        sio.seek(0) + +        r = flow.FlowReader(sio) +        tutils.raises("version", list, r.stream()) + + +class TestFlowMaster: + +    def test_load_script(self): +        s = flow.State() +        fm = flow.FlowMaster(None, s) +        assert not fm.load_script(tutils.test_data.path("scripts/a.py")) +        assert not fm.load_script(tutils.test_data.path("scripts/a.py")) +        assert not fm.unload_scripts() +        assert fm.load_script("nonexistent") +        assert "ValueError" in fm.load_script( +            tutils.test_data.path("scripts/starterr.py")) +        assert len(fm.scripts) == 0 + +    def test_getset_ignore(self): +        p = mock.Mock() +        p.config.check_ignore = HostMatcher() +        fm = flow.FlowMaster(p, flow.State()) +        assert not fm.get_ignore_filter() +        fm.set_ignore_filter(["^apple\.com:", ":443$"]) +        assert fm.get_ignore_filter() + +    def test_replay(self): +        s = flow.State() +        fm = flow.FlowMaster(None, s) +        f = tutils.tflow(resp=True) +        f.request.content = CONTENT_MISSING +        assert "missing" in fm.replay_request(f) + +        f.intercepted = True +        assert "intercepting" in fm.replay_request(f) + +        f.live = True +        assert "live" in fm.replay_request(f, run_scripthooks=True) + +    def test_script_reqerr(self): +        s = flow.State() +        fm = flow.FlowMaster(None, s) +        assert not fm.load_script(tutils.test_data.path("scripts/reqerr.py")) +        f = tutils.tflow() +        fm.handle_clientconnect(f.client_conn) +        assert fm.handle_request(f) + +    def test_script(self): +        s = flow.State() +        fm = flow.FlowMaster(None, s) +        assert not fm.load_script(tutils.test_data.path("scripts/all.py")) +        f = tutils.tflow(resp=True) + +        fm.handle_clientconnect(f.client_conn) +        assert fm.scripts[0].ns["log"][-1] == "clientconnect" +        fm.handle_serverconnect(f.server_conn) +        assert fm.scripts[0].ns["log"][-1] == "serverconnect" +        fm.handle_request(f) +        assert fm.scripts[0].ns["log"][-1] == "request" +        fm.handle_response(f) +        assert fm.scripts[0].ns["log"][-1] == "response" +        # load second script +        assert not fm.load_script(tutils.test_data.path("scripts/all.py")) +        assert len(fm.scripts) == 2 +        fm.handle_clientdisconnect(f.server_conn) +        assert fm.scripts[0].ns["log"][-1] == "clientdisconnect" +        assert fm.scripts[1].ns["log"][-1] == "clientdisconnect" + +        # unload first script +        fm.unload_scripts() +        assert len(fm.scripts) == 0 +        assert not fm.load_script(tutils.test_data.path("scripts/all.py")) + +        f.error = tutils.terr() +        fm.handle_error(f) +        assert fm.scripts[0].ns["log"][-1] == "error" + +    def test_duplicate_flow(self): +        s = flow.State() +        fm = flow.FlowMaster(None, s) +        f = tutils.tflow(resp=True) +        f = fm.load_flow(f) +        assert s.flow_count() == 1 +        f2 = fm.duplicate_flow(f) +        assert f2.response +        assert s.flow_count() == 2 +        assert s.index(f2) == 1 + +    def test_all(self): +        s = flow.State() +        fm = flow.FlowMaster(None, s) +        fm.anticache = True +        fm.anticomp = True +        f = tutils.tflow(req=None) +        fm.handle_clientconnect(f.client_conn) +        f.request = HTTPRequest.wrap(netlib.tutils.treq()) +        fm.handle_request(f) +        assert s.flow_count() == 1 + +        f.response = HTTPResponse.wrap(netlib.tutils.tresp()) +        fm.handle_response(f) +        assert not fm.handle_response(None) +        assert s.flow_count() == 1 + +        fm.handle_clientdisconnect(f.client_conn) + +        f.error = Error("msg") +        f.error.reply = controller.DummyReply() +        fm.handle_error(f) + +        fm.load_script(tutils.test_data.path("scripts/a.py")) +        fm.shutdown() + +    def test_client_playback(self): +        s = flow.State() + +        f = tutils.tflow(resp=True) +        pb = [tutils.tflow(resp=True), f] +        fm = flow.FlowMaster(DummyServer(ProxyConfig()), s) +        assert not fm.start_server_playback( +            pb, +            False, +            [], +            False, +            False, +            None, +            False, +            None, +            False) +        assert not fm.start_client_playback(pb, False) +        fm.client_playback.testing = True + +        q = Queue.Queue() +        assert not fm.state.flow_count() +        fm.tick(q, 0) +        assert fm.state.flow_count() + +        f.error = Error("error") +        fm.handle_error(f) + +    def test_server_playback(self): +        s = flow.State() + +        f = tutils.tflow() +        f.response = HTTPResponse.wrap(netlib.tutils.tresp(content=f.request)) +        pb = [f] + +        fm = flow.FlowMaster(None, s) +        fm.refresh_server_playback = True +        assert not fm.do_server_playback(tutils.tflow()) + +        fm.start_server_playback( +            pb, +            False, +            [], +            False, +            False, +            None, +            False, +            None, +            False) +        assert fm.do_server_playback(tutils.tflow()) + +        fm.start_server_playback( +            pb, +            False, +            [], +            True, +            False, +            None, +            False, +            None, +            False) +        r = tutils.tflow() +        r.request.content = "gibble" +        assert not fm.do_server_playback(r) +        assert fm.do_server_playback(tutils.tflow()) + +        fm.start_server_playback( +            pb, +            False, +            [], +            True, +            False, +            None, +            False, +            None, +            False) +        q = Queue.Queue() +        fm.tick(q, 0) +        assert fm.should_exit.is_set() + +        fm.stop_server_playback() +        assert not fm.server_playback + +    def test_server_playback_kill(self): +        s = flow.State() +        f = tutils.tflow() +        f.response = HTTPResponse.wrap(netlib.tutils.tresp(content=f.request)) +        pb = [f] +        fm = flow.FlowMaster(None, s) +        fm.refresh_server_playback = True +        fm.start_server_playback( +            pb, +            True, +            [], +            False, +            False, +            None, +            False, +            None, +            False) + +        f = tutils.tflow() +        f.request.host = "nonexistent" +        fm.process_new_request(f) +        assert "killed" in f.error.msg + +    def test_stickycookie(self): +        s = flow.State() +        fm = flow.FlowMaster(None, s) +        assert "Invalid" in fm.set_stickycookie("~h") +        fm.set_stickycookie(".*") +        assert fm.stickycookie_state +        fm.set_stickycookie(None) +        assert not fm.stickycookie_state + +        fm.set_stickycookie(".*") +        f = tutils.tflow(resp=True) +        f.response.headers["set-cookie"] = "foo=bar" +        fm.handle_request(f) +        fm.handle_response(f) +        assert fm.stickycookie_state.jar +        assert not "cookie" in f.request.headers +        f = f.copy() +        fm.handle_request(f) +        assert f.request.headers["cookie"] == "foo=bar" + +    def test_stickyauth(self): +        s = flow.State() +        fm = flow.FlowMaster(None, s) +        assert "Invalid" in fm.set_stickyauth("~h") +        fm.set_stickyauth(".*") +        assert fm.stickyauth_state +        fm.set_stickyauth(None) +        assert not fm.stickyauth_state + +        fm.set_stickyauth(".*") +        f = tutils.tflow(resp=True) +        f.request.headers["authorization"] = "foo" +        fm.handle_request(f) + +        f = tutils.tflow(resp=True) +        assert fm.stickyauth_state.hosts +        assert not "authorization" in f.request.headers +        fm.handle_request(f) +        assert f.request.headers["authorization"] == "foo" + +    def test_stream(self): +        with tutils.tmpdir() as tdir: +            p = os.path.join(tdir, "foo") + +            def r(): +                r = flow.FlowReader(open(p, "rb")) +                return list(r.stream()) + +            s = flow.State() +            fm = flow.FlowMaster(None, s) +            f = tutils.tflow(resp=True) + +            fm.start_stream(file(p, "ab"), None) +            fm.handle_request(f) +            fm.handle_response(f) +            fm.stop_stream() + +            assert r()[0].response + +            f = tutils.tflow() +            fm.start_stream(file(p, "ab"), None) +            fm.handle_request(f) +            fm.shutdown() + +            assert not r()[1].response + + +class TestRequest: + +    def test_simple(self): +        f = tutils.tflow() +        r = f.request +        u = r.url +        r.url = u +        tutils.raises(ValueError, setattr, r, "url", "") +        assert r.url == u +        r2 = r.copy() +        assert r.get_state() == r2.get_state() + +    def test_get_url(self): +        r = HTTPRequest.wrap(netlib.tutils.treq()) + +        assert r.url == "http://address:22/path" + +        r.scheme = "https" +        assert r.url == "https://address:22/path" + +        r.host = "host" +        r.port = 42 +        assert r.url == "https://host:42/path" + +        r.host = "address" +        r.port = 22 +        assert r.url == "https://address:22/path" + +        assert r.pretty_url == "https://address:22/path" +        r.headers["Host"] = "foo.com" +        assert r.url == "https://address:22/path" +        assert r.pretty_url == "https://foo.com:22/path" + +    def test_path_components(self): +        r = HTTPRequest.wrap(netlib.tutils.treq()) +        r.path = "/" +        assert r.get_path_components() == [] +        r.path = "/foo/bar" +        assert r.get_path_components() == ["foo", "bar"] +        q = odict.ODict() +        q["test"] = ["123"] +        r.set_query(q) +        assert r.get_path_components() == ["foo", "bar"] + +        r.set_path_components([]) +        assert r.get_path_components() == [] +        r.set_path_components(["foo"]) +        assert r.get_path_components() == ["foo"] +        r.set_path_components(["/oo"]) +        assert r.get_path_components() == ["/oo"] +        assert "%2F" in r.path + +    def test_getset_form_urlencoded(self): +        d = odict.ODict([("one", "two"), ("three", "four")]) +        r = HTTPRequest.wrap(netlib.tutils.treq(content=netlib.utils.urlencode(d.lst))) +        r.headers["content-type"] = "application/x-www-form-urlencoded" +        assert r.get_form_urlencoded() == d + +        d = odict.ODict([("x", "y")]) +        r.set_form_urlencoded(d) +        assert r.get_form_urlencoded() == d + +        r.headers["content-type"] = "foo" +        assert not r.get_form_urlencoded() + +    def test_getset_query(self): +        r = HTTPRequest.wrap(netlib.tutils.treq()) +        r.path = "/foo?x=y&a=b" +        q = r.get_query() +        assert q.lst == [("x", "y"), ("a", "b")] + +        r.path = "/" +        q = r.get_query() +        assert not q + +        r.path = "/?adsfa" +        q = r.get_query() +        assert q.lst == [("adsfa", "")] + +        r.path = "/foo?x=y&a=b" +        assert r.get_query() +        r.set_query(odict.ODict([])) +        assert not r.get_query() +        qv = odict.ODict([("a", "b"), ("c", "d")]) +        r.set_query(qv) +        assert r.get_query() == qv + +    def test_anticache(self): +        r = HTTPRequest.wrap(netlib.tutils.treq()) +        r.headers = Headers() +        r.headers["if-modified-since"] = "test" +        r.headers["if-none-match"] = "test" +        r.anticache() +        assert not "if-modified-since" in r.headers +        assert not "if-none-match" in r.headers + +    def test_replace(self): +        r = HTTPRequest.wrap(netlib.tutils.treq()) +        r.path = "path/foo" +        r.headers["Foo"] = "fOo" +        r.content = "afoob" +        assert r.replace("foo(?i)", "boo") == 4 +        assert r.path == "path/boo" +        assert not "foo" in r.content +        assert r.headers["boo"] == "boo" + +    def test_constrain_encoding(self): +        r = HTTPRequest.wrap(netlib.tutils.treq()) +        r.headers["accept-encoding"] = "gzip, oink" +        r.constrain_encoding() +        assert "oink" not in r.headers["accept-encoding"] + +        r.headers.set_all("accept-encoding", ["gzip", "oink"]) +        r.constrain_encoding() +        assert "oink" not in r.headers["accept-encoding"] + +    def test_get_decoded_content(self): +        r = HTTPRequest.wrap(netlib.tutils.treq()) +        r.content = None +        r.headers["content-encoding"] = "identity" +        assert r.get_decoded_content() is None + +        r.content = "falafel" +        r.encode("gzip") +        assert r.get_decoded_content() == "falafel" + +    def test_get_content_type(self): +        resp = HTTPResponse.wrap(netlib.tutils.tresp()) +        resp.headers = Headers(content_type="text/plain") +        assert resp.headers["content-type"] == "text/plain" + + +class TestResponse: + +    def test_simple(self): +        f = tutils.tflow(resp=True) +        resp = f.response +        resp2 = resp.copy() +        assert resp2.get_state() == resp.get_state() + +    def test_refresh(self): +        r = HTTPResponse.wrap(netlib.tutils.tresp()) +        n = time.time() +        r.headers["date"] = email.utils.formatdate(n) +        pre = r.headers["date"] +        r.refresh(n) +        assert pre == r.headers["date"] +        r.refresh(n + 60) + +        d = email.utils.parsedate_tz(r.headers["date"]) +        d = email.utils.mktime_tz(d) +        # Weird that this is not exact... +        assert abs(60 - (d - n)) <= 1 + +        r.headers["set-cookie"] = "MOO=BAR; Expires=Tue, 08-Mar-2011 00:20:38 GMT; Path=foo.com; Secure" +        r.refresh() + +    def test_refresh_cookie(self): +        r = HTTPResponse.wrap(netlib.tutils.tresp()) + +        # Invalid expires format, sent to us by Reddit. +        c = "rfoo=bar; Domain=reddit.com; expires=Thu, 31 Dec 2037 23:59:59 GMT; Path=/" +        assert r._refresh_cookie(c, 60) + +        c = "MOO=BAR; Expires=Tue, 08-Mar-2011 00:20:38 GMT; Path=foo.com; Secure" +        assert "00:21:38" in r._refresh_cookie(c, 60) + +        # https://github.com/mitmproxy/mitmproxy/issues/773 +        c = ">=A" +        with tutils.raises(ValueError): +            r._refresh_cookie(c, 60) + +    def test_replace(self): +        r = HTTPResponse.wrap(netlib.tutils.tresp()) +        r.headers["Foo"] = "fOo" +        r.content = "afoob" +        assert r.replace("foo(?i)", "boo") == 3 +        assert not "foo" in r.content +        assert r.headers["boo"] == "boo" + +    def test_get_content_type(self): +        resp = HTTPResponse.wrap(netlib.tutils.tresp()) +        resp.headers = Headers(content_type="text/plain") +        assert resp.headers["content-type"] == "text/plain" + + +class TestError: + +    def test_getset_state(self): +        e = Error("Error") +        state = e.get_state() +        assert Error.from_state(state).get_state() == e.get_state() + +        assert e.copy() + +        e2 = Error("bar") +        assert not e == e2 +        e.set_state(e2.get_state()) +        assert e.get_state() == e2.get_state() + +        e3 = e.copy() +        assert e3.get_state() == e.get_state() + + +class TestClientConnection: + +    def test_state(self): + +        c = tutils.tclient_conn() +        assert ClientConnection.from_state(c.get_state()).get_state() ==\ +            c.get_state() + +        c2 = tutils.tclient_conn() +        c2.address.address = (c2.address.host, 4242) +        assert not c == c2 + +        c2.timestamp_start = 42 +        c.set_state(c2.get_state()) +        assert c.timestamp_start == 42 + +        c3 = c.copy() +        assert c3.get_state() == c.get_state() + +        assert str(c) + + +def test_replacehooks(): +    h = flow.ReplaceHooks() +    h.add("~q", "foo", "bar") +    assert h.lst + +    h.set( +        [ +            (".*", "one", "two"), +            (".*", "three", "four"), +        ] +    ) +    assert h.count() == 2 + +    h.clear() +    assert not h.lst + +    h.add("~q", "foo", "bar") +    h.add("~s", "foo", "bar") + +    v = h.get_specs() +    assert v == [('~q', 'foo', 'bar'), ('~s', 'foo', 'bar')] +    assert h.count() == 2 +    h.clear() +    assert h.count() == 0 + +    f = tutils.tflow() +    f.request.content = "foo" +    h.add("~s", "foo", "bar") +    h.run(f) +    assert f.request.content == "foo" + +    f = tutils.tflow(resp=True) +    f.request.content = "foo" +    f.response.content = "foo" +    h.run(f) +    assert f.response.content == "bar" +    assert f.request.content == "foo" + +    f = tutils.tflow() +    h.clear() +    h.add("~q", "foo", "bar") +    f.request.content = "foo" +    h.run(f) +    assert f.request.content == "bar" + +    assert not h.add("~", "foo", "bar") +    assert not h.add("foo", "*", "bar") + + +def test_setheaders(): +    h = flow.SetHeaders() +    h.add("~q", "foo", "bar") +    assert h.lst + +    h.set( +        [ +            (".*", "one", "two"), +            (".*", "three", "four"), +        ] +    ) +    assert h.count() == 2 + +    h.clear() +    assert not h.lst + +    h.add("~q", "foo", "bar") +    h.add("~s", "foo", "bar") + +    v = h.get_specs() +    assert v == [('~q', 'foo', 'bar'), ('~s', 'foo', 'bar')] +    assert h.count() == 2 +    h.clear() +    assert h.count() == 0 + +    f = tutils.tflow() +    f.request.content = "foo" +    h.add("~s", "foo", "bar") +    h.run(f) +    assert f.request.content == "foo" + +    h.clear() +    h.add("~s", "one", "two") +    h.add("~s", "one", "three") +    f = tutils.tflow(resp=True) +    f.request.headers["one"] = "xxx" +    f.response.headers["one"] = "xxx" +    h.run(f) +    assert f.request.headers["one"] == "xxx" +    assert f.response.headers.get_all("one") == ["two", "three"] + +    h.clear() +    h.add("~q", "one", "two") +    h.add("~q", "one", "three") +    f = tutils.tflow() +    f.request.headers["one"] = "xxx" +    h.run(f) +    assert f.request.headers.get_all("one") == ["two", "three"] + +    assert not h.add("~", "foo", "bar")  | 
