From d12515f84b32b3157fa99ac3c3a7a7318f9626ba Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 18 Aug 2016 17:31:43 +0200 Subject: websockets: refactor implementation and add tests --- test/netlib/websockets/test_frame.py | 164 +++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 test/netlib/websockets/test_frame.py (limited to 'test/netlib/websockets/test_frame.py') diff --git a/test/netlib/websockets/test_frame.py b/test/netlib/websockets/test_frame.py new file mode 100644 index 00000000..cce39454 --- /dev/null +++ b/test/netlib/websockets/test_frame.py @@ -0,0 +1,164 @@ +import os +import codecs +import pytest + +from netlib import websockets +from netlib import tutils + + +class TestFrameHeader(object): + + @pytest.mark.parametrize("input,expected", [ + (0, '0100'), + (125, '017D'), + (126, '017E007E'), + (127, '017E007F'), + (142, '017E008E'), + (65534, '017EFFFE'), + (65535, '017EFFFF'), + (65536, '017F0000000000010000'), + (8589934591, '017F00000001FFFFFFFF'), + (2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'), + ]) + def test_serialization_length(self, input, expected): + h = websockets.FrameHeader( + opcode=websockets.OPCODE.TEXT, + payload_length=input, + ) + assert bytes(h) == codecs.decode(expected, 'hex') + + def test_serialization_too_large(self): + h = websockets.FrameHeader( + payload_length=2 ** 64 + 1, + ) + with pytest.raises(ValueError): + bytes(h) + + @pytest.mark.parametrize("input,expected", [ + ('0100', 0), + ('017D', 125), + ('017E007E', 126), + ('017E007F', 127), + ('017E008E', 142), + ('017EFFFE', 65534), + ('017EFFFF', 65535), + ('017F0000000000010000', 65536), + ('017F00000001FFFFFFFF', 8589934591), + ('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1), + ]) + def test_deserialization_length(self, input, expected): + h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) + assert h.payload_length == expected + + @pytest.mark.parametrize("input,expected", [ + ('0100', (False, None)), + ('018000000000', (True, '00000000')), + ('018012345678', (True, '12345678')), + ]) + def test_deserialization_masking(self, input, expected): + h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) + assert h.mask == expected[0] + if h.mask: + assert h.masking_key == codecs.decode(expected[1], 'hex') + + def test_equality(self): + h = websockets.FrameHeader(mask=True, masking_key=b'1234') + h2 = websockets.FrameHeader(mask=True, masking_key=b'1234') + assert h == h2 + + h = websockets.FrameHeader(fin=True) + h2 = websockets.FrameHeader(fin=False) + assert h != h2 + + assert h != 'foobar' + + def test_roundtrip(self): + def round(*args, **kwargs): + h = websockets.FrameHeader(*args, **kwargs) + h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h))) + assert h == h2 + + round() + round(fin=True) + round(rsv1=True) + round(rsv2=True) + round(rsv3=True) + round(payload_length=1) + round(payload_length=100) + round(payload_length=1000) + round(payload_length=10000) + round(opcode=websockets.OPCODE.PING) + round(masking_key=b"test") + + def test_human_readable(self): + f = websockets.FrameHeader( + masking_key=b"test", + fin=True, + payload_length=10 + ) + assert repr(f) + + f = websockets.FrameHeader() + assert repr(f) + + def test_funky(self): + f = websockets.FrameHeader(masking_key=b"test", mask=False) + raw = bytes(f) + f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) + assert not f2.mask + + def test_violations(self): + tutils.raises("opcode", websockets.FrameHeader, opcode=17) + tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") + + def test_automask(self): + f = websockets.FrameHeader(mask=True) + assert f.masking_key + + f = websockets.FrameHeader(masking_key=b"foob") + assert f.mask + + f = websockets.FrameHeader(masking_key=b"foob", mask=0) + assert not f.mask + assert f.masking_key + + +class TestFrame(object): + def test_equality(self): + f = websockets.Frame(payload=b'1234') + f2 = websockets.Frame(payload=b'1234') + assert f == f2 + + assert f != b'1234' + + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.Frame(*args, **kwargs) + raw = bytes(f) + f2 = websockets.Frame.from_file(tutils.treader(raw)) + assert f == f2 + round(b"test") + round(b"test", fin=1) + round(b"test", rsv1=1) + round(b"test", opcode=websockets.OPCODE.PING) + round(b"test", masking_key=b"test") + + def test_human_readable(self): + f = websockets.Frame() + assert repr(f) + + f = websockets.Frame(b"foobar") + assert "foobar" in repr(f) + + @pytest.mark.parametrize("masked", [True, False]) + @pytest.mark.parametrize("length", [100, 50000, 150000]) + def test_serialization_bijection(self, masked, length): + frame = websockets.Frame( + os.urandom(length), + fin=True, + opcode=websockets.OPCODE.TEXT, + mask=int(masked), + masking_key=(os.urandom(4) if masked else None) + ) + serialized = bytes(frame) + assert frame == websockets.Frame.from_bytes(serialized) -- cgit v1.2.3