aboutsummaryrefslogtreecommitdiffstats
path: root/test/mitmproxy/test_controller.py
blob: 6d4b8fe637fe7eb96591935a2d30f0b3cedf31cb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from threading import Thread, Event

from mock import Mock

from mitmproxy import controller
from six.moves import queue

from mitmproxy.exceptions import Kill
from mitmproxy.proxy import DummyServer
from netlib.tutils import raises


class TMsg:
    pass


class TestMaster(object):
    def test_simple(self):
        class DummyMaster(controller.Master):
            @controller.handler
            def log(self, _):
                m.should_exit.set()

            def tick(self, timeout):
                # Speed up test
                super(DummyMaster, self).tick(0)

        m = DummyMaster(None)
        assert not m.should_exit.is_set()
        msg = TMsg()
        msg.reply = controller.DummyReply()
        m.event_queue.put(("log", msg))
        m.run()
        assert m.should_exit.is_set()

    def test_server_simple(self):
        m = controller.Master(None)
        s = DummyServer(None)
        m.add_server(s)
        m.start()
        m.shutdown()
        m.start()
        m.shutdown()


class TestServerThread(object):
    def test_simple(self):
        m = Mock()
        t = controller.ServerThread(m)
        t.run()
        assert m.serve_forever.called


class TestChannel(object):
    def test_tell(self):
        q = queue.Queue()
        channel = controller.Channel(q, Event())
        m = Mock()
        channel.tell("test", m)
        assert q.get() == ("test", m)
        assert m.reply

    def test_ask_simple(self):
        q = queue.Queue()

        def reply():
            m, obj = q.get()
            assert m == "test"
            obj.reply.send(42)

        Thread(target=reply).start()

        channel = controller.Channel(q, Event())
        assert channel.ask("test", Mock()) == 42

    def test_ask_shutdown(self):
        q = queue.Queue()
        done = Event()
        done.set()
        channel = controller.Channel(q, done)
        with raises(Kill):
            channel.ask("test", Mock())


class TestDummyReply(object):
    def test_simple(self):
        reply = controller.DummyReply()
        assert not reply.acked
        reply.ack()
        assert reply.acked


class TestReply(object):
    def test_simple(self):
        reply = controller.Reply(42)
        assert not reply.acked
        reply.send("foo")
        assert reply.acked
        assert reply.q.get() == "foo"

    def test_default(self):
        reply = controller.Reply(42)
        reply.ack()
        assert reply.q.get() == 42

    def test_reply_none(self):
        reply = controller.Reply(42)
        reply.send(None)
        assert reply.q.get() is None