diff options
Diffstat (limited to 'test/mitmproxy/proxy/protocol')
| -rw-r--r-- | test/mitmproxy/proxy/protocol/__init__.py | 0 | ||||
| -rw-r--r-- | test/mitmproxy/proxy/protocol/test_base.py | 1 | ||||
| -rw-r--r-- | test/mitmproxy/proxy/protocol/test_http.py | 1 | ||||
| -rw-r--r-- | test/mitmproxy/proxy/protocol/test_http1.py | 78 | ||||
| -rw-r--r-- | test/mitmproxy/proxy/protocol/test_http2.py | 944 | ||||
| -rw-r--r-- | test/mitmproxy/proxy/protocol/test_http_replay.py | 1 | ||||
| -rw-r--r-- | test/mitmproxy/proxy/protocol/test_rawtcp.py | 1 | ||||
| -rw-r--r-- | test/mitmproxy/proxy/protocol/test_tls.py | 26 | ||||
| -rw-r--r-- | test/mitmproxy/proxy/protocol/test_websocket.py | 328 |
9 files changed, 1380 insertions, 0 deletions
diff --git a/test/mitmproxy/proxy/protocol/__init__.py b/test/mitmproxy/proxy/protocol/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/mitmproxy/proxy/protocol/__init__.py diff --git a/test/mitmproxy/proxy/protocol/test_base.py b/test/mitmproxy/proxy/protocol/test_base.py new file mode 100644 index 00000000..777ab4dd --- /dev/null +++ b/test/mitmproxy/proxy/protocol/test_base.py @@ -0,0 +1 @@ +# TODO: write tests diff --git a/test/mitmproxy/proxy/protocol/test_http.py b/test/mitmproxy/proxy/protocol/test_http.py new file mode 100644 index 00000000..777ab4dd --- /dev/null +++ b/test/mitmproxy/proxy/protocol/test_http.py @@ -0,0 +1 @@ +# TODO: write tests diff --git a/test/mitmproxy/proxy/protocol/test_http1.py b/test/mitmproxy/proxy/protocol/test_http1.py new file mode 100644 index 00000000..07cd7dcc --- /dev/null +++ b/test/mitmproxy/proxy/protocol/test_http1.py @@ -0,0 +1,78 @@ +from mitmproxy.test import tflow +from mitmproxy.net.http import http1 +from mitmproxy.net.tcp import TCPClient +from mitmproxy.test.tutils import treq +from ... import tservers + + +class TestHTTPFlow: + + def test_repr(self): + f = tflow.tflow(resp=True, err=True) + assert repr(f) + + +class TestInvalidRequests(tservers.HTTPProxyTest): + ssl = True + + def test_double_connect(self): + p = self.pathoc() + with p.connect(): + r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port)) + assert r.status_code == 400 + assert b"Unexpected CONNECT" in r.content + + def test_relative_request(self): + p = self.pathoc_raw() + with p.connect(): + r = p.request("get:/p/200") + assert r.status_code == 400 + assert b"Invalid HTTP request form" in r.content + + +class TestProxyMisconfiguration(tservers.TransparentProxyTest): + + def test_absolute_request(self): + p = self.pathoc() + with p.connect(): + r = p.request("get:'http://localhost:%d/p/200'" % self.server.port) + assert r.status_code == 400 + assert b"misconfiguration" in r.content + + +class TestExpectHeader(tservers.HTTPProxyTest): + + def test_simple(self): + client = TCPClient(("127.0.0.1", self.proxy.port)) + client.connect() + + # call pathod server, wait a second to complete the request + client.wfile.write( + b"POST http://localhost:%d/p/200 HTTP/1.1\r\n" + b"Expect: 100-continue\r\n" + b"Content-Length: 16\r\n" + b"\r\n" % self.server.port + ) + client.wfile.flush() + + assert client.rfile.readline() == b"HTTP/1.1 100 Continue\r\n" + assert client.rfile.readline() == b"\r\n" + + client.wfile.write(b"0123456789abcdef\r\n") + client.wfile.flush() + + resp = http1.read_response(client.rfile, treq()) + assert resp.status_code == 200 + + client.finish() + + +class TestHeadContentLength(tservers.HTTPProxyTest): + + def test_head_content_length(self): + p = self.pathoc() + with p.connect(): + resp = p.request( + """head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase + ) + assert resp.headers["Content-Length"] == "42" diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py new file mode 100644 index 00000000..f5d9259d --- /dev/null +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -0,0 +1,944 @@ +# coding=utf-8 + + +import os +import tempfile +import traceback + +import h2 + +from mitmproxy import options +from mitmproxy.proxy.config import ProxyConfig + +import mitmproxy.net +from ....mitmproxy.net import tservers as net_tservers +from mitmproxy import exceptions +from mitmproxy.net.http import http1, http2 + +from ... import tservers +from ....conftest import requires_alpn + +import logging +logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING) +logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARNING) +logging.getLogger("passlib.utils.compat").setLevel(logging.WARNING) +logging.getLogger("passlib.registry").setLevel(logging.WARNING) +logging.getLogger("PIL.Image").setLevel(logging.WARNING) +logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING) + + +# inspect the log: +# for msg in self.proxy.tmaster.tlog: +# print(msg) + + +class _Http2ServerBase(net_tservers.ServerTestBase): + ssl = dict(alpn_select=b'h2') + + class handler(mitmproxy.net.tcp.BaseHandler): + + def handle(self): + h2_conn = h2.connection.H2Connection(client_side=False, header_encoding=False) + + preamble = self.rfile.read(24) + h2_conn.initiate_connection() + h2_conn.receive_data(preamble) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + if 'h2_server_settings' in self.kwargs: + h2_conn.update_settings(self.kwargs['h2_server_settings']) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + done = False + while not done: + try: + raw = b''.join(http2.read_raw_frame(self.rfile)) + events = h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + except exceptions.TcpDisconnect: + break + except: + print(traceback.format_exc()) + break + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + for event in events: + try: + if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile): + done = True + break + except exceptions.TcpDisconnect: + done = True + except: + done = True + print(traceback.format_exc()) + break + + def handle_server_event(self, event, h2_conn, rfile, wfile): + raise NotImplementedError() + + +class _Http2TestBase: + + @classmethod + def setup_class(cls): + opts = cls.get_options() + cls.config = ProxyConfig(opts) + + tmaster = tservers.TestMaster(opts, cls.config) + cls.proxy = tservers.ProxyThread(tmaster) + cls.proxy.start() + + @classmethod + def teardown_class(cls): + cls.proxy.shutdown() + + @classmethod + def get_options(cls): + opts = options.Options( + listen_port=0, + no_upstream_cert=False, + ssl_insecure=True + ) + opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return opts + + @property + def master(self): + return self.proxy.tmaster + + def setup(self): + self.master.reset([]) + self.server.server.handle_server_event = self.handle_server_event + + def _setup_connection(self): + client = mitmproxy.net.tcp.TCPClient(("127.0.0.1", self.proxy.port)) + client.connect() + + # send CONNECT request + client.wfile.write(http1.assemble_request(mitmproxy.net.http.Request( + 'authority', + b'CONNECT', + b'', + b'localhost', + self.server.server.address.port, + b'/', + b'HTTP/1.1', + [(b'host', b'localhost:%d' % self.server.server.address.port)], + b'', + ))) + client.wfile.flush() + + # read CONNECT response + while client.rfile.readline() != b"\r\n": + pass + + client.convert_to_ssl(alpn_protos=[b'h2']) + + h2_conn = h2.connection.H2Connection(client_side=True, header_encoding=False) + h2_conn.initiate_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + return client, h2_conn + + def _send_request(self, + wfile, + h2_conn, + stream_id=1, + headers=None, + body=b'', + end_stream=None, + priority_exclusive=None, + priority_depends_on=None, + priority_weight=None): + if headers is None: + headers = [] + if end_stream is None: + end_stream = (len(body) == 0) + + h2_conn.send_headers( + stream_id=stream_id, + headers=headers, + end_stream=end_stream, + priority_exclusive=priority_exclusive, + priority_depends_on=priority_depends_on, + priority_weight=priority_weight, + ) + if body: + h2_conn.send_data(stream_id, body) + h2_conn.end_stream(stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + +class _Http2Test(_Http2TestBase, _Http2ServerBase): + + @classmethod + def setup_class(cls): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(cls): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + +@requires_alpn +class TestSimple(_Http2Test): + request_body_buffer = b'' + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + assert (b'client-foo', b'client-bar-1') in event.headers + assert (b'client-foo', b'client-bar-2') in event.headers + elif isinstance(event, h2.events.StreamEnded): + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + h2_conn.send_headers(event.stream_id, [ + (':status', '200'), + ('server-foo', 'server-bar'), + ('föo', 'bär'), + ('X-Stream-ID', str(event.stream_id)), + ]) + h2_conn.send_data(event.stream_id, b'response body') + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + elif isinstance(event, h2.events.DataReceived): + cls.request_body_buffer += event.data + return True + + def test_simple(self): + response_body_buffer = b'' + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('ClIeNt-FoO', 'client-bar-1'), + ('ClIeNt-FoO', 'client-bar-2'), + ], + body=b'request body') + + done = False + while not done: + try: + raw = b''.join(http2.read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.DataReceived): + response_body_buffer += event.data + elif isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.master.state.flows[0].response.status_code == 200 + assert self.master.state.flows[0].response.headers['server-foo'] == 'server-bar' + assert self.master.state.flows[0].response.headers['föo'] == 'bär' + assert self.master.state.flows[0].response.content == b'response body' + assert self.request_body_buffer == b'request body' + assert response_body_buffer == b'response body' + + +@requires_alpn +class TestRequestWithPriority(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + + headers = [(':status', '200')] + if event.priority_updated: + headers.append(('priority_exclusive', str(event.priority_updated.exclusive).encode())) + headers.append(('priority_depends_on', str(event.priority_updated.depends_on).encode())) + headers.append(('priority_weight', str(event.priority_updated.weight).encode())) + h2_conn.send_headers(event.stream_id, headers) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_request_with_priority(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + priority_exclusive=True, + priority_depends_on=42424242, + priority_weight=42, + ) + + done = False + while not done: + try: + raw = b''.join(http2.read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.master.state.flows[0].response.headers['priority_exclusive'] == 'True' + assert self.master.state.flows[0].response.headers['priority_depends_on'] == '42424242' + assert self.master.state.flows[0].response.headers['priority_weight'] == '42' + + def test_request_without_priority(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + ) + + done = False + while not done: + try: + raw = b''.join(http2.read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert 'priority_exclusive' not in self.master.state.flows[0].response.headers + assert 'priority_depends_on' not in self.master.state.flows[0].response.headers + assert 'priority_weight' not in self.master.state.flows[0].response.headers + + +@requires_alpn +class TestPriority(_Http2Test): + priority_data = None + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.PriorityUpdated): + cls.priority_data = (event.exclusive, event.depends_on, event.weight) + elif isinstance(event, h2.events.RequestReceived): + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + + headers = [(':status', '200')] + h2_conn.send_headers(event.stream_id, headers) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_priority(self): + client, h2_conn = self._setup_connection() + + h2_conn.prioritize(1, exclusive=True, depends_on=0, weight=42) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + ) + + done = False + while not done: + try: + raw = b''.join(http2.read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.priority_data == (True, 0, 42) + + +@requires_alpn +class TestPriorityWithExistingStream(_Http2Test): + priority_data = [] + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.PriorityUpdated): + cls.priority_data.append((event.exclusive, event.depends_on, event.weight)) + elif isinstance(event, h2.events.RequestReceived): + assert not event.priority_updated + + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + + headers = [(':status', '200')] + h2_conn.send_headers(event.stream_id, headers) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + elif isinstance(event, h2.events.StreamEnded): + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_priority_with_existing_stream(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + end_stream=False, + ) + + h2_conn.prioritize(1, exclusive=True, depends_on=0, weight=42) + h2_conn.end_stream(1) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + done = False + while not done: + try: + raw = b''.join(http2.read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.priority_data == [(True, 0, 42)] + + +@requires_alpn +class TestStreamResetFromServer(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + h2_conn.reset_stream(event.stream_id, 0x8) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_request_with_priority(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + ) + + done = False + while not done: + try: + raw = b''.join(http2.read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamReset): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.master.state.flows[0].response is None + + +@requires_alpn +class TestBodySizeLimit(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + return True + + def test_body_size_limit(self): + self.config.options.body_size_limit = 20 + + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + body=b'very long body over 20 characters long', + ) + + done = False + while not done: + try: + raw = b''.join(http2.read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamReset): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 0 + + +@requires_alpn +class TestPushPromise(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + if event.stream_id != 1: + # ignore requests initiated by push promises + return True + + h2_conn.send_headers(1, [(':status', '200')]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.push_stream(1, 2, [ + (':authority', "127.0.0.1:{}".format(cls.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_foo'), + ('foo', 'bar') + ]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.push_stream(1, 4, [ + (':authority', "127.0.0.1:{}".format(cls.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_bar'), + ('foo', 'bar') + ]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.send_headers(2, [(':status', '200')]) + h2_conn.send_headers(4, [(':status', '200')]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.send_data(1, b'regular_stream') + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.end_stream(1) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.send_data(2, b'pushed_stream_foo') + h2_conn.end_stream(2) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + h2_conn.send_data(4, b'pushed_stream_bar') + h2_conn.end_stream(4) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + return True + + def test_push_promise(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + done = False + ended_streams = 0 + pushed_streams = 0 + responses = 0 + while not done: + try: + raw = b''.join(http2.read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + ended_streams += 1 + elif isinstance(event, h2.events.PushedStreamReceived): + pushed_streams += 1 + elif isinstance(event, h2.events.ResponseReceived): + responses += 1 + if isinstance(event, h2.events.ConnectionTerminated): + done = True + + if responses == 3 and ended_streams == 3 and pushed_streams == 2: + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert ended_streams == 3 + assert pushed_streams == 2 + + bodies = [flow.response.content for flow in self.master.state.flows] + assert len(bodies) == 3 + assert b'regular_stream' in bodies + assert b'pushed_stream_foo' in bodies + assert b'pushed_stream_bar' in bodies + + pushed_flows = [flow for flow in self.master.state.flows if 'h2-pushed-stream' in flow.metadata] + assert len(pushed_flows) == 2 + + def test_push_promise_reset(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + done = False + ended_streams = 0 + pushed_streams = 0 + responses = 0 + while not done: + try: + raw = b''.join(http2.read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1: + ended_streams += 1 + elif isinstance(event, h2.events.PushedStreamReceived): + pushed_streams += 1 + h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + elif isinstance(event, h2.events.ResponseReceived): + responses += 1 + if isinstance(event, h2.events.ConnectionTerminated): + done = True + + if responses >= 1 and ended_streams >= 1 and pushed_streams == 2: + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + bodies = [flow.response.content for flow in self.master.state.flows if flow.response] + assert len(bodies) >= 1 + assert b'regular_stream' in bodies + # the other two bodies might not be transmitted before the reset + + +@requires_alpn +class TestConnectionLost(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.RequestReceived): + h2_conn.send_headers(1, [(':status', '200')]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return False + + def test_connection_lost(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + done = False + while not done: + try: + raw = b''.join(http2.read_raw_frame(client.rfile)) + h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + except: + break + try: + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + except: + break + + if len(self.master.state.flows) == 1: + assert self.master.state.flows[0].response is None + + +@requires_alpn +class TestMaxConcurrentStreams(_Http2Test): + + @classmethod + def setup_class(cls): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2}) + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + h2_conn.send_headers(event.stream_id, [ + (':status', '200'), + ('X-Stream-ID', str(event.stream_id)), + ]) + h2_conn.send_data(event.stream_id, 'Stream-ID {}'.format(event.stream_id).encode()) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_max_concurrent_streams(self): + client, h2_conn = self._setup_connection() + new_streams = [1, 3, 5, 7, 9, 11] + for stream_id in new_streams: + # this will exceed MAX_CONCURRENT_STREAMS on the server connection + # and cause mitmproxy to throttle stream creation to the server + self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('X-Stream-ID', str(stream_id)), + ]) + + ended_streams = 0 + while ended_streams != len(new_streams): + try: + header, body = http2.read_raw_frame(client.rfile) + events = h2_conn.receive_data(b''.join([header, body])) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + ended_streams += 1 + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == len(new_streams) + for flow in self.master.state.flows: + assert flow.response.status_code == 200 + assert b"Stream-ID " in flow.response.content + + +@requires_alpn +class TestConnectionTerminated(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.RequestReceived): + h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data=b'foobar') + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_connection_terminated(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ]) + + done = False + connection_terminated_event = None + while not done: + try: + raw = b''.join(http2.read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + for event in events: + if isinstance(event, h2.events.ConnectionTerminated): + connection_terminated_event = event + done = True + except: + break + + assert len(self.master.state.flows) == 1 + assert connection_terminated_event is not None + assert connection_terminated_event.error_code == 5 + assert connection_terminated_event.last_stream_id == 42 + assert connection_terminated_event.additional_data == b'foobar' diff --git a/test/mitmproxy/proxy/protocol/test_http_replay.py b/test/mitmproxy/proxy/protocol/test_http_replay.py new file mode 100644 index 00000000..777ab4dd --- /dev/null +++ b/test/mitmproxy/proxy/protocol/test_http_replay.py @@ -0,0 +1 @@ +# TODO: write tests diff --git a/test/mitmproxy/proxy/protocol/test_rawtcp.py b/test/mitmproxy/proxy/protocol/test_rawtcp.py new file mode 100644 index 00000000..777ab4dd --- /dev/null +++ b/test/mitmproxy/proxy/protocol/test_rawtcp.py @@ -0,0 +1 @@ +# TODO: write tests diff --git a/test/mitmproxy/proxy/protocol/test_tls.py b/test/mitmproxy/proxy/protocol/test_tls.py new file mode 100644 index 00000000..e17ee46f --- /dev/null +++ b/test/mitmproxy/proxy/protocol/test_tls.py @@ -0,0 +1,26 @@ +from mitmproxy.proxy.protocol.tls import TlsClientHello + + +class TestClientHello: + + def test_no_extensions(self): + data = bytes.fromhex( + "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" + "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000" + "61006200640100" + ) + c = TlsClientHello(data) + assert c.sni is None + assert c.alpn_protocols == [] + + def test_extensions(self): + data = bytes.fromhex( + "03033b70638d2523e1cba15f8364868295305e9c52aceabda4b5147210abc783e6e1000022c02bc02fc02cc030" + "cca9cca8cc14cc13c009c013c00ac014009c009d002f0035000a0100006cff0100010000000010000e00000b65" + "78616d706c652e636f6d0017000000230000000d00120010060106030501050304010403020102030005000501" + "00000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a00080006001d00" + "170018" + ) + c = TlsClientHello(data) + assert c.sni == 'example.com' + assert c.alpn_protocols == [b'h2', b'http/1.1'] diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py new file mode 100644 index 00000000..4ea01d34 --- /dev/null +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -0,0 +1,328 @@ +import pytest +import os +import tempfile +import traceback + +from mitmproxy import options +from mitmproxy import exceptions +from mitmproxy.http import HTTPFlow +from mitmproxy.websocket import WebSocketFlow +from mitmproxy.proxy.config import ProxyConfig + +from mitmproxy.net import tcp +from mitmproxy.net import http +from ....mitmproxy.net import tservers as net_tservers +from ... import tservers + +from mitmproxy.net import websockets + + +class _WebSocketServerBase(net_tservers.ServerTestBase): + + class handler(tcp.BaseHandler): + + def handle(self): + try: + request = http.http1.read_request(self.rfile) + assert websockets.check_handshake(request.headers) + + response = http.Response( + "HTTP/1.1", + 101, + reason=http.status_codes.RESPONSES.get(101), + headers=http.Headers( + connection='upgrade', + upgrade='websocket', + sec_websocket_accept=b'', + ), + content=b'', + ) + self.wfile.write(http.http1.assemble_response(response)) + self.wfile.flush() + + self.server.handle_websockets(self.rfile, self.wfile) + except: + traceback.print_exc() + + +class _WebSocketTestBase: + + @classmethod + def setup_class(cls): + opts = cls.get_options() + cls.config = ProxyConfig(opts) + + tmaster = tservers.TestMaster(opts, cls.config) + cls.proxy = tservers.ProxyThread(tmaster) + cls.proxy.start() + + @classmethod + def teardown_class(cls): + cls.proxy.shutdown() + + @classmethod + def get_options(cls): + opts = options.Options( + listen_port=0, + no_upstream_cert=False, + ssl_insecure=True, + websocket=True, + ) + opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return opts + + @property + def master(self): + return self.proxy.tmaster + + def setup(self): + self.master.reset([]) + self.server.server.handle_websockets = self.handle_websockets + + def _setup_connection(self): + client = tcp.TCPClient(("127.0.0.1", self.proxy.port)) + client.connect() + + request = http.Request( + "authority", + "CONNECT", + "", + "localhost", + self.server.server.address.port, + "", + "HTTP/1.1", + content=b'') + client.wfile.write(http.http1.assemble_request(request)) + client.wfile.flush() + + response = http.http1.read_response(client.rfile, request) + + if self.ssl: + client.convert_to_ssl() + assert client.ssl_established + + request = http.Request( + "relative", + "GET", + "http", + "localhost", + self.server.server.address.port, + "/ws", + "HTTP/1.1", + headers=http.Headers( + connection="upgrade", + upgrade="websocket", + sec_websocket_version="13", + sec_websocket_key="1234", + ), + content=b'') + client.wfile.write(http.http1.assemble_request(request)) + client.wfile.flush() + + response = http.http1.read_response(client.rfile, request) + assert websockets.check_handshake(response.headers) + + return client + + +class _WebSocketTest(_WebSocketTestBase, _WebSocketServerBase): + + @classmethod + def setup_class(cls): + _WebSocketTestBase.setup_class() + _WebSocketServerBase.setup_class(ssl=cls.ssl) + + @classmethod + def teardown_class(cls): + _WebSocketTestBase.teardown_class() + _WebSocketServerBase.teardown_class() + + +class TestSimple(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + def test_simple(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'server-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'client-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'\xde\xad\xbe\xef' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + assert len(self.master.state.flows) == 2 + assert isinstance(self.master.state.flows[0], HTTPFlow) + assert isinstance(self.master.state.flows[1], WebSocketFlow) + assert len(self.master.state.flows[1].messages) == 5 + assert self.master.state.flows[1].messages[0].content == b'server-foobar' + assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[1].content == b'client-foobar' + assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[2].content == b'client-foobar' + assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' + assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY + assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' + assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY + + +class TestSimpleTLS(_WebSocketTest): + ssl = True + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + def test_simple_tls(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'server-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'client-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + +class TestPing(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'foobar' + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received'))) + wfile.flush() + + def test_ping(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.TEXT + assert frame.payload == b'pong-received' + + +class TestPong(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + frame = websockets.Frame.from_file(rfile) + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'foobar' + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + wfile.flush() + + def test_pong(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'foobar' + + +class TestClose(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + wfile.flush() + + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(rfile) + + def test_close(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + websockets.Frame.from_file(client.rfile) + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + def test_close_payload_1(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + client.wfile.flush() + + websockets.Frame.from_file(client.rfile) + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + def test_close_payload_2(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + client.wfile.flush() + + websockets.Frame.from_file(client.rfile) + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + +class TestInvalidFrame(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar'))) + wfile.flush() + + def test_invalid_frame(self): + client = self._setup_connection() + + # with pytest.raises(exceptions.TcpDisconnect): + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == 15 + assert frame.payload == b'foobar' |
