diff options
Diffstat (limited to 'test/netlib')
| -rw-r--r-- | test/netlib/websockets/test_frame.py | 164 | ||||
| -rw-r--r-- | test/netlib/websockets/test_masker.py | 23 | ||||
| -rw-r--r-- | test/netlib/websockets/test_utils.py | 105 | ||||
| -rw-r--r-- | test/netlib/websockets/test_websockets.py | 269 | 
4 files changed, 292 insertions, 269 deletions
| diff --git a/test/netlib/websockets/test_frame.py b/test/netlib/websockets/test_frame.py new file mode 100644 index 00000000..cce39454 --- /dev/null +++ b/test/netlib/websockets/test_frame.py @@ -0,0 +1,164 @@ +import os +import codecs +import pytest + +from netlib import websockets +from netlib import tutils + + +class TestFrameHeader(object): + +    @pytest.mark.parametrize("input,expected", [ +        (0, '0100'), +        (125, '017D'), +        (126, '017E007E'), +        (127, '017E007F'), +        (142, '017E008E'), +        (65534, '017EFFFE'), +        (65535, '017EFFFF'), +        (65536, '017F0000000000010000'), +        (8589934591, '017F00000001FFFFFFFF'), +        (2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'), +    ]) +    def test_serialization_length(self, input, expected): +        h = websockets.FrameHeader( +            opcode=websockets.OPCODE.TEXT, +            payload_length=input, +        ) +        assert bytes(h) == codecs.decode(expected, 'hex') + +    def test_serialization_too_large(self): +        h = websockets.FrameHeader( +            payload_length=2 ** 64 + 1, +        ) +        with pytest.raises(ValueError): +            bytes(h) + +    @pytest.mark.parametrize("input,expected", [ +        ('0100', 0), +        ('017D', 125), +        ('017E007E', 126), +        ('017E007F', 127), +        ('017E008E', 142), +        ('017EFFFE', 65534), +        ('017EFFFF', 65535), +        ('017F0000000000010000', 65536), +        ('017F00000001FFFFFFFF', 8589934591), +        ('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1), +    ]) +    def test_deserialization_length(self, input, expected): +        h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) +        assert h.payload_length == expected + +    @pytest.mark.parametrize("input,expected", [ +        ('0100', (False, None)), +        ('018000000000', (True, '00000000')), +        ('018012345678', (True, '12345678')), +    ]) +    def test_deserialization_masking(self, input, expected): +        h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) +        assert h.mask == expected[0] +        if h.mask: +            assert h.masking_key == codecs.decode(expected[1], 'hex') + +    def test_equality(self): +        h = websockets.FrameHeader(mask=True, masking_key=b'1234') +        h2 = websockets.FrameHeader(mask=True, masking_key=b'1234') +        assert h == h2 + +        h = websockets.FrameHeader(fin=True) +        h2 = websockets.FrameHeader(fin=False) +        assert h != h2 + +        assert h != 'foobar' + +    def test_roundtrip(self): +        def round(*args, **kwargs): +            h = websockets.FrameHeader(*args, **kwargs) +            h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h))) +            assert h == h2 + +        round() +        round(fin=True) +        round(rsv1=True) +        round(rsv2=True) +        round(rsv3=True) +        round(payload_length=1) +        round(payload_length=100) +        round(payload_length=1000) +        round(payload_length=10000) +        round(opcode=websockets.OPCODE.PING) +        round(masking_key=b"test") + +    def test_human_readable(self): +        f = websockets.FrameHeader( +            masking_key=b"test", +            fin=True, +            payload_length=10 +        ) +        assert repr(f) + +        f = websockets.FrameHeader() +        assert repr(f) + +    def test_funky(self): +        f = websockets.FrameHeader(masking_key=b"test", mask=False) +        raw = bytes(f) +        f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) +        assert not f2.mask + +    def test_violations(self): +        tutils.raises("opcode", websockets.FrameHeader, opcode=17) +        tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") + +    def test_automask(self): +        f = websockets.FrameHeader(mask=True) +        assert f.masking_key + +        f = websockets.FrameHeader(masking_key=b"foob") +        assert f.mask + +        f = websockets.FrameHeader(masking_key=b"foob", mask=0) +        assert not f.mask +        assert f.masking_key + + +class TestFrame(object): +    def test_equality(self): +        f = websockets.Frame(payload=b'1234') +        f2 = websockets.Frame(payload=b'1234') +        assert f == f2 + +        assert f != b'1234' + +    def test_roundtrip(self): +        def round(*args, **kwargs): +            f = websockets.Frame(*args, **kwargs) +            raw = bytes(f) +            f2 = websockets.Frame.from_file(tutils.treader(raw)) +            assert f == f2 +        round(b"test") +        round(b"test", fin=1) +        round(b"test", rsv1=1) +        round(b"test", opcode=websockets.OPCODE.PING) +        round(b"test", masking_key=b"test") + +    def test_human_readable(self): +        f = websockets.Frame() +        assert repr(f) + +        f = websockets.Frame(b"foobar") +        assert "foobar" in repr(f) + +    @pytest.mark.parametrize("masked", [True, False]) +    @pytest.mark.parametrize("length", [100, 50000, 150000]) +    def test_serialization_bijection(self, masked, length): +        frame = websockets.Frame( +            os.urandom(length), +            fin=True, +            opcode=websockets.OPCODE.TEXT, +            mask=int(masked), +            masking_key=(os.urandom(4) if masked else None) +        ) +        serialized = bytes(frame) +        assert frame == websockets.Frame.from_bytes(serialized) diff --git a/test/netlib/websockets/test_masker.py b/test/netlib/websockets/test_masker.py new file mode 100644 index 00000000..528fce71 --- /dev/null +++ b/test/netlib/websockets/test_masker.py @@ -0,0 +1,23 @@ +import codecs +import pytest + +from netlib import websockets + + +class TestMasker(object): + +    @pytest.mark.parametrize("input,expected", [ +        ([b"a"], '00'), +        ([b"four"], '070d1616'), +        ([b"fourf"], '070d161607'), +        ([b"fourfive"], '070d1616070b1501'), +        ([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'), +        ([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'),  # noqa +    ]) +    def test_masker(self, input, expected): +        m = websockets.Masker(b"abcd") +        data = b"".join([m(t) for t in input]) +        assert data == codecs.decode(expected, 'hex') + +        data = websockets.Masker(b"abcd")(data) +        assert data == b"".join(input) diff --git a/test/netlib/websockets/test_utils.py b/test/netlib/websockets/test_utils.py new file mode 100644 index 00000000..34765e04 --- /dev/null +++ b/test/netlib/websockets/test_utils.py @@ -0,0 +1,105 @@ +import pytest + +from netlib import http +from netlib import websockets + + +class TestUtils(object): + +    def test_client_handshake_headers(self): +        h = websockets.client_handshake_headers(version='42') +        assert h['sec-websocket-version'] == '42' + +        h = websockets.client_handshake_headers(key='some-key') +        assert h['sec-websocket-key'] == 'some-key' + +        h = websockets.client_handshake_headers(protocol='foobar') +        assert h['sec-websocket-protocol'] == 'foobar' + +        h = websockets.client_handshake_headers(extensions='foo; bar') +        assert h['sec-websocket-extensions'] == 'foo; bar' + +    def test_server_handshake_headers(self): +        h = websockets.server_handshake_headers('some-key') +        assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw=' +        assert 'sec-websocket-protocol' not in h +        assert 'sec-websocket-extensions' not in h + +        h = websockets.server_handshake_headers('some-key', 'foobar', 'foo; bar') +        assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw=' +        assert h['sec-websocket-protocol'] == 'foobar' +        assert h['sec-websocket-extensions'] == 'foo; bar' + +    @pytest.mark.parametrize("input,expected", [ +        ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], True), +        ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-accept', b'foobar')], True), +        ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-KeY', b'foobar')], True), +        ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-AccePt', b'foobar')], True), +        ([(b'connection', b'foo'), (b'upgrade', b'bar'), (b'sec-websocket-key', b'foobar')], False), +        ([(b'connection', b'upgrade'), (b'upgrade', b'websocket')], False), +        ([(b'connection', b'upgrade'), (b'sec-websocket-key', b'foobar')], False), +        ([(b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], False), +        ([], False), +    ]) +    def test_check_handshake(self, input, expected): +        h = http.Headers(input) +        assert websockets.check_handshake(h) == expected + +    @pytest.mark.parametrize("input,expected", [ +        ([(b'sec-websocket-version', b'13')], True), +        ([(b'Sec-WebSockeT-VerSion', b'13')], True), +        ([(b'sec-websocket-version', b'9')], False), +        ([(b'sec-websocket-version', b'42')], False), +        ([(b'sec-websocket-version', b'')], False), +        ([], False), +    ]) +    def test_check_client_version(self, input, expected): +        h = http.Headers(input) +        assert websockets.check_client_version(h) == expected + +    @pytest.mark.parametrize("input,expected", [ +        ('foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='), +        (b'foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='), +    ]) +    def test_create_server_nonce(self, input, expected): +        assert websockets.create_server_nonce(input) == expected + +    @pytest.mark.parametrize("input,expected", [ +        ([(b'sec-websocket-extensions', b'foo; bar')], 'foo; bar'), +        ([(b'Sec-WebSockeT-ExteNsionS', b'foo; bar')], 'foo; bar'), +        ([(b'sec-websocket-extensions', b'')], ''), +        ([], None), +    ]) +    def test_get_extensions(self, input, expected): +        h = http.Headers(input) +        assert websockets.get_extensions(h) == expected + +    @pytest.mark.parametrize("input,expected", [ +        ([(b'sec-websocket-protocol', b'foobar')], 'foobar'), +        ([(b'Sec-WebSockeT-ProTocoL', b'foobar')], 'foobar'), +        ([(b'sec-websocket-protocol', b'')], ''), +        ([], None), +    ]) +    def test_get_protocol(self, input, expected): +        h = http.Headers(input) +        assert websockets.get_protocol(h) == expected + +    @pytest.mark.parametrize("input,expected", [ +        ([(b'sec-websocket-key', b'foobar')], 'foobar'), +        ([(b'Sec-WebSockeT-KeY', b'foobar')], 'foobar'), +        ([(b'sec-websocket-key', b'')], ''), +        ([], None), +    ]) +    def test_get_client_key(self, input, expected): +        h = http.Headers(input) +        assert websockets.get_client_key(h) == expected + +    @pytest.mark.parametrize("input,expected", [ +        ([(b'sec-websocket-accept', b'foobar')], 'foobar'), +        ([(b'Sec-WebSockeT-AccepT', b'foobar')], 'foobar'), +        ([(b'sec-websocket-accept', b'')], ''), +        ([], None), +    ]) +    def test_get_server_accept(self, input, expected): +        h = http.Headers(input) +        assert websockets.get_server_accept(h) == expected diff --git a/test/netlib/websockets/test_websockets.py b/test/netlib/websockets/test_websockets.py deleted file mode 100644 index 50fa26e6..00000000 --- a/test/netlib/websockets/test_websockets.py +++ /dev/null @@ -1,269 +0,0 @@ -import os - -from netlib.http.http1 import read_response, read_request - -from netlib import tcp -from netlib import tutils -from netlib import websockets -from netlib.http import status_codes -from netlib.tutils import treq -from netlib import exceptions - -from .. import tservers - - -class WebSocketsEchoHandler(tcp.BaseHandler): - -    def __init__(self, connection, address, server): -        super(WebSocketsEchoHandler, self).__init__( -            connection, address, server -        ) -        self.protocol = websockets.WebsocketsProtocol() -        self.handshake_done = False - -    def handle(self): -        while True: -            if not self.handshake_done: -                self.handshake() -            else: -                self.read_next_message() - -    def read_next_message(self): -        frame = websockets.Frame.from_file(self.rfile) -        self.on_message(frame.payload) - -    def send_message(self, message): -        frame = websockets.Frame.default(message, from_client=False) -        frame.to_file(self.wfile) - -    def handshake(self): - -        req = read_request(self.rfile) -        key = self.protocol.check_client_handshake(req.headers) - -        preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) -        self.wfile.write(preamble.encode() + b"\r\n") -        headers = self.protocol.server_handshake_headers(key) -        self.wfile.write(str(headers) + "\r\n") -        self.wfile.flush() -        self.handshake_done = True - -    def on_message(self, message): -        if message is not None: -            self.send_message(message) - - -class WebSocketsClient(tcp.TCPClient): - -    def __init__(self, address, source_address=None): -        super(WebSocketsClient, self).__init__(address, source_address) -        self.protocol = websockets.WebsocketsProtocol() -        self.client_nonce = None - -    def connect(self): -        super(WebSocketsClient, self).connect() - -        preamble = b'GET / HTTP/1.1' -        self.wfile.write(preamble + b"\r\n") -        headers = self.protocol.client_handshake_headers() -        self.client_nonce = headers["sec-websocket-key"].encode("ascii") -        self.wfile.write(bytes(headers) + b"\r\n") -        self.wfile.flush() - -        resp = read_response(self.rfile, treq(method=b"GET")) -        server_nonce = self.protocol.check_server_handshake(resp.headers) - -        if not server_nonce == self.protocol.create_server_nonce(self.client_nonce): -            self.close() - -    def read_next_message(self): -        return websockets.Frame.from_file(self.rfile).payload - -    def send_message(self, message): -        frame = websockets.Frame.default(message, from_client=True) -        frame.to_file(self.wfile) - - -class TestWebSockets(tservers.ServerTestBase): -    handler = WebSocketsEchoHandler - -    def __init__(self): -        self.protocol = websockets.WebsocketsProtocol() - -    def random_bytes(self, n=100): -        return os.urandom(n) - -    def echo(self, msg): -        client = WebSocketsClient(("127.0.0.1", self.port)) -        client.connect() -        client.send_message(msg) -        response = client.read_next_message() -        assert response == msg - -    def test_simple_echo(self): -        self.echo(b"hello I'm the client") - -    def test_frame_sizes(self): -        # length can fit in the the 7 bit payload length -        small_msg = self.random_bytes(100) -        # 50kb, sligthly larger than can fit in a 7 bit int -        medium_msg = self.random_bytes(50000) -        # 150kb, slightly larger than can fit in a 16 bit int -        large_msg = self.random_bytes(150000) - -        self.echo(small_msg) -        self.echo(medium_msg) -        self.echo(large_msg) - -    def test_default_builder(self): -        """ -          default builder should always generate valid frames -        """ -        msg = self.random_bytes() -        assert websockets.Frame.default(msg, from_client=True) -        assert websockets.Frame.default(msg, from_client=False) - -    def test_serialization_bijection(self): -        """ -          Ensure that various frame types can be serialized/deserialized back -          and forth between to_bytes() and from_bytes() -        """ -        for is_client in [True, False]: -            for num_bytes in [100, 50000, 150000]: -                frame = websockets.Frame.default( -                    self.random_bytes(num_bytes), is_client -                ) -                frame2 = websockets.Frame.from_bytes( -                    frame.to_bytes() -                ) -                assert frame == frame2 - -        bytes = b'\x81\x03cba' -        assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes - -    def test_check_server_handshake(self): -        headers = self.protocol.server_handshake_headers("key") -        assert self.protocol.check_server_handshake(headers) -        headers["Upgrade"] = "not_websocket" -        assert not self.protocol.check_server_handshake(headers) - -    def test_check_client_handshake(self): -        headers = self.protocol.client_handshake_headers("key") -        assert self.protocol.check_client_handshake(headers) == "key" -        headers["Upgrade"] = "not_websocket" -        assert not self.protocol.check_client_handshake(headers) - - -class BadHandshakeHandler(WebSocketsEchoHandler): - -    def handshake(self): - -        client_hs = read_request(self.rfile) -        self.protocol.check_client_handshake(client_hs.headers) - -        preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101) -        self.wfile.write(preamble.encode()) -        headers = self.protocol.server_handshake_headers(b"malformed key") -        self.wfile.write(bytes(headers) + b"\r\n") -        self.wfile.flush() -        self.handshake_done = True - - -class TestBadHandshake(tservers.ServerTestBase): - -    """ -      Ensure that the client disconnects if the server handshake is malformed -    """ -    handler = BadHandshakeHandler - -    def test(self): -        with tutils.raises(exceptions.TcpDisconnect): -            client = WebSocketsClient(("127.0.0.1", self.port)) -            client.connect() -            client.send_message(b"hello") - - -class TestFrameHeader: - -    def test_roundtrip(self): -        def round(*args, **kwargs): -            f = websockets.FrameHeader(*args, **kwargs) -            f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f))) -            assert f == f2 -        round() -        round(fin=1) -        round(rsv1=1) -        round(rsv2=1) -        round(rsv3=1) -        round(payload_length=1) -        round(payload_length=100) -        round(payload_length=1000) -        round(payload_length=10000) -        round(opcode=websockets.OPCODE.PING) -        round(masking_key=b"test") - -    def test_human_readable(self): -        f = websockets.FrameHeader( -            masking_key=b"test", -            fin=True, -            payload_length=10 -        ) -        assert repr(f) -        f = websockets.FrameHeader() -        assert repr(f) - -    def test_funky(self): -        f = websockets.FrameHeader(masking_key=b"test", mask=False) -        raw = bytes(f) -        f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) -        assert not f2.mask - -    def test_violations(self): -        tutils.raises("opcode", websockets.FrameHeader, opcode=17) -        tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") - -    def test_automask(self): -        f = websockets.FrameHeader(mask=True) -        assert f.masking_key - -        f = websockets.FrameHeader(masking_key=b"foob") -        assert f.mask - -        f = websockets.FrameHeader(masking_key=b"foob", mask=0) -        assert not f.mask -        assert f.masking_key - - -class TestFrame: - -    def test_roundtrip(self): -        def round(*args, **kwargs): -            f = websockets.Frame(*args, **kwargs) -            raw = bytes(f) -            f2 = websockets.Frame.from_file(tutils.treader(raw)) -            assert f == f2 -        round(b"test") -        round(b"test", fin=1) -        round(b"test", rsv1=1) -        round(b"test", opcode=websockets.OPCODE.PING) -        round(b"test", masking_key=b"test") - -    def test_human_readable(self): -        f = websockets.Frame() -        assert repr(f) - - -def test_masker(): -    tests = [ -        [b"a"], -        [b"four"], -        [b"fourf"], -        [b"fourfive"], -        [b"a", b"aasdfasdfa", b"asdf"], -        [b"a" * 50, b"aasdfasdfa", b"asdf"], -    ] -    for i in tests: -        m = websockets.Masker(b"abcd") -        data = b"".join([m(t) for t in i]) -        data2 = websockets.Masker(b"abcd")(data) -        assert data2 == b"".join(i) | 
