aboutsummaryrefslogtreecommitdiffstats
path: root/test/netlib/websockets/test_frame.py
blob: 3b7c9ed4a53431375eb61fb5d446054b2b8fbea8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
import codecs
import pytest

from netlib import websockets
from mitmproxy.test import tutils


class TestFrameHeader:

    @pytest.mark.parametrize("input,expected", [
        (0, '0100'),
        (125, '017D'),
        (126, '017E007E'),
        (127, '017E007F'),
        (142, '017E008E'),
        (65534, '017EFFFE'),
        (65535, '017EFFFF'),
        (65536, '017F0000000000010000'),
        (8589934591, '017F00000001FFFFFFFF'),
        (2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'),
    ])
    def test_serialization_length(self, input, expected):
        h = websockets.FrameHeader(
            opcode=websockets.OPCODE.TEXT,
            payload_length=input,
        )
        assert bytes(h) == codecs.decode(expected, 'hex')

    def test_serialization_too_large(self):
        h = websockets.FrameHeader(
            payload_length=2 ** 64 + 1,
        )
        with pytest.raises(ValueError):
            bytes(h)

    @pytest.mark.parametrize("input,expected", [
        ('0100', 0),
        ('017D', 125),
        ('017E007E', 126),
        ('017E007F', 127),
        ('017E008E', 142),
        ('017EFFFE', 65534),
        ('017EFFFF', 65535),
        ('017F0000000000010000', 65536),
        ('017F00000001FFFFFFFF', 8589934591),
        ('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1),
    ])
    def test_deserialization_length(self, input, expected):
        h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
        assert h.payload_length == expected

    @pytest.mark.parametrize("input,expected", [
        ('0100', (False, None)),
        ('018000000000', (True, '00000000')),
        ('018012345678', (True, '12345678')),
    ])
    def test_deserialization_masking(self, input, expected):
        h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
        assert h.mask == expected[0]
        if h.mask:
            assert h.masking_key == codecs.decode(expected[1], 'hex')

    def test_equality(self):
        h = websockets.FrameHeader(mask=True, masking_key=b'1234')
        h2 = websockets.FrameHeader(mask=True, masking_key=b'1234')
        assert h == h2

        h = websockets.FrameHeader(fin=True)
        h2 = websockets.FrameHeader(fin=False)
        assert h != h2

        assert h != 'foobar'

    def test_roundtrip(self):
        def round(*args, **kwargs):
            h = websockets.FrameHeader(*args, **kwargs)
            h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h)))
            assert h == h2

        round()
        round(fin=True)
        round(rsv1=True)
        round(rsv2=True)
        round(rsv3=True)
        round(payload_length=1)
        round(payload_length=100)
        round(payload_length=1000)
        round(payload_length=10000)
        round(opcode=websockets.OPCODE.PING)
        round(masking_key=b"test")

    def test_human_readable(self):
        f = websockets.FrameHeader(
            masking_key=b"test",
            fin=True,
            payload_length=10
        )
        assert repr(f)

        f = websockets.FrameHeader()
        assert repr(f)

    def test_funky(self):
        f = websockets.FrameHeader(masking_key=b"test", mask=False)
        raw = bytes(f)
        f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
        assert not f2.mask

    def test_violations(self):
        tutils.raises("opcode", websockets.FrameHeader, opcode=17)
        tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")

    def test_automask(self):
        f = websockets.FrameHeader(mask=True)
        assert f.masking_key

        f = websockets.FrameHeader(masking_key=b"foob")
        assert f.mask

        f = websockets.FrameHeader(masking_key=b"foob", mask=0)
        assert not f.mask
        assert f.masking_key


class TestFrame:
    def test_equality(self):
        f = websockets.Frame(payload=b'1234')
        f2 = websockets.Frame(payload=b'1234')
        assert f == f2

        assert f != b'1234'

    def test_roundtrip(self):
        def round(*args, **kwargs):
            f = websockets.Frame(*args, **kwargs)
            raw = bytes(f)
            f2 = websockets.Frame.from_file(tutils.treader(raw))
            assert f == f2
        round(b"test")
        round(b"test", fin=1)
        round(b"test", rsv1=1)
        round(b"test", opcode=websockets.OPCODE.PING)
        round(b"test", masking_key=b"test")

    def test_human_readable(self):
        f = websockets.Frame()
        assert repr(f)

        f = websockets.Frame(b"foobar")
        assert "foobar" in repr(f)

    @pytest.mark.parametrize("masked", [True, False])
    @pytest.mark.parametrize("length", [100, 50000, 150000])
    def test_serialization_bijection(self, masked, length):
        frame = websockets.Frame(
            os.urandom(length),
            fin=True,
            opcode=websockets.OPCODE.TEXT,
            mask=int(masked),
            masking_key=(os.urandom(4) if masked else None)
        )
        serialized = bytes(frame)
        assert frame == websockets.Frame.from_bytes(serialized)