From d12515f84b32b3157fa99ac3c3a7a7318f9626ba Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 18 Aug 2016 17:31:43 +0200 Subject: websockets: refactor implementation and add tests --- netlib/websockets/__init__.py | 34 +++- netlib/websockets/frame.py | 142 ++++++---------- netlib/websockets/masker.py | 33 ++++ netlib/websockets/protocol.py | 112 ------------- netlib/websockets/utils.py | 90 ++++++++++ test/netlib/websockets/test_frame.py | 164 ++++++++++++++++++ test/netlib/websockets/test_masker.py | 23 +++ test/netlib/websockets/test_utils.py | 105 ++++++++++++ test/netlib/websockets/test_websockets.py | 269 ------------------------------ 9 files changed, 499 insertions(+), 473 deletions(-) create mode 100644 netlib/websockets/masker.py delete mode 100644 netlib/websockets/protocol.py create mode 100644 netlib/websockets/utils.py create mode 100644 test/netlib/websockets/test_frame.py create mode 100644 test/netlib/websockets/test_masker.py create mode 100644 test/netlib/websockets/test_utils.py delete mode 100644 test/netlib/websockets/test_websockets.py diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py index fea696d9..e14e8a7d 100644 --- a/netlib/websockets/__init__.py +++ b/netlib/websockets/__init__.py @@ -1,11 +1,37 @@ from __future__ import absolute_import, print_function, division -from .frame import FrameHeader, Frame, OPCODE -from .protocol import Masker, WebsocketsProtocol + +from .frame import FrameHeader +from .frame import Frame +from .frame import OPCODE +from .frame import CLOSE_REASON +from .masker import Masker +from .utils import MAGIC +from .utils import VERSION +from .utils import client_handshake_headers +from .utils import server_handshake_headers +from .utils import check_handshake +from .utils import check_client_version +from .utils import create_server_nonce +from .utils import get_extensions +from .utils import get_protocol +from .utils import get_client_key +from .utils import get_server_accept __all__ = [ "FrameHeader", "Frame", - "Masker", - "WebsocketsProtocol", "OPCODE", + "CLOSE_REASON", + "Masker", + "MAGIC", + "VERSION", + "client_handshake_headers", + "server_handshake_headers", + "check_handshake", + "check_client_version", + "create_server_nonce", + "get_extensions", + "get_protocol", + "get_client_key", + "get_server_accept", ] diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index 7d355699..e62d0e87 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -2,7 +2,6 @@ from __future__ import absolute_import import os import struct import io -import warnings import six @@ -10,7 +9,7 @@ from netlib import tcp from netlib import strutils from netlib import utils from netlib import human -from netlib.websockets import protocol +from .masker import Masker MAX_16_BIT_INT = (1 << 16) @@ -18,6 +17,7 @@ MAX_64_BIT_INT = (1 << 64) DEFAULT = object() +# RFC 6455, Section 5.2 - Base Framing Protocol OPCODE = utils.BiDi( CONTINUE=0x00, TEXT=0x01, @@ -27,6 +27,23 @@ OPCODE = utils.BiDi( PONG=0x0a ) +# RFC 6455, Section 7.4.1 - Defined Status Codes +CLOSE_REASON = utils.BiDi( + NORMAL_CLOSURE=1000, + GOING_AWAY=1001, + PROTOCOL_ERROR=1002, + UNSUPPORTED_DATA=1003, + RESERVED=1004, + RESERVED_NO_STATUS=1005, + RESERVED_ABNORMAL_CLOSURE=1006, + INVALID_PAYLOAD_DATA=1007, + POLICY_VIOLATION=1008, + MESSAGE_TOO_BIG=1009, + MANDATORY_EXTENSION=1010, + INTERNAL_ERROR=1011, + RESERVED_TLS_HANDHSAKE_FAILED=1015, +) + class FrameHeader(object): @@ -103,10 +120,6 @@ class FrameHeader(object): vals.append(" %s" % human.pretty_size(self.payload_length)) return "".join(vals) - def human_readable(self): - warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) - return repr(self) - def __bytes__(self): first_byte = utils.setbit(0, 7, self.fin) first_byte = utils.setbit(first_byte, 6, self.rsv1) @@ -128,6 +141,9 @@ class FrameHeader(object): # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length b += struct.pack('!Q', self.payload_length) + else: + raise ValueError("Payload length exceeds 64bit integer") + if self.masking_key: b += self.masking_key return b @@ -135,10 +151,6 @@ class FrameHeader(object): if six.PY2: __str__ = __bytes__ - def to_bytes(self): - warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) - return bytes(self) - @classmethod def from_file(cls, fp): """ @@ -151,19 +163,17 @@ class FrameHeader(object): rsv1 = utils.getbit(first_byte, 6) rsv2 = utils.getbit(first_byte, 5) rsv3 = utils.getbit(first_byte, 4) - # grab right-most 4 bits - opcode = first_byte & 15 + opcode = first_byte & 0xF mask_bit = utils.getbit(second_byte, 7) - # grab the next 7 bits - length_code = second_byte & 127 + length_code = second_byte & 0x7F - # payload_lengthy > 125 indicates you need to read more bytes + # payload_length > 125 indicates you need to read more bytes # to get the actual payload length if length_code <= 125: payload_length = length_code elif length_code == 126: payload_length, = struct.unpack("!H", fp.safe_read(2)) - elif length_code == 127: + else: # length_code == 127: payload_length, = struct.unpack("!Q", fp.safe_read(8)) # masking key only present if mask bit set @@ -191,31 +201,30 @@ class FrameHeader(object): class Frame(object): - """ - Represents one websockets frame. - Constructor takes human readable forms of the frame components - from_bytes() is also avaliable. - - WebSockets Frame as defined in RFC6455 - - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-------+-+-------------+-------------------------------+ - |F|R|R|R| opcode|M| Payload len | Extended payload length | - |I|S|S|S| (4) |A| (7) | (16/64) | - |N|V|V|V| |S| | (if payload len==126/127) | - | |1|2|3| |K| | | - +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - | Extended payload length continued, if payload len == 127 | - + - - - - - - - - - - - - - - - +-------------------------------+ - | |Masking-key, if MASK set to 1 | - +-------------------------------+-------------------------------+ - | Masking-key (continued) | Payload Data | - +-------------------------------- - - - - - - - - - - - - - - - + - : Payload Data continued ... : - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - | Payload Data continued ... | - +---------------------------------------------------------------+ + Represents a single WebSockets frame. + Constructor takes human readable forms of the frame components. + from_bytes() reads from a file-like object to create a new Frame. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ """ def __init__(self, payload=b"", **kwargs): @@ -223,27 +232,6 @@ class Frame(object): kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) self.header = FrameHeader(**kwargs) - @classmethod - def default(cls, message, from_client=False): - """ - Construct a basic websocket frame from some default values. - Creates a non-fragmented text frame. - """ - if from_client: - mask_bit = 1 - masking_key = os.urandom(4) - else: - mask_bit = 0 - masking_key = None - - return cls( - message, - fin=1, # final frame - opcode=OPCODE.TEXT, # text - mask=mask_bit, - masking_key=masking_key, - ) - @classmethod def from_bytes(cls, bytestring): """ @@ -258,17 +246,13 @@ class Frame(object): ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload) return ret - def human_readable(self): - warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning) - return repr(self) - def __bytes__(self): """ Serialize the frame to wire format. Returns a string. """ b = bytes(self.header) if self.header.masking_key: - b += protocol.Masker(self.header.masking_key)(self.payload) + b += Masker(self.header.masking_key)(self.payload) else: b += self.payload return b @@ -276,15 +260,6 @@ class Frame(object): if six.PY2: __str__ = __bytes__ - def to_bytes(self): - warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) - return bytes(self) - - def to_file(self, writer): - warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning) - writer.write(bytes(self)) - writer.flush() - @classmethod def from_file(cls, fp): """ @@ -297,20 +272,11 @@ class Frame(object): payload = fp.safe_read(header.payload_length) if header.mask == 1 and header.masking_key: - payload = protocol.Masker(header.masking_key)(payload) + payload = Masker(header.masking_key)(payload) - return cls( - payload, - fin=header.fin, - opcode=header.opcode, - mask=header.mask, - payload_length=header.payload_length, - masking_key=header.masking_key, - rsv1=header.rsv1, - rsv2=header.rsv2, - rsv3=header.rsv3, - length_code=header.length_code - ) + frame = cls(payload) + frame.header = header + return frame def __eq__(self, other): if isinstance(other, Frame): diff --git a/netlib/websockets/masker.py b/netlib/websockets/masker.py new file mode 100644 index 00000000..bd39ed6a --- /dev/null +++ b/netlib/websockets/masker.py @@ -0,0 +1,33 @@ +from __future__ import absolute_import + +import six + + +class Masker(object): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns. + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + + def __init__(self, key): + self.key = key + self.offset = 0 + + def mask(self, offset, data): + result = bytearray(data) + for i in range(len(data)): + if six.PY2: + result[i] ^= ord(self.key[offset % 4]) + else: + result[i] ^= self.key[offset % 4] + offset += 1 + result = bytes(result) + return result + + def __call__(self, data): + ret = self.mask(self.offset, data) + self.offset += len(ret) + return ret diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py deleted file mode 100644 index af0eef7d..00000000 --- a/netlib/websockets/protocol.py +++ /dev/null @@ -1,112 +0,0 @@ -""" -Colleciton of utility functions that implement small portions of the RFC6455 -WebSockets Protocol Useful for building WebSocket clients and servers. - -Emphassis is on readabilty, simplicity and modularity, not performance or -completeness - -This is a work in progress and does not yet contain all the utilites need to -create fully complient client/servers # -Spec: https://tools.ietf.org/html/rfc6455 - -The magic sha that websocket servers must know to prove they understand -RFC6455 -""" - -from __future__ import absolute_import -import base64 -import hashlib -import os - -import six - -from netlib import http, strutils - -websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -VERSION = "13" - - -class Masker(object): - - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - - def __init__(self, key): - self.key = key - self.offset = 0 - - def mask(self, offset, data): - result = bytearray(data) - if six.PY2: - for i in range(len(data)): - result[i] ^= ord(self.key[offset % 4]) - offset += 1 - result = str(result) - else: - - for i in range(len(data)): - result[i] ^= self.key[offset % 4] - offset += 1 - result = bytes(result) - return result - - def __call__(self, data): - ret = self.mask(self.offset, data) - self.offset += len(ret) - return ret - - -class WebsocketsProtocol(object): - - def __init__(self): - pass - - @classmethod - def client_handshake_headers(self, key=None, version=VERSION): - """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. - - Returns an instance of http.Headers - """ - if not key: - key = base64.b64encode(os.urandom(16)).decode('ascii') - return http.Headers( - sec_websocket_key=key, - sec_websocket_version=version, - connection="Upgrade", - upgrade="websocket", - ) - - @classmethod - def server_handshake_headers(self, key): - """ - The server response is a valid HTTP 101 response. - """ - return http.Headers( - sec_websocket_accept=self.create_server_nonce(key), - connection="Upgrade", - upgrade="websocket" - ) - - @classmethod - def check_client_handshake(self, headers): - if headers.get("upgrade") != "websocket": - return - return headers.get("sec-websocket-key") - - @classmethod - def check_server_handshake(self, headers): - if headers.get("upgrade") != "websocket": - return - return headers.get("sec-websocket-accept") - - @classmethod - def create_server_nonce(self, client_nonce): - return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + websockets_magic).digest()) diff --git a/netlib/websockets/utils.py b/netlib/websockets/utils.py new file mode 100644 index 00000000..aa0d39a1 --- /dev/null +++ b/netlib/websockets/utils.py @@ -0,0 +1,90 @@ +""" +Collection of WebSockets Protocol utility functions (RFC6455) +Spec: https://tools.ietf.org/html/rfc6455 +""" + +from __future__ import absolute_import + +import base64 +import hashlib +import os + +from netlib import http, strutils + +MAGIC = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +VERSION = "13" + + +def client_handshake_headers(version=None, key=None, protocol=None, extensions=None): + """ + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of http.Headers + """ + if version is None: + version = VERSION + if key is None: + key = base64.b64encode(os.urandom(16)).decode('ascii') + h = http.Headers( + connection="upgrade", + upgrade="websocket", + sec_websocket_version=version, + sec_websocket_key=key, + ) + if protocol is not None: + h['sec-websocket-protocol'] = protocol + if extensions is not None: + h['sec-websocket-extensions'] = extensions + return h + + +def server_handshake_headers(client_key, protocol=None, extensions=None): + """ + The server response is a valid HTTP 101 response. + + Returns an instance of http.Headers + """ + h = http.Headers( + connection="upgrade", + upgrade="websocket", + sec_websocket_accept=create_server_nonce(client_key), + ) + if protocol is not None: + h['sec-websocket-protocol'] = protocol + if extensions is not None: + h['sec-websocket-extensions'] = extensions + return h + + +def check_handshake(headers): + return ( + "upgrade" in headers.get("connection", "").lower() and + headers.get("upgrade", "").lower() == "websocket" and + (headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None) + ) + + +def create_server_nonce(client_nonce): + return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + MAGIC).digest()) + + +def check_client_version(headers): + return headers.get("sec-websocket-version", "") == VERSION + + +def get_extensions(headers): + return headers.get("sec-websocket-extensions", None) + + +def get_protocol(headers): + return headers.get("sec-websocket-protocol", None) + + +def get_client_key(headers): + return headers.get("sec-websocket-key", None) + + +def get_server_accept(headers): + return headers.get("sec-websocket-accept", None) diff --git a/test/netlib/websockets/test_frame.py b/test/netlib/websockets/test_frame.py new file mode 100644 index 00000000..cce39454 --- /dev/null +++ b/test/netlib/websockets/test_frame.py @@ -0,0 +1,164 @@ +import os +import codecs +import pytest + +from netlib import websockets +from netlib import tutils + + +class TestFrameHeader(object): + + @pytest.mark.parametrize("input,expected", [ + (0, '0100'), + (125, '017D'), + (126, '017E007E'), + (127, '017E007F'), + (142, '017E008E'), + (65534, '017EFFFE'), + (65535, '017EFFFF'), + (65536, '017F0000000000010000'), + (8589934591, '017F00000001FFFFFFFF'), + (2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'), + ]) + def test_serialization_length(self, input, expected): + h = websockets.FrameHeader( + opcode=websockets.OPCODE.TEXT, + payload_length=input, + ) + assert bytes(h) == codecs.decode(expected, 'hex') + + def test_serialization_too_large(self): + h = websockets.FrameHeader( + payload_length=2 ** 64 + 1, + ) + with pytest.raises(ValueError): + bytes(h) + + @pytest.mark.parametrize("input,expected", [ + ('0100', 0), + ('017D', 125), + ('017E007E', 126), + ('017E007F', 127), + ('017E008E', 142), + ('017EFFFE', 65534), + ('017EFFFF', 65535), + ('017F0000000000010000', 65536), + ('017F00000001FFFFFFFF', 8589934591), + ('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1), + ]) + def test_deserialization_length(self, input, expected): + h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) + assert h.payload_length == expected + + @pytest.mark.parametrize("input,expected", [ + ('0100', (False, None)), + ('018000000000', (True, '00000000')), + ('018012345678', (True, '12345678')), + ]) + def test_deserialization_masking(self, input, expected): + h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) + assert h.mask == expected[0] + if h.mask: + assert h.masking_key == codecs.decode(expected[1], 'hex') + + def test_equality(self): + h = websockets.FrameHeader(mask=True, masking_key=b'1234') + h2 = websockets.FrameHeader(mask=True, masking_key=b'1234') + assert h == h2 + + h = websockets.FrameHeader(fin=True) + h2 = websockets.FrameHeader(fin=False) + assert h != h2 + + assert h != 'foobar' + + def test_roundtrip(self): + def round(*args, **kwargs): + h = websockets.FrameHeader(*args, **kwargs) + h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h))) + assert h == h2 + + round() + round(fin=True) + round(rsv1=True) + round(rsv2=True) + round(rsv3=True) + round(payload_length=1) + round(payload_length=100) + round(payload_length=1000) + round(payload_length=10000) + round(opcode=websockets.OPCODE.PING) + round(masking_key=b"test") + + def test_human_readable(self): + f = websockets.FrameHeader( + masking_key=b"test", + fin=True, + payload_length=10 + ) + assert repr(f) + + f = websockets.FrameHeader() + assert repr(f) + + def test_funky(self): + f = websockets.FrameHeader(masking_key=b"test", mask=False) + raw = bytes(f) + f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) + assert not f2.mask + + def test_violations(self): + tutils.raises("opcode", websockets.FrameHeader, opcode=17) + tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") + + def test_automask(self): + f = websockets.FrameHeader(mask=True) + assert f.masking_key + + f = websockets.FrameHeader(masking_key=b"foob") + assert f.mask + + f = websockets.FrameHeader(masking_key=b"foob", mask=0) + assert not f.mask + assert f.masking_key + + +class TestFrame(object): + def test_equality(self): + f = websockets.Frame(payload=b'1234') + f2 = websockets.Frame(payload=b'1234') + assert f == f2 + + assert f != b'1234' + + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.Frame(*args, **kwargs) + raw = bytes(f) + f2 = websockets.Frame.from_file(tutils.treader(raw)) + assert f == f2 + round(b"test") + round(b"test", fin=1) + round(b"test", rsv1=1) + round(b"test", opcode=websockets.OPCODE.PING) + round(b"test", masking_key=b"test") + + def test_human_readable(self): + f = websockets.Frame() + assert repr(f) + + f = websockets.Frame(b"foobar") + assert "foobar" in repr(f) + + @pytest.mark.parametrize("masked", [True, False]) + @pytest.mark.parametrize("length", [100, 50000, 150000]) + def test_serialization_bijection(self, masked, length): + frame = websockets.Frame( + os.urandom(length), + fin=True, + opcode=websockets.OPCODE.TEXT, + mask=int(masked), + masking_key=(os.urandom(4) if masked else None) + ) + serialized = bytes(frame) + assert frame == websockets.Frame.from_bytes(serialized) diff --git a/test/netlib/websockets/test_masker.py b/test/netlib/websockets/test_masker.py new file mode 100644 index 00000000..528fce71 --- /dev/null +++ b/test/netlib/websockets/test_masker.py @@ -0,0 +1,23 @@ +import codecs +import pytest + +from netlib import websockets + + +class TestMasker(object): + + @pytest.mark.parametrize("input,expected", [ + ([b"a"], '00'), + ([b"four"], '070d1616'), + ([b"fourf"], '070d161607'), + ([b"fourfive"], '070d1616070b1501'), + ([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'), + ([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'), # noqa + ]) + def test_masker(self, input, expected): + m = websockets.Masker(b"abcd") + data = b"".join([m(t) for t in input]) + assert data == codecs.decode(expected, 'hex') + + data = websockets.Masker(b"abcd")(data) + assert data == b"".join(input) diff --git a/test/netlib/websockets/test_utils.py b/test/netlib/websockets/test_utils.py new file mode 100644 index 00000000..34765e04 --- /dev/null +++ b/test/netlib/websockets/test_utils.py @@ -0,0 +1,105 @@ +import pytest + +from netlib import http +from netlib import websockets + + +class TestUtils(object): + + def test_client_handshake_headers(self): + h = websockets.client_handshake_headers(version='42') + assert h['sec-websocket-version'] == '42' + + h = websockets.client_handshake_headers(key='some-key') + assert h['sec-websocket-key'] == 'some-key' + + h = websockets.client_handshake_headers(protocol='foobar') + assert h['sec-websocket-protocol'] == 'foobar' + + h = websockets.client_handshake_headers(extensions='foo; bar') + assert h['sec-websocket-extensions'] == 'foo; bar' + + def test_server_handshake_headers(self): + h = websockets.server_handshake_headers('some-key') + assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw=' + assert 'sec-websocket-protocol' not in h + assert 'sec-websocket-extensions' not in h + + h = websockets.server_handshake_headers('some-key', 'foobar', 'foo; bar') + assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw=' + assert h['sec-websocket-protocol'] == 'foobar' + assert h['sec-websocket-extensions'] == 'foo; bar' + + @pytest.mark.parametrize("input,expected", [ + ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], True), + ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-accept', b'foobar')], True), + ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-KeY', b'foobar')], True), + ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-AccePt', b'foobar')], True), + ([(b'connection', b'foo'), (b'upgrade', b'bar'), (b'sec-websocket-key', b'foobar')], False), + ([(b'connection', b'upgrade'), (b'upgrade', b'websocket')], False), + ([(b'connection', b'upgrade'), (b'sec-websocket-key', b'foobar')], False), + ([(b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], False), + ([], False), + ]) + def test_check_handshake(self, input, expected): + h = http.Headers(input) + assert websockets.check_handshake(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-version', b'13')], True), + ([(b'Sec-WebSockeT-VerSion', b'13')], True), + ([(b'sec-websocket-version', b'9')], False), + ([(b'sec-websocket-version', b'42')], False), + ([(b'sec-websocket-version', b'')], False), + ([], False), + ]) + def test_check_client_version(self, input, expected): + h = http.Headers(input) + assert websockets.check_client_version(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ('foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='), + (b'foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='), + ]) + def test_create_server_nonce(self, input, expected): + assert websockets.create_server_nonce(input) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-extensions', b'foo; bar')], 'foo; bar'), + ([(b'Sec-WebSockeT-ExteNsionS', b'foo; bar')], 'foo; bar'), + ([(b'sec-websocket-extensions', b'')], ''), + ([], None), + ]) + def test_get_extensions(self, input, expected): + h = http.Headers(input) + assert websockets.get_extensions(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-protocol', b'foobar')], 'foobar'), + ([(b'Sec-WebSockeT-ProTocoL', b'foobar')], 'foobar'), + ([(b'sec-websocket-protocol', b'')], ''), + ([], None), + ]) + def test_get_protocol(self, input, expected): + h = http.Headers(input) + assert websockets.get_protocol(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-key', b'foobar')], 'foobar'), + ([(b'Sec-WebSockeT-KeY', b'foobar')], 'foobar'), + ([(b'sec-websocket-key', b'')], ''), + ([], None), + ]) + def test_get_client_key(self, input, expected): + h = http.Headers(input) + assert websockets.get_client_key(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-accept', b'foobar')], 'foobar'), + ([(b'Sec-WebSockeT-AccepT', b'foobar')], 'foobar'), + ([(b'sec-websocket-accept', b'')], ''), + ([], None), + ]) + def test_get_server_accept(self, input, expected): + h = http.Headers(input) + assert websockets.get_server_accept(h) == expected diff --git a/test/netlib/websockets/test_websockets.py b/test/netlib/websockets/test_websockets.py deleted file mode 100644 index 50fa26e6..00000000 --- a/test/netlib/websockets/test_websockets.py +++ /dev/null @@ -1,269 +0,0 @@ -import os - -from netlib.http.http1 import read_response, read_request - -from netlib import tcp -from netlib import tutils -from netlib import websockets -from netlib.http import status_codes -from netlib.tutils import treq -from netlib import exceptions - -from .. import tservers - - -class WebSocketsEchoHandler(tcp.BaseHandler): - - def __init__(self, connection, address, server): - super(WebSocketsEchoHandler, self).__init__( - connection, address, server - ) - self.protocol = websockets.WebsocketsProtocol() - self.handshake_done = False - - def handle(self): - while True: - if not self.handshake_done: - self.handshake() - else: - self.read_next_message() - - def read_next_message(self): - frame = websockets.Frame.from_file(self.rfile) - self.on_message(frame.payload) - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=False) - frame.to_file(self.wfile) - - def handshake(self): - - req = read_request(self.rfile) - key = self.protocol.check_client_handshake(req.headers) - - preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble.encode() + b"\r\n") - headers = self.protocol.server_handshake_headers(key) - self.wfile.write(str(headers) + "\r\n") - self.wfile.flush() - self.handshake_done = True - - def on_message(self, message): - if message is not None: - self.send_message(message) - - -class WebSocketsClient(tcp.TCPClient): - - def __init__(self, address, source_address=None): - super(WebSocketsClient, self).__init__(address, source_address) - self.protocol = websockets.WebsocketsProtocol() - self.client_nonce = None - - def connect(self): - super(WebSocketsClient, self).connect() - - preamble = b'GET / HTTP/1.1' - self.wfile.write(preamble + b"\r\n") - headers = self.protocol.client_handshake_headers() - self.client_nonce = headers["sec-websocket-key"].encode("ascii") - self.wfile.write(bytes(headers) + b"\r\n") - self.wfile.flush() - - resp = read_response(self.rfile, treq(method=b"GET")) - server_nonce = self.protocol.check_server_handshake(resp.headers) - - if not server_nonce == self.protocol.create_server_nonce(self.client_nonce): - self.close() - - def read_next_message(self): - return websockets.Frame.from_file(self.rfile).payload - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=True) - frame.to_file(self.wfile) - - -class TestWebSockets(tservers.ServerTestBase): - handler = WebSocketsEchoHandler - - def __init__(self): - self.protocol = websockets.WebsocketsProtocol() - - def random_bytes(self, n=100): - return os.urandom(n) - - def echo(self, msg): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(msg) - response = client.read_next_message() - assert response == msg - - def test_simple_echo(self): - self.echo(b"hello I'm the client") - - def test_frame_sizes(self): - # length can fit in the the 7 bit payload length - small_msg = self.random_bytes(100) - # 50kb, sligthly larger than can fit in a 7 bit int - medium_msg = self.random_bytes(50000) - # 150kb, slightly larger than can fit in a 16 bit int - large_msg = self.random_bytes(150000) - - self.echo(small_msg) - self.echo(medium_msg) - self.echo(large_msg) - - def test_default_builder(self): - """ - default builder should always generate valid frames - """ - msg = self.random_bytes() - assert websockets.Frame.default(msg, from_client=True) - assert websockets.Frame.default(msg, from_client=False) - - def test_serialization_bijection(self): - """ - Ensure that various frame types can be serialized/deserialized back - and forth between to_bytes() and from_bytes() - """ - for is_client in [True, False]: - for num_bytes in [100, 50000, 150000]: - frame = websockets.Frame.default( - self.random_bytes(num_bytes), is_client - ) - frame2 = websockets.Frame.from_bytes( - frame.to_bytes() - ) - assert frame == frame2 - - bytes = b'\x81\x03cba' - assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes - - def test_check_server_handshake(self): - headers = self.protocol.server_handshake_headers("key") - assert self.protocol.check_server_handshake(headers) - headers["Upgrade"] = "not_websocket" - assert not self.protocol.check_server_handshake(headers) - - def test_check_client_handshake(self): - headers = self.protocol.client_handshake_headers("key") - assert self.protocol.check_client_handshake(headers) == "key" - headers["Upgrade"] = "not_websocket" - assert not self.protocol.check_client_handshake(headers) - - -class BadHandshakeHandler(WebSocketsEchoHandler): - - def handshake(self): - - client_hs = read_request(self.rfile) - self.protocol.check_client_handshake(client_hs.headers) - - preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble.encode()) - headers = self.protocol.server_handshake_headers(b"malformed key") - self.wfile.write(bytes(headers) + b"\r\n") - self.wfile.flush() - self.handshake_done = True - - -class TestBadHandshake(tservers.ServerTestBase): - - """ - Ensure that the client disconnects if the server handshake is malformed - """ - handler = BadHandshakeHandler - - def test(self): - with tutils.raises(exceptions.TcpDisconnect): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(b"hello") - - -class TestFrameHeader: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.FrameHeader(*args, **kwargs) - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f))) - assert f == f2 - round() - round(fin=1) - round(rsv1=1) - round(rsv2=1) - round(rsv3=1) - round(payload_length=1) - round(payload_length=100) - round(payload_length=1000) - round(payload_length=10000) - round(opcode=websockets.OPCODE.PING) - round(masking_key=b"test") - - def test_human_readable(self): - f = websockets.FrameHeader( - masking_key=b"test", - fin=True, - payload_length=10 - ) - assert repr(f) - f = websockets.FrameHeader() - assert repr(f) - - def test_funky(self): - f = websockets.FrameHeader(masking_key=b"test", mask=False) - raw = bytes(f) - f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) - assert not f2.mask - - def test_violations(self): - tutils.raises("opcode", websockets.FrameHeader, opcode=17) - tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") - - def test_automask(self): - f = websockets.FrameHeader(mask=True) - assert f.masking_key - - f = websockets.FrameHeader(masking_key=b"foob") - assert f.mask - - f = websockets.FrameHeader(masking_key=b"foob", mask=0) - assert not f.mask - assert f.masking_key - - -class TestFrame: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.Frame(*args, **kwargs) - raw = bytes(f) - f2 = websockets.Frame.from_file(tutils.treader(raw)) - assert f == f2 - round(b"test") - round(b"test", fin=1) - round(b"test", rsv1=1) - round(b"test", opcode=websockets.OPCODE.PING) - round(b"test", masking_key=b"test") - - def test_human_readable(self): - f = websockets.Frame() - assert repr(f) - - -def test_masker(): - tests = [ - [b"a"], - [b"four"], - [b"fourf"], - [b"fourfive"], - [b"a", b"aasdfasdfa", b"asdf"], - [b"a" * 50, b"aasdfasdfa", b"asdf"], - ] - for i in tests: - m = websockets.Masker(b"abcd") - data = b"".join([m(t) for t in i]) - data2 = websockets.Masker(b"abcd")(data) - assert data2 == b"".join(i) -- cgit v1.2.3 From e5b0dae7e9ef8d2ce62fc263c377c76546190825 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 16 Aug 2016 18:31:50 +0200 Subject: add websockets support to mitmproxy --- mitmproxy/protocol/http.py | 20 +- mitmproxy/protocol/websockets.py | 140 ++++++++++++++ pathod/language/http.py | 4 +- pathod/pathoc.py | 2 +- pathod/pathod.py | 9 +- pathod/protocols/websockets.py | 2 +- test/mitmproxy/protocol/test_websockets.py | 297 +++++++++++++++++++++++++++++ 7 files changed, 465 insertions(+), 9 deletions(-) create mode 100644 mitmproxy/protocol/websockets.py create mode 100644 test/mitmproxy/protocol/test_websockets.py diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py index d81fc8ca..fbb52c92 100644 --- a/mitmproxy/protocol/http.py +++ b/mitmproxy/protocol/http.py @@ -7,12 +7,15 @@ import traceback import h2.exceptions import six -import netlib.exceptions from mitmproxy import exceptions from mitmproxy import models from mitmproxy.protocol import base +from .websockets import WebSocketsLayer + +import netlib.exceptions from netlib import http from netlib import tcp +from netlib import websockets class _HttpTransmissionLayer(base.Layer): @@ -189,6 +192,21 @@ class HttpLayer(base.Layer): self.process_request_hook(flow) try: + # WebSockets + if websockets.check_handshake(request.headers): + if websockets.check_client_version(request.headers): + layer = WebSocketsLayer(self, request) + layer() + return + else: + # we only support RFC6455 with WebSockets version 13 + self.send_response(models.make_error_response( + 400, + http.status_codes.RESPONSES.get(400), + http.Headers(sec_websocket_version="13") + )) + return + if not flow.response: self.establish_server_connection( flow.request.host, diff --git a/mitmproxy/protocol/websockets.py b/mitmproxy/protocol/websockets.py new file mode 100644 index 00000000..05eaa537 --- /dev/null +++ b/mitmproxy/protocol/websockets.py @@ -0,0 +1,140 @@ +from __future__ import absolute_import, print_function, division + +import socket +import struct + +from OpenSSL import SSL + +from mitmproxy import exceptions +from mitmproxy import models +from mitmproxy.protocol import base + +import netlib.exceptions +from netlib import tcp +from netlib import http +from netlib import websockets + + +class WebSocketsLayer(base.Layer): + """ + WebSockets layer to intercept, modify, and forward WebSockets connections + + Only version 13 is supported (as specified in RFC6455) + Only HTTP/1.1-initiated connections are supported. + + The client starts by sending an Upgrade-request. + In order to determine the handshake and negotiate the correct protocol + and extensions, the Upgrade-request is forwarded to the server. + The response from the server is then parsed and negotiated settings are extracted. + Finally the handshake is completed by forwarding the server-response to the client. + After that, only WebSockets frames are exchanged. + + PING/PONG frames pass through and must be answered by the other endpoint. + + CLOSE frames are forwarded before this WebSocketsLayer terminates. + + This layer is transparent to any negotiated extensions. + This layer is transparent to any negotiated subprotocols. + Only raw frames are forwarded to the other endpoint. + """ + + def __init__(self, ctx, request): + super(WebSocketsLayer, self).__init__(ctx) + self._request = request + + self.client_key = websockets.get_client_key(self._request.headers) + self.client_protocol = websockets.get_protocol(self._request.headers) + self.client_extensions = websockets.get_extensions(self._request.headers) + + self.server_accept = None + self.server_protocol = None + self.server_extensions = None + + def _initiate_server_conn(self): + self.establish_server_connection( + self._request.host, + self._request.port, + self._request.scheme, + ) + + self.server_conn.send(netlib.http.http1.assemble_request(self._request)) + response = netlib.http.http1.read_response(self.server_conn.rfile, self._request, body_size_limit=None) + + if not websockets.check_handshake(response.headers): + raise exceptions.ProtocolException("Establishing WebSockets connection with server failed: {}".format(response.headers)) + + self.server_accept = websockets.get_server_accept(response.headers) + self.server_protocol = websockets.get_protocol(response.headers) + self.server_extensions = websockets.get_extensions(response.headers) + + def _complete_handshake(self): + headers = websockets.server_handshake_headers(self.client_key, self.server_protocol, self.server_extensions) + self.send_response(models.HTTPResponse( + self._request.http_version, + 101, + http.status_codes.RESPONSES.get(101), + headers, + b"", + )) + + def _handle_frame(self, frame, source_conn, other_conn, is_server): + self.log( + "WebSockets Frame received from {}".format("server" if is_server else "client"), + "debug", + [repr(frame)] + ) + + if frame.header.opcode & 0x8 == 0: + # forward the data frame to the other side + other_conn.send(bytes(frame)) + self.log("WebSockets frame received by {}: {}".format(is_server, frame), "debug") + elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG): + # just forward the ping/pong to the other side + other_conn.send(bytes(frame)) + elif frame.header.opcode == websockets.OPCODE.CLOSE: + other_conn.send(bytes(frame)) + + code = '(status code missing)' + msg = None + reason = '(message missing)' + if len(frame.payload) >= 2: + code, = struct.unpack('!H', frame.payload[:2]) + msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code') + if len(frame.payload) > 2: + reason = frame.payload[2:] + self.log("WebSockets connection closed: {} {}, {}".format(code, msg, reason), "info") + + # close the connection + return False + else: + # unknown frame - just forward it + other_conn.send(bytes(frame)) + + # continue the connection + return True + + def __call__(self): + self._initiate_server_conn() + self._complete_handshake() + + client = self.client_conn.connection + server = self.server_conn.connection + conns = [client, server] + + try: + while not self.channel.should_exit.is_set(): + r = tcp.ssl_read_select(conns, 1) + for conn in r: + source_conn = self.client_conn if conn == client else self.server_conn + other_conn = self.server_conn if conn == client else self.client_conn + is_server = (conn == self.server_conn.connection) + + frame = websockets.Frame.from_file(source_conn.rfile) + + if not self._handle_frame(frame, source_conn, other_conn, is_server): + return + except (socket.error, netlib.exceptions.TcpException, SSL.Error) as e: + self.log("WebSockets connection closed unexpectedly by {}: {}".format( + "server" if is_server else "client", repr(e)), "info") + except Exception as e: # pragma: no cover + raise exceptions.ProtocolException("Error in WebSockets connection: {}".format(repr(e))) diff --git a/pathod/language/http.py b/pathod/language/http.py index fdc5bba6..46027ca3 100644 --- a/pathod/language/http.py +++ b/pathod/language/http.py @@ -198,7 +198,7 @@ class Response(_HTTPMessage): 1, StatusCode(101) ) - headers = netlib.websockets.WebsocketsProtocol.server_handshake_headers( + headers = netlib.websockets.server_handshake_headers( settings.websocket_key ) for i in headers.fields: @@ -310,7 +310,7 @@ class Request(_HTTPMessage): 1, Method("get") ) - for i in netlib.websockets.WebsocketsProtocol.client_handshake_headers().fields: + for i in netlib.websockets.client_handshake_headers().fields: if not get_header(i[0], self.headers): tokens.append( Header( diff --git a/pathod/pathoc.py b/pathod/pathoc.py index 5831ba3e..a8923013 100644 --- a/pathod/pathoc.py +++ b/pathod/pathoc.py @@ -139,7 +139,7 @@ class WebsocketFrameReader(basethread.BaseThread): except exceptions.TcpDisconnect: return self.frames_queue.put(frm) - log("<< %s" % frm.header.human_readable()) + log("<< %s" % repr(frm.header)) if self.ws_read_limit is not None: self.ws_read_limit -= 1 starttime = time.time() diff --git a/pathod/pathod.py b/pathod/pathod.py index 7087cba6..bd0feb73 100644 --- a/pathod/pathod.py +++ b/pathod/pathod.py @@ -173,12 +173,13 @@ class PathodHandler(tcp.BaseHandler): retlog["cipher"] = self.get_current_cipher() m = utils.MemBool() - websocket_key = websockets.WebsocketsProtocol.check_client_handshake(headers) - self.settings.websocket_key = websocket_key + + valid_websockets_handshake = websockets.check_handshake(headers) + self.settings.websocket_key = websockets.get_client_key(headers) # If this is a websocket initiation, we respond with a proper # server response, unless over-ridden. - if websocket_key: + if valid_websockets_handshake: anchor_gen = language.parse_pathod("ws") else: anchor_gen = None @@ -225,7 +226,7 @@ class PathodHandler(tcp.BaseHandler): spec, lg ) - if nexthandler and websocket_key: + if nexthandler and valid_websockets_handshake: self.protocol = protocols.websockets.WebsocketsProtocol(self) return self.protocol.handle_websocket, retlog else: diff --git a/pathod/protocols/websockets.py b/pathod/protocols/websockets.py index a34e75e8..df83461a 100644 --- a/pathod/protocols/websockets.py +++ b/pathod/protocols/websockets.py @@ -20,7 +20,7 @@ class WebsocketsProtocol: lg("Error reading websocket frame: %s" % e) return None, None ended = time.time() - lg(frm.human_readable()) + lg(repr(frm)) retlog = dict( type="inbound", protocol="websockets", diff --git a/test/mitmproxy/protocol/test_websockets.py b/test/mitmproxy/protocol/test_websockets.py new file mode 100644 index 00000000..cc478c0b --- /dev/null +++ b/test/mitmproxy/protocol/test_websockets.py @@ -0,0 +1,297 @@ +import pytest +import os +import tempfile +import traceback + +from mitmproxy import options +from mitmproxy.proxy.config import ProxyConfig + +import netlib +from netlib import http +from ...netlib import tservers as netlib_tservers +from .. import tservers + +from netlib import websockets + + +class _WebSocketsServerBase(netlib_tservers.ServerTestBase): + + class handler(netlib.tcp.BaseHandler): + + def handle(self): + try: + request = http.http1.read_request(self.rfile) + assert websockets.check_handshake(request.headers) + + response = http.Response( + "HTTP/1.1", + 101, + reason=http.status_codes.RESPONSES.get(101), + headers=http.Headers( + connection='upgrade', + upgrade='websocket', + sec_websocket_accept=b'', + ), + content=b'', + ) + self.wfile.write(http.http1.assemble_response(response)) + self.wfile.flush() + + self.server.handle_websockets(self.rfile, self.wfile) + except: + traceback.print_exc() + + +class _WebSocketsTestBase(object): + + @classmethod + def setup_class(cls): + opts = cls.get_options() + cls.config = ProxyConfig(opts) + + tmaster = tservers.TestMaster(opts, cls.config) + tmaster.start_app(options.APP_HOST, options.APP_PORT) + cls.proxy = tservers.ProxyThread(tmaster) + cls.proxy.start() + + @classmethod + def teardown_class(cls): + cls.proxy.shutdown() + + @classmethod + def get_options(cls): + opts = options.Options( + listen_port=0, + no_upstream_cert=False, + ssl_insecure=True + ) + opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return opts + + @property + def master(self): + return self.proxy.tmaster + + def setup(self): + self.master.clear_log() + self.master.state.clear() + self.server.server.handle_websockets = self.handle_websockets + + def _setup_connection(self): + client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port)) + client.connect() + + request = http.Request( + "authority", + "CONNECT", + "", + "localhost", + self.server.server.address.port, + "", + "HTTP/1.1", + content=b'') + client.wfile.write(http.http1.assemble_request(request)) + client.wfile.flush() + + response = http.http1.read_response(client.rfile, request) + + if self.ssl: + client.convert_to_ssl() + assert client.ssl_established + + request = http.Request( + "relative", + "GET", + "http", + "localhost", + self.server.server.address.port, + "/ws", + "HTTP/1.1", + headers=http.Headers( + connection="upgrade", + upgrade="websocket", + sec_websocket_version="13", + sec_websocket_key="1234", + ), + content=b'') + client.wfile.write(http.http1.assemble_request(request)) + client.wfile.flush() + + response = http.http1.read_response(client.rfile, request) + assert websockets.check_handshake(response.headers) + + return client + + +class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase): + + @classmethod + def setup_class(cls): + _WebSocketsTestBase.setup_class() + _WebSocketsServerBase.setup_class(ssl=cls.ssl) + + @classmethod + def teardown_class(cls): + _WebSocketsTestBase.teardown_class() + _WebSocketsServerBase.teardown_class() + + +class TestSimple(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + def test_simple(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'server-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'client-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + +class TestSimpleTLS(_WebSocketsTest): + ssl = True + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + def test_simple_tls(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'server-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'client-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + +class TestPing(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'foobar' + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received'))) + wfile.flush() + + def test_ping(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.TEXT + assert frame.payload == b'pong-received' + + +class TestPong(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + frame = websockets.Frame.from_file(rfile) + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'foobar' + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + wfile.flush() + + def test_pong(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'foobar' + + +class TestClose(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(rfile) + + def test_close(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + def test_close_payload_1(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + client.wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + def test_close_payload_2(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + client.wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + +class TestInvalidFrame(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar'))) + wfile.flush() + + def test_invalid_frame(self): + client = self._setup_connection() + + # with pytest.raises(netlib.exceptions.TcpDisconnect): + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == 15 + assert frame.payload == b'foobar' -- cgit v1.2.3 From cd3d30633fae965044d5f320b5544dfbd039693f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 21 Aug 2016 11:12:41 +0200 Subject: websockets: update protocol detection --- mitmproxy/controller.py | 2 ++ mitmproxy/flow/master.py | 4 ++++ mitmproxy/protocol/__init__.py | 4 ++++ mitmproxy/protocol/http.py | 21 +++++------------- mitmproxy/protocol/websockets.py | 48 +++++++--------------------------------- mitmproxy/proxy/root_context.py | 19 +++++++++++++--- 6 files changed, 39 insertions(+), 59 deletions(-) diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index c262b192..d886af97 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -28,6 +28,8 @@ Events = frozenset([ "response", "responseheaders", + "websockets_handshake", + "next_layer", "error", diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py index 0475ef4e..9cdcc8dd 100644 --- a/mitmproxy/flow/master.py +++ b/mitmproxy/flow/master.py @@ -334,6 +334,10 @@ class FlowMaster(controller.Master): self.client_playback.clear(f) return f + @controller.handler + def websockets_handshake(self, f): + return f + def handle_intercept(self, f): self.state.update_flow(f) diff --git a/mitmproxy/protocol/__init__.py b/mitmproxy/protocol/__init__.py index 510cd195..b99b55bd 100644 --- a/mitmproxy/protocol/__init__.py +++ b/mitmproxy/protocol/__init__.py @@ -29,8 +29,10 @@ from __future__ import absolute_import, print_function, division from .base import Layer, ServerConnectionMixin from .http import UpstreamConnectLayer +from .http import HttpLayer from .http1 import Http1Layer from .http2 import Http2Layer +from .websockets import WebSocketsLayer from .rawtcp import RawTCPLayer from .tls import TlsClientHello from .tls import TlsLayer @@ -40,7 +42,9 @@ __all__ = [ "Layer", "ServerConnectionMixin", "TlsLayer", "is_tls_record_magic", "TlsClientHello", "UpstreamConnectLayer", + "HttpLayer", "Http1Layer", "Http2Layer", + "WebSocketsLayer", "RawTCPLayer", ] diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py index fbb52c92..1418d6e9 100644 --- a/mitmproxy/protocol/http.py +++ b/mitmproxy/protocol/http.py @@ -10,7 +10,6 @@ import six from mitmproxy import exceptions from mitmproxy import models from mitmproxy.protocol import base -from .websockets import WebSocketsLayer import netlib.exceptions from netlib import http @@ -192,20 +191,10 @@ class HttpLayer(base.Layer): self.process_request_hook(flow) try: - # WebSockets - if websockets.check_handshake(request.headers): - if websockets.check_client_version(request.headers): - layer = WebSocketsLayer(self, request) - layer() - return - else: - # we only support RFC6455 with WebSockets version 13 - self.send_response(models.make_error_response( - 400, - http.status_codes.RESPONSES.get(400), - http.Headers(sec_websocket_version="13") - )) - return + if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers): + # we only support RFC6455 with WebSockets version 13 + # allow inline scripts to manupulate the client handshake + self.channel.ask("websockets_handshake", flow) if not flow.response: self.establish_server_connection( @@ -230,7 +219,7 @@ class HttpLayer(base.Layer): # It may be useful to pass additional args (such as the upgrade header) # to next_layer in the future if flow.response.status_code == 101: - layer = self.ctx.next_layer(self) + layer = self.ctx.next_layer(self, flow) layer() return diff --git a/mitmproxy/protocol/websockets.py b/mitmproxy/protocol/websockets.py index 05eaa537..f15a38ef 100644 --- a/mitmproxy/protocol/websockets.py +++ b/mitmproxy/protocol/websockets.py @@ -6,12 +6,10 @@ import struct from OpenSSL import SSL from mitmproxy import exceptions -from mitmproxy import models from mitmproxy.protocol import base import netlib.exceptions from netlib import tcp -from netlib import http from netlib import websockets @@ -38,44 +36,17 @@ class WebSocketsLayer(base.Layer): Only raw frames are forwarded to the other endpoint. """ - def __init__(self, ctx, request): + def __init__(self, ctx, flow): super(WebSocketsLayer, self).__init__(ctx) - self._request = request + self._flow = flow - self.client_key = websockets.get_client_key(self._request.headers) - self.client_protocol = websockets.get_protocol(self._request.headers) - self.client_extensions = websockets.get_extensions(self._request.headers) + self.client_key = websockets.get_client_key(self._flow.request.headers) + self.client_protocol = websockets.get_protocol(self._flow.request.headers) + self.client_extensions = websockets.get_extensions(self._flow.request.headers) - self.server_accept = None - self.server_protocol = None - self.server_extensions = None - - def _initiate_server_conn(self): - self.establish_server_connection( - self._request.host, - self._request.port, - self._request.scheme, - ) - - self.server_conn.send(netlib.http.http1.assemble_request(self._request)) - response = netlib.http.http1.read_response(self.server_conn.rfile, self._request, body_size_limit=None) - - if not websockets.check_handshake(response.headers): - raise exceptions.ProtocolException("Establishing WebSockets connection with server failed: {}".format(response.headers)) - - self.server_accept = websockets.get_server_accept(response.headers) - self.server_protocol = websockets.get_protocol(response.headers) - self.server_extensions = websockets.get_extensions(response.headers) - - def _complete_handshake(self): - headers = websockets.server_handshake_headers(self.client_key, self.server_protocol, self.server_extensions) - self.send_response(models.HTTPResponse( - self._request.http_version, - 101, - http.status_codes.RESPONSES.get(101), - headers, - b"", - )) + self.server_accept = websockets.get_server_accept(self._flow.response.headers) + self.server_protocol = websockets.get_protocol(self._flow.response.headers) + self.server_extensions = websockets.get_extensions(self._flow.response.headers) def _handle_frame(self, frame, source_conn, other_conn, is_server): self.log( @@ -114,9 +85,6 @@ class WebSocketsLayer(base.Layer): return True def __call__(self): - self._initiate_server_conn() - self._complete_handshake() - client = self.client_conn.connection server = self.server_conn.connection conns = [client, server] diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index 81dd625c..95611362 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -4,6 +4,7 @@ import sys import six +from netlib import websockets import netlib.exceptions from mitmproxy import exceptions from mitmproxy import protocol @@ -32,7 +33,7 @@ class RootContext(object): self.channel = channel self.config = config - def next_layer(self, top_layer): + def next_layer(self, top_layer, flow=None): """ This function determines the next layer in the protocol stack. @@ -42,10 +43,22 @@ class RootContext(object): Returns: The next layer """ - layer = self._next_layer(top_layer) + layer = self._next_layer(top_layer, flow) return self.channel.ask("next_layer", layer) - def _next_layer(self, top_layer): + def _next_layer(self, top_layer, flow): + if flow is not None: + # We already have a flow, try to derive the next information from it + + # Check for WebSockets handshake + is_websockets = ( + flow and + websockets.check_handshake(flow.request.headers) and + websockets.check_handshake(flow.response.headers) + ) + if isinstance(top_layer, protocol.HttpLayer) and is_websockets: + return protocol.WebSocketsLayer(top_layer, flow) + try: d = top_layer.client_conn.rfile.peek(3) except netlib.exceptions.TcpException as e: -- cgit v1.2.3 From 823d8a5da89755d9f9d85f16434ddad26fd0ac6d Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 21 Aug 2016 11:29:27 +0200 Subject: add docs for websocket_handshake hook --- docs/scripting/inlinescripts.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/scripting/inlinescripts.rst b/docs/scripting/inlinescripts.rst index bc9d5ff5..e1c01b17 100644 --- a/docs/scripting/inlinescripts.rst +++ b/docs/scripting/inlinescripts.rst @@ -126,6 +126,18 @@ HTTP Events :param HTTPFlow flow: The flow containing the error. It is guaranteed to have non-None ``error`` attribute. +WebSockets Events +^^^^^^^^^^^^^^^^^ + +.. py:function:: websockets_handshake(context, flow) + + Called when a client wants to establish a WebSockets connection. + The WebSockets-specific headers can be manipulated to manipulate the handshake. + The ``flow`` object is guaranteed to have a non-None ``request`` attribute. + + :param HTTPFlow flow: The flow containing the request which has been received. + The object is guaranteed to have a non-None ``request`` attribute. + TCP Events ^^^^^^^^^^ -- cgit v1.2.3 From 0d0c2c788df4b60e951e6fcc13b479de8cec22c1 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 1 Sep 2016 10:06:51 +0200 Subject: cleanup tests --- test/mitmproxy/protocol/__init__.py | 0 test/mitmproxy/protocol/test_http1.py | 65 ++ test/mitmproxy/protocol/test_http2.py | 925 +++++++++++++++++++++++++++++ test/mitmproxy/protocol/test_websockets.py | 2 + test/mitmproxy/test_protocol_http1.py | 63 -- test/mitmproxy/test_protocol_http2.py | 925 ----------------------------- 6 files changed, 992 insertions(+), 988 deletions(-) create mode 100644 test/mitmproxy/protocol/__init__.py create mode 100644 test/mitmproxy/protocol/test_http1.py create mode 100644 test/mitmproxy/protocol/test_http2.py delete mode 100644 test/mitmproxy/test_protocol_http1.py delete mode 100644 test/mitmproxy/test_protocol_http2.py diff --git a/test/mitmproxy/protocol/__init__.py b/test/mitmproxy/protocol/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/mitmproxy/protocol/test_http1.py b/test/mitmproxy/protocol/test_http1.py new file mode 100644 index 00000000..7d04c56b --- /dev/null +++ b/test/mitmproxy/protocol/test_http1.py @@ -0,0 +1,65 @@ +from __future__ import (absolute_import, print_function, division) + +from netlib.http import http1 +from netlib.tcp import TCPClient +from netlib.tutils import treq +from .. import tutils, tservers + + +class TestHTTPFlow(object): + + def test_repr(self): + f = tutils.tflow(resp=True, err=True) + assert repr(f) + + +class TestInvalidRequests(tservers.HTTPProxyTest): + ssl = True + + def test_double_connect(self): + p = self.pathoc() + r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port)) + assert r.status_code == 400 + assert b"Invalid HTTP request form" in r.content + + def test_relative_request(self): + p = self.pathoc_raw() + p.connect() + r = p.request("get:/p/200") + assert r.status_code == 400 + assert b"Invalid HTTP request form" in r.content + + +class TestExpectHeader(tservers.HTTPProxyTest): + + def test_simple(self): + client = TCPClient(("127.0.0.1", self.proxy.port)) + client.connect() + + # call pathod server, wait a second to complete the request + client.wfile.write( + b"POST http://localhost:%d/p/200 HTTP/1.1\r\n" + b"Expect: 100-continue\r\n" + b"Content-Length: 16\r\n" + b"\r\n" % self.server.port + ) + client.wfile.flush() + + assert client.rfile.readline() == b"HTTP/1.1 100 Continue\r\n" + assert client.rfile.readline() == b"\r\n" + + client.wfile.write(b"0123456789abcdef\r\n") + client.wfile.flush() + + resp = http1.read_response(client.rfile, treq()) + assert resp.status_code == 200 + + client.finish() + + +class TestHeadContentLength(tservers.HTTPProxyTest): + + def test_head_content_length(self): + p = self.pathoc() + resp = p.request("""head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase) + assert resp.headers["Content-Length"] == "42" diff --git a/test/mitmproxy/protocol/test_http2.py b/test/mitmproxy/protocol/test_http2.py new file mode 100644 index 00000000..1eabebf1 --- /dev/null +++ b/test/mitmproxy/protocol/test_http2.py @@ -0,0 +1,925 @@ +# coding=utf-8 + +from __future__ import (absolute_import, print_function, division) + +import pytest +import os +import tempfile +import traceback + +import h2 + +from mitmproxy import options +from mitmproxy.proxy.config import ProxyConfig + +import netlib +from ...netlib import tservers as netlib_tservers +from netlib.exceptions import HttpException +from netlib.http.http2 import framereader + +from .. import tservers + +import logging +logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING) +logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARNING) +logging.getLogger("passlib.utils.compat").setLevel(logging.WARNING) +logging.getLogger("passlib.registry").setLevel(logging.WARNING) +logging.getLogger("PIL.Image").setLevel(logging.WARNING) +logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING) + + +requires_alpn = pytest.mark.skipif( + not netlib.tcp.HAS_ALPN, + reason='requires OpenSSL with ALPN support') + + +class _Http2ServerBase(netlib_tservers.ServerTestBase): + ssl = dict(alpn_select=b'h2') + + class handler(netlib.tcp.BaseHandler): + + def handle(self): + h2_conn = h2.connection.H2Connection(client_side=False, header_encoding=False) + + preamble = self.rfile.read(24) + h2_conn.initiate_connection() + h2_conn.receive_data(preamble) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + if 'h2_server_settings' in self.kwargs: + h2_conn.update_settings(self.kwargs['h2_server_settings']) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(self.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + except netlib.exceptions.TcpDisconnect: + break + except: + print(traceback.format_exc()) + break + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + for event in events: + try: + if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile): + done = True + break + except netlib.exceptions.TcpDisconnect: + done = True + except: + done = True + print(traceback.format_exc()) + break + + def handle_server_event(self, event, h2_conn, rfile, wfile): + raise NotImplementedError() + + +class _Http2TestBase(object): + + @classmethod + def setup_class(cls): + opts = cls.get_options() + cls.config = ProxyConfig(opts) + + tmaster = tservers.TestMaster(opts, cls.config) + tmaster.start_app(options.APP_HOST, options.APP_PORT) + cls.proxy = tservers.ProxyThread(tmaster) + cls.proxy.start() + + @classmethod + def teardown_class(cls): + cls.proxy.shutdown() + + @classmethod + def get_options(cls): + opts = options.Options( + listen_port=0, + no_upstream_cert=False, + ssl_insecure=True + ) + opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return opts + + @property + def master(self): + return self.proxy.tmaster + + def setup(self): + self.master.clear_log() + self.master.state.clear() + self.server.server.handle_server_event = self.handle_server_event + + def _setup_connection(self): + client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port)) + client.connect() + + # send CONNECT request + client.wfile.write( + b"CONNECT localhost:%d HTTP/1.1\r\n" + b"Host: localhost:%d\r\n" + b"\r\n" % (self.server.server.address.port, self.server.server.address.port) + ) + client.wfile.flush() + + # read CONNECT response + while client.rfile.readline() != b"\r\n": + pass + + client.convert_to_ssl(alpn_protos=[b'h2']) + + h2_conn = h2.connection.H2Connection(client_side=True, header_encoding=False) + h2_conn.initiate_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + return client, h2_conn + + def _send_request(self, + wfile, + h2_conn, + stream_id=1, + headers=None, + body=b'', + end_stream=None, + priority_exclusive=None, + priority_depends_on=None, + priority_weight=None): + if headers is None: + headers = [] + if end_stream is None: + end_stream = (len(body) == 0) + + h2_conn.send_headers( + stream_id=stream_id, + headers=headers, + end_stream=end_stream, + priority_exclusive=priority_exclusive, + priority_depends_on=priority_depends_on, + priority_weight=priority_weight, + ) + if body: + h2_conn.send_data(stream_id, body) + h2_conn.end_stream(stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + +class _Http2Test(_Http2TestBase, _Http2ServerBase): + + @classmethod + def setup_class(cls): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(cls): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + +@requires_alpn +class TestSimple(_Http2Test): + request_body_buffer = b'' + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + assert (b'client-foo', b'client-bar-1') in event.headers + assert (b'client-foo', b'client-bar-2') in event.headers + elif isinstance(event, h2.events.StreamEnded): + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + h2_conn.send_headers(event.stream_id, [ + (':status', '200'), + ('server-foo', 'server-bar'), + ('föo', 'bär'), + ('X-Stream-ID', str(event.stream_id)), + ]) + h2_conn.send_data(event.stream_id, b'response body') + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + elif isinstance(event, h2.events.DataReceived): + cls.request_body_buffer += event.data + return True + + def test_simple(self): + response_body_buffer = b'' + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('ClIeNt-FoO', 'client-bar-1'), + ('ClIeNt-FoO', 'client-bar-2'), + ], + body=b'request body') + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.DataReceived): + response_body_buffer += event.data + elif isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.master.state.flows[0].response.status_code == 200 + assert self.master.state.flows[0].response.headers['server-foo'] == 'server-bar' + assert self.master.state.flows[0].response.headers['föo'] == 'bär' + assert self.master.state.flows[0].response.body == b'response body' + assert self.request_body_buffer == b'request body' + assert response_body_buffer == b'response body' + + +@requires_alpn +class TestRequestWithPriority(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + + headers = [(':status', '200')] + if event.priority_updated: + headers.append(('priority_exclusive', event.priority_updated.exclusive)) + headers.append(('priority_depends_on', event.priority_updated.depends_on)) + headers.append(('priority_weight', event.priority_updated.weight)) + h2_conn.send_headers(event.stream_id, headers) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_request_with_priority(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + priority_exclusive=True, + priority_depends_on=42424242, + priority_weight=42, + ) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.master.state.flows[0].response.headers['priority_exclusive'] == 'True' + assert self.master.state.flows[0].response.headers['priority_depends_on'] == '42424242' + assert self.master.state.flows[0].response.headers['priority_weight'] == '42' + + def test_request_without_priority(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + ) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert 'priority_exclusive' not in self.master.state.flows[0].response.headers + assert 'priority_depends_on' not in self.master.state.flows[0].response.headers + assert 'priority_weight' not in self.master.state.flows[0].response.headers + + +@requires_alpn +class TestPriority(_Http2Test): + priority_data = None + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.PriorityUpdated): + cls.priority_data = (event.exclusive, event.depends_on, event.weight) + elif isinstance(event, h2.events.RequestReceived): + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + + headers = [(':status', '200')] + h2_conn.send_headers(event.stream_id, headers) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_priority(self): + client, h2_conn = self._setup_connection() + + h2_conn.prioritize(1, exclusive=True, depends_on=0, weight=42) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + ) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.priority_data == (True, 0, 42) + + +@requires_alpn +class TestPriorityWithExistingStream(_Http2Test): + priority_data = [] + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.PriorityUpdated): + cls.priority_data.append((event.exclusive, event.depends_on, event.weight)) + elif isinstance(event, h2.events.RequestReceived): + assert not event.priority_updated + + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + + headers = [(':status', '200')] + h2_conn.send_headers(event.stream_id, headers) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + elif isinstance(event, h2.events.StreamEnded): + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_priority_with_existing_stream(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + end_stream=False, + ) + + h2_conn.prioritize(1, exclusive=True, depends_on=0, weight=42) + h2_conn.end_stream(1) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.priority_data == [(True, 0, 42)] + + +@requires_alpn +class TestStreamResetFromServer(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + h2_conn.reset_stream(event.stream_id, 0x8) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_request_with_priority(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + ) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamReset): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.master.state.flows[0].response is None + + +@requires_alpn +class TestBodySizeLimit(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + return True + + def test_body_size_limit(self): + self.config.options.body_size_limit = 20 + + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + body=b'very long body over 20 characters long', + ) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamReset): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 0 + + +@requires_alpn +class TestPushPromise(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + if event.stream_id != 1: + # ignore requests initiated by push promises + return True + + h2_conn.send_headers(1, [(':status', '200')]) + h2_conn.push_stream(1, 2, [ + (':authority', "127.0.0.1:{}".format(cls.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_foo'), + ('foo', 'bar') + ]) + h2_conn.push_stream(1, 4, [ + (':authority', "127.0.0.1:{}".format(cls.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_bar'), + ('foo', 'bar') + ]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.send_headers(2, [(':status', '200')]) + h2_conn.send_headers(4, [(':status', '200')]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.send_data(1, b'regular_stream') + h2_conn.send_data(2, b'pushed_stream_foo') + h2_conn.send_data(4, b'pushed_stream_bar') + wfile.write(h2_conn.data_to_send()) + wfile.flush() + h2_conn.end_stream(1) + h2_conn.end_stream(2) + h2_conn.end_stream(4) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + return True + + def test_push_promise(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + done = False + ended_streams = 0 + pushed_streams = 0 + responses = 0 + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + ended_streams += 1 + elif isinstance(event, h2.events.PushedStreamReceived): + pushed_streams += 1 + elif isinstance(event, h2.events.ResponseReceived): + responses += 1 + if isinstance(event, h2.events.ConnectionTerminated): + done = True + + if responses == 3 and ended_streams == 3 and pushed_streams == 2: + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert ended_streams == 3 + assert pushed_streams == 2 + + bodies = [flow.response.body for flow in self.master.state.flows] + assert len(bodies) == 3 + assert b'regular_stream' in bodies + assert b'pushed_stream_foo' in bodies + assert b'pushed_stream_bar' in bodies + + def test_push_promise_reset(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + done = False + ended_streams = 0 + pushed_streams = 0 + responses = 0 + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1: + ended_streams += 1 + elif isinstance(event, h2.events.PushedStreamReceived): + pushed_streams += 1 + h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + elif isinstance(event, h2.events.ResponseReceived): + responses += 1 + if isinstance(event, h2.events.ConnectionTerminated): + done = True + + if responses >= 1 and ended_streams >= 1 and pushed_streams == 2: + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + bodies = [flow.response.body for flow in self.master.state.flows if flow.response] + assert len(bodies) >= 1 + assert b'regular_stream' in bodies + # the other two bodies might not be transmitted before the reset + + +@requires_alpn +class TestConnectionLost(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.RequestReceived): + h2_conn.send_headers(1, [(':status', '200')]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return False + + def test_connection_lost(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + except: + break + try: + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + except: + break + + if len(self.master.state.flows) == 1: + assert self.master.state.flows[0].response is None + + +@requires_alpn +class TestMaxConcurrentStreams(_Http2Test): + + @classmethod + def setup_class(cls): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2}) + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + h2_conn.send_headers(event.stream_id, [ + (':status', '200'), + ('X-Stream-ID', str(event.stream_id)), + ]) + h2_conn.send_data(event.stream_id, 'Stream-ID {}'.format(event.stream_id).encode()) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_max_concurrent_streams(self): + client, h2_conn = self._setup_connection() + new_streams = [1, 3, 5, 7, 9, 11] + for stream_id in new_streams: + # this will exceed MAX_CONCURRENT_STREAMS on the server connection + # and cause mitmproxy to throttle stream creation to the server + self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('X-Stream-ID', str(stream_id)), + ]) + + ended_streams = 0 + while ended_streams != len(new_streams): + try: + header, body = framereader.http2_read_raw_frame(client.rfile) + events = h2_conn.receive_data(b''.join([header, body])) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + ended_streams += 1 + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == len(new_streams) + for flow in self.master.state.flows: + assert flow.response.status_code == 200 + assert b"Stream-ID " in flow.response.body + + +@requires_alpn +class TestConnectionTerminated(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.RequestReceived): + h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data=b'foobar') + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_connection_terminated(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ]) + + done = False + connection_terminated_event = None + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + for event in events: + if isinstance(event, h2.events.ConnectionTerminated): + connection_terminated_event = event + done = True + except: + break + + assert len(self.master.state.flows) == 1 + assert connection_terminated_event is not None + assert connection_terminated_event.error_code == 5 + assert connection_terminated_event.last_stream_id == 42 + assert connection_terminated_event.additional_data == b'foobar' diff --git a/test/mitmproxy/protocol/test_websockets.py b/test/mitmproxy/protocol/test_websockets.py index cc478c0b..e7e2684f 100644 --- a/test/mitmproxy/protocol/test_websockets.py +++ b/test/mitmproxy/protocol/test_websockets.py @@ -1,3 +1,5 @@ +from __future__ import absolute_import, print_function, division + import pytest import os import tempfile diff --git a/test/mitmproxy/test_protocol_http1.py b/test/mitmproxy/test_protocol_http1.py deleted file mode 100644 index cf7bd598..00000000 --- a/test/mitmproxy/test_protocol_http1.py +++ /dev/null @@ -1,63 +0,0 @@ -from netlib.http import http1 -from netlib.tcp import TCPClient -from netlib.tutils import treq -from . import tutils, tservers - - -class TestHTTPFlow(object): - - def test_repr(self): - f = tutils.tflow(resp=True, err=True) - assert repr(f) - - -class TestInvalidRequests(tservers.HTTPProxyTest): - ssl = True - - def test_double_connect(self): - p = self.pathoc() - r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port)) - assert r.status_code == 400 - assert b"Invalid HTTP request form" in r.content - - def test_relative_request(self): - p = self.pathoc_raw() - p.connect() - r = p.request("get:/p/200") - assert r.status_code == 400 - assert b"Invalid HTTP request form" in r.content - - -class TestExpectHeader(tservers.HTTPProxyTest): - - def test_simple(self): - client = TCPClient(("127.0.0.1", self.proxy.port)) - client.connect() - - # call pathod server, wait a second to complete the request - client.wfile.write( - b"POST http://localhost:%d/p/200 HTTP/1.1\r\n" - b"Expect: 100-continue\r\n" - b"Content-Length: 16\r\n" - b"\r\n" % self.server.port - ) - client.wfile.flush() - - assert client.rfile.readline() == b"HTTP/1.1 100 Continue\r\n" - assert client.rfile.readline() == b"\r\n" - - client.wfile.write(b"0123456789abcdef\r\n") - client.wfile.flush() - - resp = http1.read_response(client.rfile, treq()) - assert resp.status_code == 200 - - client.finish() - - -class TestHeadContentLength(tservers.HTTPProxyTest): - - def test_head_content_length(self): - p = self.pathoc() - resp = p.request("""head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase) - assert resp.headers["Content-Length"] == "42" diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py deleted file mode 100644 index 873c89c3..00000000 --- a/test/mitmproxy/test_protocol_http2.py +++ /dev/null @@ -1,925 +0,0 @@ -# coding=utf-8 - -from __future__ import (absolute_import, print_function, division) - -import pytest -import os -import tempfile -import traceback - -import h2 - -from mitmproxy import options -from mitmproxy.proxy.config import ProxyConfig - -import netlib -from ..netlib import tservers as netlib_tservers -from netlib.exceptions import HttpException -from netlib.http.http2 import framereader - -from . import tservers - -import logging -logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING) -logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARNING) -logging.getLogger("passlib.utils.compat").setLevel(logging.WARNING) -logging.getLogger("passlib.registry").setLevel(logging.WARNING) -logging.getLogger("PIL.Image").setLevel(logging.WARNING) -logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING) - - -requires_alpn = pytest.mark.skipif( - not netlib.tcp.HAS_ALPN, - reason='requires OpenSSL with ALPN support') - - -class _Http2ServerBase(netlib_tservers.ServerTestBase): - ssl = dict(alpn_select=b'h2') - - class handler(netlib.tcp.BaseHandler): - - def handle(self): - h2_conn = h2.connection.H2Connection(client_side=False, header_encoding=False) - - preamble = self.rfile.read(24) - h2_conn.initiate_connection() - h2_conn.receive_data(preamble) - self.wfile.write(h2_conn.data_to_send()) - self.wfile.flush() - - if 'h2_server_settings' in self.kwargs: - h2_conn.update_settings(self.kwargs['h2_server_settings']) - self.wfile.write(h2_conn.data_to_send()) - self.wfile.flush() - - done = False - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(self.rfile)) - events = h2_conn.receive_data(raw) - except HttpException: - print(traceback.format_exc()) - assert False - except netlib.exceptions.TcpDisconnect: - break - except: - print(traceback.format_exc()) - break - self.wfile.write(h2_conn.data_to_send()) - self.wfile.flush() - - for event in events: - try: - if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile): - done = True - break - except netlib.exceptions.TcpDisconnect: - done = True - except: - done = True - print(traceback.format_exc()) - break - - def handle_server_event(self, event, h2_conn, rfile, wfile): - raise NotImplementedError() - - -class _Http2TestBase(object): - - @classmethod - def setup_class(cls): - opts = cls.get_options() - cls.config = ProxyConfig(opts) - - tmaster = tservers.TestMaster(opts, cls.config) - tmaster.start_app(options.APP_HOST, options.APP_PORT) - cls.proxy = tservers.ProxyThread(tmaster) - cls.proxy.start() - - @classmethod - def teardown_class(cls): - cls.proxy.shutdown() - - @classmethod - def get_options(cls): - opts = options.Options( - listen_port=0, - no_upstream_cert=False, - ssl_insecure=True - ) - opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") - return opts - - @property - def master(self): - return self.proxy.tmaster - - def setup(self): - self.master.clear_log() - self.master.state.clear() - self.server.server.handle_server_event = self.handle_server_event - - def _setup_connection(self): - client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port)) - client.connect() - - # send CONNECT request - client.wfile.write( - b"CONNECT localhost:%d HTTP/1.1\r\n" - b"Host: localhost:%d\r\n" - b"\r\n" % (self.server.server.address.port, self.server.server.address.port) - ) - client.wfile.flush() - - # read CONNECT response - while client.rfile.readline() != b"\r\n": - pass - - client.convert_to_ssl(alpn_protos=[b'h2']) - - h2_conn = h2.connection.H2Connection(client_side=True, header_encoding=False) - h2_conn.initiate_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - return client, h2_conn - - def _send_request(self, - wfile, - h2_conn, - stream_id=1, - headers=None, - body=b'', - end_stream=None, - priority_exclusive=None, - priority_depends_on=None, - priority_weight=None): - if headers is None: - headers = [] - if end_stream is None: - end_stream = (len(body) == 0) - - h2_conn.send_headers( - stream_id=stream_id, - headers=headers, - end_stream=end_stream, - priority_exclusive=priority_exclusive, - priority_depends_on=priority_depends_on, - priority_weight=priority_weight, - ) - if body: - h2_conn.send_data(stream_id, body) - h2_conn.end_stream(stream_id) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - - -class _Http2Test(_Http2TestBase, _Http2ServerBase): - - @classmethod - def setup_class(cls): - _Http2TestBase.setup_class() - _Http2ServerBase.setup_class() - - @classmethod - def teardown_class(cls): - _Http2TestBase.teardown_class() - _Http2ServerBase.teardown_class() - - -@requires_alpn -class TestSimple(_Http2Test): - request_body_buffer = b'' - - @classmethod - def handle_server_event(cls, event, h2_conn, rfile, wfile): - if isinstance(event, h2.events.ConnectionTerminated): - return False - elif isinstance(event, h2.events.RequestReceived): - assert (b'client-foo', b'client-bar-1') in event.headers - assert (b'client-foo', b'client-bar-2') in event.headers - elif isinstance(event, h2.events.StreamEnded): - import warnings - with warnings.catch_warnings(): - # Ignore UnicodeWarning: - # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison - # failed to convert both arguments to Unicode - interpreting - # them as being unequal. - # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: - - warnings.simplefilter("ignore") - h2_conn.send_headers(event.stream_id, [ - (':status', '200'), - ('server-foo', 'server-bar'), - ('föo', 'bär'), - ('X-Stream-ID', str(event.stream_id)), - ]) - h2_conn.send_data(event.stream_id, b'response body') - h2_conn.end_stream(event.stream_id) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - elif isinstance(event, h2.events.DataReceived): - cls.request_body_buffer += event.data - return True - - def test_simple(self): - response_body_buffer = b'' - client, h2_conn = self._setup_connection() - - self._send_request( - client.wfile, - h2_conn, - headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ('ClIeNt-FoO', 'client-bar-1'), - ('ClIeNt-FoO', 'client-bar-2'), - ], - body=b'request body') - - done = False - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) - events = h2_conn.receive_data(raw) - except HttpException: - print(traceback.format_exc()) - assert False - - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - for event in events: - if isinstance(event, h2.events.DataReceived): - response_body_buffer += event.data - elif isinstance(event, h2.events.StreamEnded): - done = True - - h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - assert len(self.master.state.flows) == 1 - assert self.master.state.flows[0].response.status_code == 200 - assert self.master.state.flows[0].response.headers['server-foo'] == 'server-bar' - assert self.master.state.flows[0].response.headers['föo'] == 'bär' - assert self.master.state.flows[0].response.body == b'response body' - assert self.request_body_buffer == b'request body' - assert response_body_buffer == b'response body' - - -@requires_alpn -class TestRequestWithPriority(_Http2Test): - - @classmethod - def handle_server_event(cls, event, h2_conn, rfile, wfile): - if isinstance(event, h2.events.ConnectionTerminated): - return False - elif isinstance(event, h2.events.RequestReceived): - import warnings - with warnings.catch_warnings(): - # Ignore UnicodeWarning: - # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison - # failed to convert both arguments to Unicode - interpreting - # them as being unequal. - # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: - - warnings.simplefilter("ignore") - - headers = [(':status', '200')] - if event.priority_updated: - headers.append(('priority_exclusive', event.priority_updated.exclusive)) - headers.append(('priority_depends_on', event.priority_updated.depends_on)) - headers.append(('priority_weight', event.priority_updated.weight)) - h2_conn.send_headers(event.stream_id, headers) - h2_conn.end_stream(event.stream_id) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - return True - - def test_request_with_priority(self): - client, h2_conn = self._setup_connection() - - self._send_request( - client.wfile, - h2_conn, - headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ], - priority_exclusive=True, - priority_depends_on=42424242, - priority_weight=42, - ) - - done = False - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) - events = h2_conn.receive_data(raw) - except HttpException: - print(traceback.format_exc()) - assert False - - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - for event in events: - if isinstance(event, h2.events.StreamEnded): - done = True - - h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - assert len(self.master.state.flows) == 1 - assert self.master.state.flows[0].response.headers['priority_exclusive'] == 'True' - assert self.master.state.flows[0].response.headers['priority_depends_on'] == '42424242' - assert self.master.state.flows[0].response.headers['priority_weight'] == '42' - - def test_request_without_priority(self): - client, h2_conn = self._setup_connection() - - self._send_request( - client.wfile, - h2_conn, - headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ], - ) - - done = False - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) - events = h2_conn.receive_data(raw) - except HttpException: - print(traceback.format_exc()) - assert False - - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - for event in events: - if isinstance(event, h2.events.StreamEnded): - done = True - - h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - assert len(self.master.state.flows) == 1 - assert 'priority_exclusive' not in self.master.state.flows[0].response.headers - assert 'priority_depends_on' not in self.master.state.flows[0].response.headers - assert 'priority_weight' not in self.master.state.flows[0].response.headers - - -@requires_alpn -class TestPriority(_Http2Test): - priority_data = None - - @classmethod - def handle_server_event(cls, event, h2_conn, rfile, wfile): - if isinstance(event, h2.events.ConnectionTerminated): - return False - elif isinstance(event, h2.events.PriorityUpdated): - cls.priority_data = (event.exclusive, event.depends_on, event.weight) - elif isinstance(event, h2.events.RequestReceived): - import warnings - with warnings.catch_warnings(): - # Ignore UnicodeWarning: - # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison - # failed to convert both arguments to Unicode - interpreting - # them as being unequal. - # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: - - warnings.simplefilter("ignore") - - headers = [(':status', '200')] - h2_conn.send_headers(event.stream_id, headers) - h2_conn.end_stream(event.stream_id) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - return True - - def test_priority(self): - client, h2_conn = self._setup_connection() - - h2_conn.prioritize(1, exclusive=True, depends_on=0, weight=42) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - self._send_request( - client.wfile, - h2_conn, - headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ], - ) - - done = False - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) - events = h2_conn.receive_data(raw) - except HttpException: - print(traceback.format_exc()) - assert False - - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - for event in events: - if isinstance(event, h2.events.StreamEnded): - done = True - - h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - assert len(self.master.state.flows) == 1 - assert self.priority_data == (True, 0, 42) - - -@requires_alpn -class TestPriorityWithExistingStream(_Http2Test): - priority_data = [] - - @classmethod - def handle_server_event(cls, event, h2_conn, rfile, wfile): - if isinstance(event, h2.events.ConnectionTerminated): - return False - elif isinstance(event, h2.events.PriorityUpdated): - cls.priority_data.append((event.exclusive, event.depends_on, event.weight)) - elif isinstance(event, h2.events.RequestReceived): - assert not event.priority_updated - - import warnings - with warnings.catch_warnings(): - # Ignore UnicodeWarning: - # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison - # failed to convert both arguments to Unicode - interpreting - # them as being unequal. - # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: - - warnings.simplefilter("ignore") - - headers = [(':status', '200')] - h2_conn.send_headers(event.stream_id, headers) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - elif isinstance(event, h2.events.StreamEnded): - h2_conn.end_stream(event.stream_id) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - return True - - def test_priority_with_existing_stream(self): - client, h2_conn = self._setup_connection() - - self._send_request( - client.wfile, - h2_conn, - headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ], - end_stream=False, - ) - - h2_conn.prioritize(1, exclusive=True, depends_on=0, weight=42) - h2_conn.end_stream(1) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - done = False - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) - events = h2_conn.receive_data(raw) - except HttpException: - print(traceback.format_exc()) - assert False - - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - for event in events: - if isinstance(event, h2.events.StreamEnded): - done = True - - h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - assert len(self.master.state.flows) == 1 - assert self.priority_data == [(True, 0, 42)] - - -@requires_alpn -class TestStreamResetFromServer(_Http2Test): - - @classmethod - def handle_server_event(cls, event, h2_conn, rfile, wfile): - if isinstance(event, h2.events.ConnectionTerminated): - return False - elif isinstance(event, h2.events.RequestReceived): - h2_conn.reset_stream(event.stream_id, 0x8) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - return True - - def test_request_with_priority(self): - client, h2_conn = self._setup_connection() - - self._send_request( - client.wfile, - h2_conn, - headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ], - ) - - done = False - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) - events = h2_conn.receive_data(raw) - except HttpException: - print(traceback.format_exc()) - assert False - - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - for event in events: - if isinstance(event, h2.events.StreamReset): - done = True - - h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - assert len(self.master.state.flows) == 1 - assert self.master.state.flows[0].response is None - - -@requires_alpn -class TestBodySizeLimit(_Http2Test): - - @classmethod - def handle_server_event(cls, event, h2_conn, rfile, wfile): - if isinstance(event, h2.events.ConnectionTerminated): - return False - return True - - def test_body_size_limit(self): - self.config.options.body_size_limit = 20 - - client, h2_conn = self._setup_connection() - - self._send_request( - client.wfile, - h2_conn, - headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ], - body=b'very long body over 20 characters long', - ) - - done = False - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) - events = h2_conn.receive_data(raw) - except HttpException: - print(traceback.format_exc()) - assert False - - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - for event in events: - if isinstance(event, h2.events.StreamReset): - done = True - - h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - assert len(self.master.state.flows) == 0 - - -@requires_alpn -class TestPushPromise(_Http2Test): - - @classmethod - def handle_server_event(cls, event, h2_conn, rfile, wfile): - if isinstance(event, h2.events.ConnectionTerminated): - return False - elif isinstance(event, h2.events.RequestReceived): - if event.stream_id != 1: - # ignore requests initiated by push promises - return True - - h2_conn.send_headers(1, [(':status', '200')]) - h2_conn.push_stream(1, 2, [ - (':authority', "127.0.0.1:{}".format(cls.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/pushed_stream_foo'), - ('foo', 'bar') - ]) - h2_conn.push_stream(1, 4, [ - (':authority', "127.0.0.1:{}".format(cls.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/pushed_stream_bar'), - ('foo', 'bar') - ]) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - - h2_conn.send_headers(2, [(':status', '200')]) - h2_conn.send_headers(4, [(':status', '200')]) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - - h2_conn.send_data(1, b'regular_stream') - h2_conn.send_data(2, b'pushed_stream_foo') - h2_conn.send_data(4, b'pushed_stream_bar') - wfile.write(h2_conn.data_to_send()) - wfile.flush() - h2_conn.end_stream(1) - h2_conn.end_stream(2) - h2_conn.end_stream(4) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - - return True - - def test_push_promise(self): - client, h2_conn = self._setup_connection() - - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ('foo', 'bar') - ]) - - done = False - ended_streams = 0 - pushed_streams = 0 - responses = 0 - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) - events = h2_conn.receive_data(raw) - except HttpException: - print(traceback.format_exc()) - assert False - except: - break - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - for event in events: - if isinstance(event, h2.events.StreamEnded): - ended_streams += 1 - elif isinstance(event, h2.events.PushedStreamReceived): - pushed_streams += 1 - elif isinstance(event, h2.events.ResponseReceived): - responses += 1 - if isinstance(event, h2.events.ConnectionTerminated): - done = True - - if responses == 3 and ended_streams == 3 and pushed_streams == 2: - done = True - - h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - assert ended_streams == 3 - assert pushed_streams == 2 - - bodies = [flow.response.body for flow in self.master.state.flows] - assert len(bodies) == 3 - assert b'regular_stream' in bodies - assert b'pushed_stream_foo' in bodies - assert b'pushed_stream_bar' in bodies - - def test_push_promise_reset(self): - client, h2_conn = self._setup_connection() - - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ('foo', 'bar') - ]) - - done = False - ended_streams = 0 - pushed_streams = 0 - responses = 0 - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) - events = h2_conn.receive_data(raw) - except HttpException: - print(traceback.format_exc()) - assert False - - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - for event in events: - if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1: - ended_streams += 1 - elif isinstance(event, h2.events.PushedStreamReceived): - pushed_streams += 1 - h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - elif isinstance(event, h2.events.ResponseReceived): - responses += 1 - if isinstance(event, h2.events.ConnectionTerminated): - done = True - - if responses >= 1 and ended_streams >= 1 and pushed_streams == 2: - done = True - - h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - bodies = [flow.response.body for flow in self.master.state.flows if flow.response] - assert len(bodies) >= 1 - assert b'regular_stream' in bodies - # the other two bodies might not be transmitted before the reset - - -@requires_alpn -class TestConnectionLost(_Http2Test): - - @classmethod - def handle_server_event(cls, event, h2_conn, rfile, wfile): - if isinstance(event, h2.events.RequestReceived): - h2_conn.send_headers(1, [(':status', '200')]) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - return False - - def test_connection_lost(self): - client, h2_conn = self._setup_connection() - - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ('foo', 'bar') - ]) - - done = False - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) - h2_conn.receive_data(raw) - except HttpException: - print(traceback.format_exc()) - assert False - except: - break - try: - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - except: - break - - if len(self.master.state.flows) == 1: - assert self.master.state.flows[0].response is None - - -@requires_alpn -class TestMaxConcurrentStreams(_Http2Test): - - @classmethod - def setup_class(cls): - _Http2TestBase.setup_class() - _Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2}) - - @classmethod - def handle_server_event(cls, event, h2_conn, rfile, wfile): - if isinstance(event, h2.events.ConnectionTerminated): - return False - elif isinstance(event, h2.events.RequestReceived): - h2_conn.send_headers(event.stream_id, [ - (':status', '200'), - ('X-Stream-ID', str(event.stream_id)), - ]) - h2_conn.send_data(event.stream_id, 'Stream-ID {}'.format(event.stream_id).encode()) - h2_conn.end_stream(event.stream_id) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - return True - - def test_max_concurrent_streams(self): - client, h2_conn = self._setup_connection() - new_streams = [1, 3, 5, 7, 9, 11] - for stream_id in new_streams: - # this will exceed MAX_CONCURRENT_STREAMS on the server connection - # and cause mitmproxy to throttle stream creation to the server - self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ('X-Stream-ID', str(stream_id)), - ]) - - ended_streams = 0 - while ended_streams != len(new_streams): - try: - header, body = framereader.http2_read_raw_frame(client.rfile) - events = h2_conn.receive_data(b''.join([header, body])) - except: - break - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - for event in events: - if isinstance(event, h2.events.StreamEnded): - ended_streams += 1 - - h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - assert len(self.master.state.flows) == len(new_streams) - for flow in self.master.state.flows: - assert flow.response.status_code == 200 - assert b"Stream-ID " in flow.response.body - - -@requires_alpn -class TestConnectionTerminated(_Http2Test): - - @classmethod - def handle_server_event(cls, event, h2_conn, rfile, wfile): - if isinstance(event, h2.events.RequestReceived): - h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data=b'foobar') - wfile.write(h2_conn.data_to_send()) - wfile.flush() - return True - - def test_connection_terminated(self): - client, h2_conn = self._setup_connection() - - self._send_request(client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ]) - - done = False - connection_terminated_event = None - while not done: - try: - raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) - events = h2_conn.receive_data(raw) - for event in events: - if isinstance(event, h2.events.ConnectionTerminated): - connection_terminated_event = event - done = True - except: - break - - assert len(self.master.state.flows) == 1 - assert connection_terminated_event is not None - assert connection_terminated_event.error_code == 5 - assert connection_terminated_event.last_stream_id == 42 - assert connection_terminated_event.additional_data == b'foobar' -- cgit v1.2.3