diff options
author | Thomas Kriechbaumer <thomas@kriechbaumer.name> | 2016-08-18 17:31:43 +0200 |
---|---|---|
committer | Thomas Kriechbaumer <thomas@kriechbaumer.name> | 2016-09-01 09:56:14 +0200 |
commit | d12515f84b32b3157fa99ac3c3a7a7318f9626ba (patch) | |
tree | 5fee9c57965ebbee51227eb374029b89dbe56cca /test/netlib | |
parent | 281d779fa3eb6b81ec76d046337275c0a82eff46 (diff) | |
download | mitmproxy-d12515f84b32b3157fa99ac3c3a7a7318f9626ba.tar.gz mitmproxy-d12515f84b32b3157fa99ac3c3a7a7318f9626ba.tar.bz2 mitmproxy-d12515f84b32b3157fa99ac3c3a7a7318f9626ba.zip |
websockets: refactor implementation and add tests
Diffstat (limited to 'test/netlib')
-rw-r--r-- | test/netlib/websockets/test_frame.py | 164 | ||||
-rw-r--r-- | test/netlib/websockets/test_masker.py | 23 | ||||
-rw-r--r-- | test/netlib/websockets/test_utils.py | 105 | ||||
-rw-r--r-- | test/netlib/websockets/test_websockets.py | 269 |
4 files changed, 292 insertions, 269 deletions
diff --git a/test/netlib/websockets/test_frame.py b/test/netlib/websockets/test_frame.py new file mode 100644 index 00000000..cce39454 --- /dev/null +++ b/test/netlib/websockets/test_frame.py @@ -0,0 +1,164 @@ +import os +import codecs +import pytest + +from netlib import websockets +from netlib import tutils + + +class TestFrameHeader(object): + + @pytest.mark.parametrize("input,expected", [ + (0, '0100'), + (125, '017D'), + (126, '017E007E'), + (127, '017E007F'), + (142, '017E008E'), + (65534, '017EFFFE'), + (65535, '017EFFFF'), + (65536, '017F0000000000010000'), + (8589934591, '017F00000001FFFFFFFF'), + (2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'), + ]) + def test_serialization_length(self, input, expected): + h = websockets.FrameHeader( + opcode=websockets.OPCODE.TEXT, + payload_length=input, + ) + assert bytes(h) == codecs.decode(expected, 'hex') + + def test_serialization_too_large(self): + h = websockets.FrameHeader( + payload_length=2 ** 64 + 1, + ) + with pytest.raises(ValueError): + bytes(h) + + @pytest.mark.parametrize("input,expected", [ + ('0100', 0), + ('017D', 125), + ('017E007E', 126), + ('017E007F', 127), + ('017E008E', 142), + ('017EFFFE', 65534), + ('017EFFFF', 65535), + ('017F0000000000010000', 65536), + ('017F00000001FFFFFFFF', 8589934591), + ('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1), + ]) + def test_deserialization_length(self, input, expected): + h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) + assert h.payload_length == expected + + @pytest.mark.parametrize("input,expected", [ + ('0100', (False, None)), + ('018000000000', (True, '00000000')), + ('018012345678', (True, '12345678')), + ]) + def test_deserialization_masking(self, input, expected): + h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) + assert h.mask == expected[0] + if h.mask: + assert h.masking_key == codecs.decode(expected[1], 'hex') + + def test_equality(self): + h = websockets.FrameHeader(mask=True, masking_key=b'1234') + h2 = websockets.FrameHeader(mask=True, masking_key=b'1234') + assert h == h2 + + h = websockets.FrameHeader(fin=True) + h2 = websockets.FrameHeader(fin=False) + assert h != h2 + + assert h != 'foobar' + + def test_roundtrip(self): + def round(*args, **kwargs): + h = websockets.FrameHeader(*args, **kwargs) + h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h))) + assert h == h2 + + round() + round(fin=True) + round(rsv1=True) + round(rsv2=True) + round(rsv3=True) + round(payload_length=1) + round(payload_length=100) + round(payload_length=1000) + round(payload_length=10000) + round(opcode=websockets.OPCODE.PING) + round(masking_key=b"test") + + def test_human_readable(self): + f = websockets.FrameHeader( + masking_key=b"test", + fin=True, + payload_length=10 + ) + assert repr(f) + + f = websockets.FrameHeader() + assert repr(f) + + def test_funky(self): + f = websockets.FrameHeader(masking_key=b"test", mask=False) + raw = bytes(f) + f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) + assert not f2.mask + + def test_violations(self): + tutils.raises("opcode", websockets.FrameHeader, opcode=17) + tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") + + def test_automask(self): + f = websockets.FrameHeader(mask=True) + assert f.masking_key + + f = websockets.FrameHeader(masking_key=b"foob") + assert f.mask + + f = websockets.FrameHeader(masking_key=b"foob", mask=0) + assert not f.mask + assert f.masking_key + + +class TestFrame(object): + def test_equality(self): + f = websockets.Frame(payload=b'1234') + f2 = websockets.Frame(payload=b'1234') + assert f == f2 + + assert f != b'1234' + + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.Frame(*args, **kwargs) + raw = bytes(f) + f2 = websockets.Frame.from_file(tutils.treader(raw)) + assert f == f2 + round(b"test") + round(b"test", fin=1) + round(b"test", rsv1=1) + round(b"test", opcode=websockets.OPCODE.PING) + round(b"test", masking_key=b"test") + + def test_human_readable(self): + f = websockets.Frame() + assert repr(f) + + f = websockets.Frame(b"foobar") + assert "foobar" in repr(f) + + @pytest.mark.parametrize("masked", [True, False]) + @pytest.mark.parametrize("length", [100, 50000, 150000]) + def test_serialization_bijection(self, masked, length): + frame = websockets.Frame( + os.urandom(length), + fin=True, + opcode=websockets.OPCODE.TEXT, + mask=int(masked), + masking_key=(os.urandom(4) if masked else None) + ) + serialized = bytes(frame) + assert frame == websockets.Frame.from_bytes(serialized) diff --git a/test/netlib/websockets/test_masker.py b/test/netlib/websockets/test_masker.py new file mode 100644 index 00000000..528fce71 --- /dev/null +++ b/test/netlib/websockets/test_masker.py @@ -0,0 +1,23 @@ +import codecs +import pytest + +from netlib import websockets + + +class TestMasker(object): + + @pytest.mark.parametrize("input,expected", [ + ([b"a"], '00'), + ([b"four"], '070d1616'), + ([b"fourf"], '070d161607'), + ([b"fourfive"], '070d1616070b1501'), + ([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'), + ([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'), # noqa + ]) + def test_masker(self, input, expected): + m = websockets.Masker(b"abcd") + data = b"".join([m(t) for t in input]) + assert data == codecs.decode(expected, 'hex') + + data = websockets.Masker(b"abcd")(data) + assert data == b"".join(input) diff --git a/test/netlib/websockets/test_utils.py b/test/netlib/websockets/test_utils.py new file mode 100644 index 00000000..34765e04 --- /dev/null +++ b/test/netlib/websockets/test_utils.py @@ -0,0 +1,105 @@ +import pytest + +from netlib import http +from netlib import websockets + + +class TestUtils(object): + + def test_client_handshake_headers(self): + h = websockets.client_handshake_headers(version='42') + assert h['sec-websocket-version'] == '42' + + h = websockets.client_handshake_headers(key='some-key') + assert h['sec-websocket-key'] == 'some-key' + + h = websockets.client_handshake_headers(protocol='foobar') + assert h['sec-websocket-protocol'] == 'foobar' + + h = websockets.client_handshake_headers(extensions='foo; bar') + assert h['sec-websocket-extensions'] == 'foo; bar' + + def test_server_handshake_headers(self): + h = websockets.server_handshake_headers('some-key') + assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw=' + assert 'sec-websocket-protocol' not in h + assert 'sec-websocket-extensions' not in h + + h = websockets.server_handshake_headers('some-key', 'foobar', 'foo; bar') + assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw=' + assert h['sec-websocket-protocol'] == 'foobar' + assert h['sec-websocket-extensions'] == 'foo; bar' + + @pytest.mark.parametrize("input,expected", [ + ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], True), + ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-accept', b'foobar')], True), + ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-KeY', b'foobar')], True), + ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-AccePt', b'foobar')], True), + ([(b'connection', b'foo'), (b'upgrade', b'bar'), (b'sec-websocket-key', b'foobar')], False), + ([(b'connection', b'upgrade'), (b'upgrade', b'websocket')], False), + ([(b'connection', b'upgrade'), (b'sec-websocket-key', b'foobar')], False), + ([(b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], False), + ([], False), + ]) + def test_check_handshake(self, input, expected): + h = http.Headers(input) + assert websockets.check_handshake(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-version', b'13')], True), + ([(b'Sec-WebSockeT-VerSion', b'13')], True), + ([(b'sec-websocket-version', b'9')], False), + ([(b'sec-websocket-version', b'42')], False), + ([(b'sec-websocket-version', b'')], False), + ([], False), + ]) + def test_check_client_version(self, input, expected): + h = http.Headers(input) + assert websockets.check_client_version(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ('foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='), + (b'foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='), + ]) + def test_create_server_nonce(self, input, expected): + assert websockets.create_server_nonce(input) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-extensions', b'foo; bar')], 'foo; bar'), + ([(b'Sec-WebSockeT-ExteNsionS', b'foo; bar')], 'foo; bar'), + ([(b'sec-websocket-extensions', b'')], ''), + ([], None), + ]) + def test_get_extensions(self, input, expected): + h = http.Headers(input) + assert websockets.get_extensions(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-protocol', b'foobar')], 'foobar'), + ([(b'Sec-WebSockeT-ProTocoL', b'foobar')], 'foobar'), + ([(b'sec-websocket-protocol', b'')], ''), + ([], None), + ]) + def test_get_protocol(self, input, expected): + h = http.Headers(input) + assert websockets.get_protocol(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-key', b'foobar')], 'foobar'), + ([(b'Sec-WebSockeT-KeY', b'foobar')], 'foobar'), + ([(b'sec-websocket-key', b'')], ''), + ([], None), + ]) + def test_get_client_key(self, input, expected): + h = http.Headers(input) + assert websockets.get_client_key(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-accept', b'foobar')], 'foobar'), + ([(b'Sec-WebSockeT-AccepT', b'foobar')], 'foobar'), + ([(b'sec-websocket-accept', b'')], ''), + ([], None), + ]) + def test_get_server_accept(self, input, expected): + h = http.Headers(input) + assert websockets.get_server_accept(h) == expected diff --git a/test/netlib/websockets/test_websockets.py b/test/netlib/websockets/test_websockets.py deleted file mode 100644 index 50fa26e6..00000000 --- a/test/netlib/websockets/test_websockets.py +++ /dev/null @@ -1,269 +0,0 @@ -import os - -from netlib.http.http1 import read_response, read_request - -from netlib import tcp -from netlib import tutils -from netlib import websockets -from netlib.http import status_codes -from netlib.tutils import treq -from netlib import exceptions - -from .. import tservers - - -class WebSocketsEchoHandler(tcp.BaseHandler): - - def __init__(self, connection, address, server): - super(WebSocketsEchoHandler, self).__init__( - connection, address, server - ) - self.protocol = websockets.WebsocketsProtocol() - self.handshake_done = False - - def handle(self): - while True: - if not self.handshake_done: - self.handshake() - else: - self.read_next_message() - - def read_next_message(self): - frame = websockets.Frame.from_file(self.rfile) - self.on_message(frame.payload) - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=False) - frame.to_file(self.wfile) - - def handshake(self): - - req = read_request(self.rfile) - key = self.protocol.check_client_handshake(req.headers) - - preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble.encode() + b"\r\n") - headers = self.protocol.server_handshake_headers(key) - self.wfile.write(str(headers) + "\r\n") - self.wfile.flush() - self.handshake_done = True - - def on_message(self, message): - if message is not None: - self.send_message(message) - - -class WebSocketsClient(tcp.TCPClient): - - def __init__(self, address, source_address=None): - super(WebSocketsClient, self).__init__(address, source_address) - self.protocol = websockets.WebsocketsProtocol() - self.client_nonce = None - - def connect(self): - super(WebSocketsClient, self).connect() - - preamble = b'GET / HTTP/1.1' - self.wfile.write(preamble + b"\r\n") - headers = self.protocol.client_handshake_headers() - self.client_nonce = headers["sec-websocket-key"].encode("ascii") - self.wfile.write(bytes(headers) + b"\r\n") - self.wfile.flush() - - resp = read_response(self.rfile, treq(method=b"GET")) - server_nonce = self.protocol.check_server_handshake(resp.headers) - - if not server_nonce == self.protocol.create_server_nonce(self.client_nonce): - self.close() - - def read_next_message(self): - return websockets.Frame.from_file(self.rfile).payload - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=True) - frame.to_file(self.wfile) - - -class TestWebSockets(tservers.ServerTestBase): - handler = WebSocketsEchoHandler - - def __init__(self): - self.protocol = websockets.WebsocketsProtocol() - - def random_bytes(self, n=100): - return os.urandom(n) - - def echo(self, msg): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(msg) - response = client.read_next_message() - assert response == msg - - def test_simple_echo(self): - self.echo(b"hello I'm the client") - - def test_frame_sizes(self): - # length can fit in the the 7 bit payload length - small_msg = self.random_bytes(100) - # 50kb, sligthly larger than can fit in a 7 bit int - medium_msg = self.random_bytes(50000) - # 150kb, slightly larger than can fit in a 16 bit int - large_msg = self.random_bytes(150000) - - self.echo(small_msg) - self.echo(medium_msg) - self.echo(large_msg) - - def test_default_builder(self): - """ - default builder should always generate valid frames - """ - msg = self.random_bytes() - assert websockets.Frame.default(msg, from_client=True) - assert websockets.Frame.default(msg, from_client=False) - - def test_serialization_bijection(self): - """ - Ensure that various frame types can be serialized/deserialized back - and forth between to_bytes() and from_bytes() - """ - for is_client in [True, False]: - for num_bytes in [100, 50000, 150000]: - frame = websockets.Frame.default( - self.random_bytes(num_bytes), is_client - ) - frame2 = websockets.Frame.from_bytes( - frame.to_bytes() - ) - assert frame == frame2 - - bytes = b'\x81\x03cba' - assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes - - def test_check_server_handshake(self): - headers = self.protocol.server_handshake_headers("key") - assert self.protocol.check_server_handshake(headers) - headers["Upgrade"] = "not_websocket" - assert not self.protocol.check_server_handshake(headers) - - def test_check_client_handshake(self): - headers = self.protocol.client_handshake_headers("key") - assert self.protocol.check_client_handshake(headers) == "key" - headers["Upgrade"] = "not_websocket" - assert not self.protocol.check_client_handshake(headers) - - -class BadHandshakeHandler(WebSocketsEchoHandler): - - def handshake(self): - - client_hs = read_request(self.rfile) - self.protocol.check_client_handshake(client_hs.headers) - - preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble.encode()) - headers = self.protocol.server_handshake_headers(b"malformed key") - self.wfile.write(bytes(headers) + b"\r\n") - self.wfile.flush() - self.handshake_done = True - - -class TestBadHandshake(tservers.ServerTestBase): - - """ - Ensure that the client disconnects if the server handshake is malformed - """ - handler = BadHandshakeHandler - - def test(self): - with tutils.raises(exceptions.TcpDisconnect): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(b"hello") - - -class TestFrameHeader: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.FrameHeader(*args, **kwargs) - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f))) - assert f == f2 - round() - round(fin=1) - round(rsv1=1) - round(rsv2=1) - round(rsv3=1) - round(payload_length=1) - round(payload_length=100) - round(payload_length=1000) - round(payload_length=10000) - round(opcode=websockets.OPCODE.PING) - round(masking_key=b"test") - - def test_human_readable(self): - f = websockets.FrameHeader( - masking_key=b"test", - fin=True, - payload_length=10 - ) - assert repr(f) - f = websockets.FrameHeader() - assert repr(f) - - def test_funky(self): - f = websockets.FrameHeader(masking_key=b"test", mask=False) - raw = bytes(f) - f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) - assert not f2.mask - - def test_violations(self): - tutils.raises("opcode", websockets.FrameHeader, opcode=17) - tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") - - def test_automask(self): - f = websockets.FrameHeader(mask=True) - assert f.masking_key - - f = websockets.FrameHeader(masking_key=b"foob") - assert f.mask - - f = websockets.FrameHeader(masking_key=b"foob", mask=0) - assert not f.mask - assert f.masking_key - - -class TestFrame: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.Frame(*args, **kwargs) - raw = bytes(f) - f2 = websockets.Frame.from_file(tutils.treader(raw)) - assert f == f2 - round(b"test") - round(b"test", fin=1) - round(b"test", rsv1=1) - round(b"test", opcode=websockets.OPCODE.PING) - round(b"test", masking_key=b"test") - - def test_human_readable(self): - f = websockets.Frame() - assert repr(f) - - -def test_masker(): - tests = [ - [b"a"], - [b"four"], - [b"fourf"], - [b"fourfive"], - [b"a", b"aasdfasdfa", b"asdf"], - [b"a" * 50, b"aasdfasdfa", b"asdf"], - ] - for i in tests: - m = websockets.Masker(b"abcd") - data = b"".join([m(t) for t in i]) - data2 = websockets.Masker(b"abcd")(data) - assert data2 == b"".join(i) |