From ffb3988dc9ef3f7f8137b913edb7986e148e0dc4 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 13 Nov 2016 16:18:29 +0100 Subject: rename WebSocket{s,} protocol --- docs/features/tcpproxy.rst | 2 +- docs/scripting/events.rst | 8 +- mitmproxy/net/websockets/frame.py | 10 +- mitmproxy/net/websockets/utils.py | 2 +- mitmproxy/proxy/protocol/__init__.py | 8 +- mitmproxy/proxy/protocol/http.py | 15 +- mitmproxy/proxy/protocol/http2.py | 2 +- mitmproxy/proxy/protocol/websocket.py | 111 +++++++++++ mitmproxy/proxy/protocol/websockets.py | 111 ----------- test/mitmproxy/protocol/test_websocket.py | 297 +++++++++++++++++++++++++++++ test/mitmproxy/protocol/test_websockets.py | 297 ----------------------------- 11 files changed, 432 insertions(+), 431 deletions(-) create mode 100644 mitmproxy/proxy/protocol/websocket.py delete mode 100644 mitmproxy/proxy/protocol/websockets.py create mode 100644 test/mitmproxy/protocol/test_websocket.py delete mode 100644 test/mitmproxy/protocol/test_websockets.py diff --git a/docs/features/tcpproxy.rst b/docs/features/tcpproxy.rst index 1d6fbd12..e24620e2 100644 --- a/docs/features/tcpproxy.rst +++ b/docs/features/tcpproxy.rst @@ -3,7 +3,7 @@ TCP Proxy ========= -WebSockets or other non-HTTP protocols are not supported by mitmproxy yet. However, you can exempt +Non-HTTP protocols are not supported by mitmproxy yet. However, you can exempt hostnames from processing, so that mitmproxy acts as a generic TCP forwarder. This feature is closely related to the :ref:`passthrough` functionality, but differs in two important aspects: diff --git a/docs/scripting/events.rst b/docs/scripting/events.rst index 5f560e58..69b829a3 100644 --- a/docs/scripting/events.rst +++ b/docs/scripting/events.rst @@ -162,15 +162,15 @@ WebSocket Events :widths: 40 60 :header-rows: 0 - * - .. py:function:: websockets_handshake(flow) + * - .. py:function:: websocket_handshake(flow) - - Called when a client wants to establish a WebSockets connection. The - WebSockets-specific headers can be manipulated to manipulate the + - Called when a client wants to establish a WebSocket connection. The + WebSocket-specific headers can be manipulated to manipulate the handshake. The ``flow`` object is guaranteed to have a non-None ``request`` attribute. *flow* - The flow containing the HTTP websocket handshake request. The + The flow containing the HTTP WebSocket handshake request. The object is guaranteed to have a non-None ``request`` attribute. diff --git a/mitmproxy/net/websockets/frame.py b/mitmproxy/net/websockets/frame.py index bd5f67dd..28881f64 100644 --- a/mitmproxy/net/websockets/frame.py +++ b/mitmproxy/net/websockets/frame.py @@ -90,7 +90,7 @@ class FrameHeader: @classmethod def _make_length_code(self, length): """ - A websockets frame contains an initial length_code, and an optional + A WebSocket frame contains an initial length_code, and an optional extended length code to represent the actual length if length code is larger than 125 """ @@ -149,7 +149,7 @@ class FrameHeader: @classmethod def from_file(cls, fp): """ - read a websockets frame header + read a WebSocket frame header """ first_byte, second_byte = fp.safe_read(2) fin = bits.getbit(first_byte, 7) @@ -195,11 +195,11 @@ class FrameHeader: class Frame: """ - Represents a single WebSockets frame. + Represents a single WebSocket frame. Constructor takes human readable forms of the frame components. from_bytes() reads from a file-like object to create a new Frame. - WebSockets Frame as defined in RFC6455 + WebSocket frame as defined in RFC6455 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-------+-+-------------+-------------------------------+ @@ -253,7 +253,7 @@ class Frame: @classmethod def from_file(cls, fp): """ - read a websockets frame sent by a server or client + read a WebSocket frame sent by a server or client fp is a "file like" object that could be backed by a network stream or a disk or an in memory stream reader diff --git a/mitmproxy/net/websockets/utils.py b/mitmproxy/net/websockets/utils.py index d0b168ce..2f13f2b2 100644 --- a/mitmproxy/net/websockets/utils.py +++ b/mitmproxy/net/websockets/utils.py @@ -1,5 +1,5 @@ """ -Collection of WebSockets Protocol utility functions (RFC6455) +Collection of WebSocket protocol utility functions (RFC6455) Spec: https://tools.ietf.org/html/rfc6455 """ diff --git a/mitmproxy/proxy/protocol/__init__.py b/mitmproxy/proxy/protocol/__init__.py index 89b60386..6dbdd13c 100644 --- a/mitmproxy/proxy/protocol/__init__.py +++ b/mitmproxy/proxy/protocol/__init__.py @@ -2,7 +2,7 @@ In mitmproxy, protocols are implemented as a set of layers, which are composed on top each other. The first layer is usually the proxy mode, e.g. transparent proxy or normal HTTP proxy. Next, various protocol layers are stacked on top of -each other - imagine WebSockets on top of an HTTP Upgrade request. An actual +each other - imagine WebSocket on top of an HTTP Upgrade request. An actual mitmproxy connection may look as follows (outermost layer first): Transparent HTTP proxy, no TLS: @@ -10,7 +10,7 @@ mitmproxy connection may look as follows (outermost layer first): - Http1Layer - HttpLayer - Regular proxy, CONNECT request with WebSockets over SSL: + Regular proxy, CONNECT request with WebSocket over SSL: - ReverseProxy - Http1Layer - HttpLayer @@ -34,7 +34,7 @@ from .http import UpstreamConnectLayer from .http import HttpLayer from .http1 import Http1Layer from .http2 import Http2Layer -from .websockets import WebSocketsLayer +from .websocket import WebSocketLayer from .rawtcp import RawTCPLayer from .tls import TlsClientHello from .tls import TlsLayer @@ -47,6 +47,6 @@ __all__ = [ "HttpLayer", "Http1Layer", "Http2Layer", - "WebSocketsLayer", + "WebSocketLayer", "RawTCPLayer", ] diff --git a/mitmproxy/proxy/protocol/http.py b/mitmproxy/proxy/protocol/http.py index dcedfc5a..2da4ddbf 100644 --- a/mitmproxy/proxy/protocol/http.py +++ b/mitmproxy/proxy/protocol/http.py @@ -8,7 +8,8 @@ from mitmproxy import exceptions from mitmproxy import http from mitmproxy import flow from mitmproxy.proxy.protocol import base -from mitmproxy.proxy.protocol import websockets as pwebsockets +from mitmproxy.proxy.protocol.websocket import WebSocketLayer +import mitmproxy.net.http from mitmproxy.net import tcp from mitmproxy.net import websockets @@ -300,7 +301,7 @@ class HttpLayer(base.Layer): try: if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers): - # We only support RFC6455 with WebSockets version 13 + # We only support RFC6455 with WebSocket version 13 # allow inline scripts to manipulate the client handshake self.channel.ask("websocket_handshake", f) @@ -392,19 +393,19 @@ class HttpLayer(base.Layer): if f.response.status_code == 101: # Handle a successful HTTP 101 Switching Protocols Response, # received after e.g. a WebSocket upgrade request. - # Check for WebSockets handshake - is_websockets = ( + # Check for WebSocket handshake + is_websocket = ( websockets.check_handshake(f.request.headers) and websockets.check_handshake(f.response.headers) ) - if is_websockets and not self.config.options.websockets: + if is_websocket and not self.config.options.websockets: self.log( "Client requested WebSocket connection, but the protocol is disabled.", "info" ) - if is_websockets and self.config.options.websockets: - layer = pwebsockets.WebSocketsLayer(self, f) + if is_websocket and self.config.options.websockets: + layer = WebSocketLayer(self, f) else: layer = self.ctx.next_layer(self) layer() diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index 835f86d0..41707096 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -121,7 +121,7 @@ class Http2Layer(base.Layer): self.client_conn.send(self.connections[self.client_conn].data_to_send()) def next_layer(self): # pragma: no cover - # WebSockets over HTTP/2? + # WebSocket over HTTP/2? # CONNECT for proxying? raise NotImplementedError() diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py new file mode 100644 index 00000000..47628013 --- /dev/null +++ b/mitmproxy/proxy/protocol/websocket.py @@ -0,0 +1,111 @@ +import socket +import struct +from OpenSSL import SSL +from mitmproxy import exceptions +from mitmproxy.proxy.protocol import base +from mitmproxy.utils import strutils +from mitmproxy.net import tcp +from mitmproxy.net import websockets + + +class WebSocketLayer(base.Layer): + """ + WebSocket layer to intercept, modify, and forward WebSocket connections + + Only version 13 is supported (as specified in RFC6455) + Only HTTP/1.1-initiated connections are supported. + + The client starts by sending an Upgrade-request. + In order to determine the handshake and negotiate the correct protocol + and extensions, the Upgrade-request is forwarded to the server. + The response from the server is then parsed and negotiated settings are extracted. + Finally the handshake is completed by forwarding the server-response to the client. + After that, only WebSocket frames are exchanged. + + PING/PONG frames pass through and must be answered by the other endpoint. + + CLOSE frames are forwarded before this WebSocketLayer terminates. + + This layer is transparent to any negotiated extensions. + This layer is transparent to any negotiated subprotocols. + Only raw frames are forwarded to the other endpoint. + """ + + def __init__(self, ctx, flow): + super().__init__(ctx) + self._flow = flow + + self.client_key = websockets.get_client_key(self._flow.request.headers) + self.client_protocol = websockets.get_protocol(self._flow.request.headers) + self.client_extensions = websockets.get_extensions(self._flow.request.headers) + + self.server_accept = websockets.get_server_accept(self._flow.response.headers) + self.server_protocol = websockets.get_protocol(self._flow.response.headers) + self.server_extensions = websockets.get_extensions(self._flow.response.headers) + + def _handle_frame(self, frame, source_conn, other_conn, is_server): + sender = "server" if is_server else "client" + self.log( + "WebSocket frame received from {}".format(sender), + "debug", + [repr(frame)] + ) + + if frame.header.opcode & 0x8 == 0: + self.log( + "{direction} websocket {direction} {server}".format( + server=repr(self.server_conn.address), + direction="<-" if is_server else "->", + ), + "info", + strutils.bytes_to_escaped_str(frame.payload, keep_spacing=True).splitlines() + ) + # forward the data frame to the other side + other_conn.send(bytes(frame)) + elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG): + # just forward the ping/pong to the other side + other_conn.send(bytes(frame)) + elif frame.header.opcode == websockets.OPCODE.CLOSE: + code = '(status code missing)' + msg = None + reason = '(message missing)' + if len(frame.payload) >= 2: + code, = struct.unpack('!H', frame.payload[:2]) + msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code') + if len(frame.payload) > 2: + reason = frame.payload[2:] + self.log("WebSocket connection closed by {}: {} {}, {}".format(sender, code, msg, reason), "info") + + other_conn.send(bytes(frame)) + # close the connection + return False + else: + self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)]) + # unknown frame - just forward it + other_conn.send(bytes(frame)) + + # continue the connection + return True + + def __call__(self): + client = self.client_conn.connection + server = self.server_conn.connection + conns = [client, server] + + try: + while not self.channel.should_exit.is_set(): + r = tcp.ssl_read_select(conns, 1) + for conn in r: + source_conn = self.client_conn if conn == client else self.server_conn + other_conn = self.server_conn if conn == client else self.client_conn + is_server = (conn == self.server_conn.connection) + + frame = websockets.Frame.from_file(source_conn.rfile) + + if not self._handle_frame(frame, source_conn, other_conn, is_server): + return + except (socket.error, exceptions.TcpException, SSL.Error) as e: + self.log("WebSocket connection closed unexpectedly by {}: {}".format( + "server" if is_server else "client", repr(e)), "info") + except Exception as e: # pragma: no cover + raise exceptions.ProtocolException("Error in WebSocket connection: {}".format(repr(e))) diff --git a/mitmproxy/proxy/protocol/websockets.py b/mitmproxy/proxy/protocol/websockets.py deleted file mode 100644 index ca1d05cb..00000000 --- a/mitmproxy/proxy/protocol/websockets.py +++ /dev/null @@ -1,111 +0,0 @@ -import socket -import struct -from OpenSSL import SSL -from mitmproxy import exceptions -from mitmproxy.proxy.protocol import base -from mitmproxy.utils import strutils -from mitmproxy.net import tcp -from mitmproxy.net import websockets - - -class WebSocketsLayer(base.Layer): - """ - WebSockets layer to intercept, modify, and forward WebSockets connections - - Only version 13 is supported (as specified in RFC6455) - Only HTTP/1.1-initiated connections are supported. - - The client starts by sending an Upgrade-request. - In order to determine the handshake and negotiate the correct protocol - and extensions, the Upgrade-request is forwarded to the server. - The response from the server is then parsed and negotiated settings are extracted. - Finally the handshake is completed by forwarding the server-response to the client. - After that, only WebSockets frames are exchanged. - - PING/PONG frames pass through and must be answered by the other endpoint. - - CLOSE frames are forwarded before this WebSocketsLayer terminates. - - This layer is transparent to any negotiated extensions. - This layer is transparent to any negotiated subprotocols. - Only raw frames are forwarded to the other endpoint. - """ - - def __init__(self, ctx, flow): - super().__init__(ctx) - self._flow = flow - - self.client_key = websockets.get_client_key(self._flow.request.headers) - self.client_protocol = websockets.get_protocol(self._flow.request.headers) - self.client_extensions = websockets.get_extensions(self._flow.request.headers) - - self.server_accept = websockets.get_server_accept(self._flow.response.headers) - self.server_protocol = websockets.get_protocol(self._flow.response.headers) - self.server_extensions = websockets.get_extensions(self._flow.response.headers) - - def _handle_frame(self, frame, source_conn, other_conn, is_server): - sender = "server" if is_server else "client" - self.log( - "WebSockets Frame received from {}".format(sender), - "debug", - [repr(frame)] - ) - - if frame.header.opcode & 0x8 == 0: - self.log( - "{direction} websocket {direction} {server}".format( - server=repr(self.server_conn.address), - direction="<-" if is_server else "->", - ), - "info", - strutils.bytes_to_escaped_str(frame.payload, keep_spacing=True).splitlines() - ) - # forward the data frame to the other side - other_conn.send(bytes(frame)) - elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG): - # just forward the ping/pong to the other side - other_conn.send(bytes(frame)) - elif frame.header.opcode == websockets.OPCODE.CLOSE: - code = '(status code missing)' - msg = None - reason = '(message missing)' - if len(frame.payload) >= 2: - code, = struct.unpack('!H', frame.payload[:2]) - msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code') - if len(frame.payload) > 2: - reason = frame.payload[2:] - self.log("WebSockets connection closed by {}: {} {}, {}".format(sender, code, msg, reason), "info") - - other_conn.send(bytes(frame)) - # close the connection - return False - else: - self.log("Unknown WebSockets frame received from {}".format(sender), "info", [repr(frame)]) - # unknown frame - just forward it - other_conn.send(bytes(frame)) - - # continue the connection - return True - - def __call__(self): - client = self.client_conn.connection - server = self.server_conn.connection - conns = [client, server] - - try: - while not self.channel.should_exit.is_set(): - r = tcp.ssl_read_select(conns, 1) - for conn in r: - source_conn = self.client_conn if conn == client else self.server_conn - other_conn = self.server_conn if conn == client else self.client_conn - is_server = (conn == self.server_conn.connection) - - frame = websockets.Frame.from_file(source_conn.rfile) - - if not self._handle_frame(frame, source_conn, other_conn, is_server): - return - except (socket.error, exceptions.TcpException, SSL.Error) as e: - self.log("WebSockets connection closed unexpectedly by {}: {}".format( - "server" if is_server else "client", repr(e)), "info") - except Exception as e: # pragma: no cover - raise exceptions.ProtocolException("Error in WebSockets connection: {}".format(repr(e))) diff --git a/test/mitmproxy/protocol/test_websocket.py b/test/mitmproxy/protocol/test_websocket.py new file mode 100644 index 00000000..93997045 --- /dev/null +++ b/test/mitmproxy/protocol/test_websocket.py @@ -0,0 +1,297 @@ +import pytest +import os +import tempfile +import traceback + +from mitmproxy import options +from mitmproxy import exceptions +from mitmproxy.proxy.config import ProxyConfig + +import mitmproxy.net +from mitmproxy.net import http +from ...mitmproxy.net import tservers as net_tservers +from .. import tservers + +from mitmproxy.net import websockets + + +class _WebSocketServerBase(net_tservers.ServerTestBase): + + class handler(mitmproxy.net.tcp.BaseHandler): + + def handle(self): + try: + request = http.http1.read_request(self.rfile) + assert websockets.check_handshake(request.headers) + + response = http.Response( + "HTTP/1.1", + 101, + reason=http.status_codes.RESPONSES.get(101), + headers=http.Headers( + connection='upgrade', + upgrade='websocket', + sec_websocket_accept=b'', + ), + content=b'', + ) + self.wfile.write(http.http1.assemble_response(response)) + self.wfile.flush() + + self.server.handle_websockets(self.rfile, self.wfile) + except: + traceback.print_exc() + + +class _WebSocketTestBase: + + @classmethod + def setup_class(cls): + opts = cls.get_options() + cls.config = ProxyConfig(opts) + + tmaster = tservers.TestMaster(opts, cls.config) + cls.proxy = tservers.ProxyThread(tmaster) + cls.proxy.start() + + @classmethod + def teardown_class(cls): + cls.proxy.shutdown() + + @classmethod + def get_options(cls): + opts = options.Options( + listen_port=0, + no_upstream_cert=False, + ssl_insecure=True, + websockets=True, + ) + opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return opts + + @property + def master(self): + return self.proxy.tmaster + + def setup(self): + self.master.reset([]) + self.server.server.handle_websockets = self.handle_websockets + + def _setup_connection(self): + client = mitmproxy.net.tcp.TCPClient(("127.0.0.1", self.proxy.port)) + client.connect() + + request = http.Request( + "authority", + "CONNECT", + "", + "localhost", + self.server.server.address.port, + "", + "HTTP/1.1", + content=b'') + client.wfile.write(http.http1.assemble_request(request)) + client.wfile.flush() + + response = http.http1.read_response(client.rfile, request) + + if self.ssl: + client.convert_to_ssl() + assert client.ssl_established + + request = http.Request( + "relative", + "GET", + "http", + "localhost", + self.server.server.address.port, + "/ws", + "HTTP/1.1", + headers=http.Headers( + connection="upgrade", + upgrade="websocket", + sec_websocket_version="13", + sec_websocket_key="1234", + ), + content=b'') + client.wfile.write(http.http1.assemble_request(request)) + client.wfile.flush() + + response = http.http1.read_response(client.rfile, request) + assert websockets.check_handshake(response.headers) + + return client + + +class _WebSocketTest(_WebSocketTestBase, _WebSocketServerBase): + + @classmethod + def setup_class(cls): + _WebSocketTestBase.setup_class() + _WebSocketServerBase.setup_class(ssl=cls.ssl) + + @classmethod + def teardown_class(cls): + _WebSocketTestBase.teardown_class() + _WebSocketServerBase.teardown_class() + + +class TestSimple(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + def test_simple(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'server-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'client-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + +class TestSimpleTLS(_WebSocketTest): + ssl = True + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + def test_simple_tls(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'server-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'client-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + +class TestPing(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'foobar' + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received'))) + wfile.flush() + + def test_ping(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.TEXT + assert frame.payload == b'pong-received' + + +class TestPong(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + frame = websockets.Frame.from_file(rfile) + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'foobar' + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + wfile.flush() + + def test_pong(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'foobar' + + +class TestClose(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(rfile) + + def test_close(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + def test_close_payload_1(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + client.wfile.flush() + + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + def test_close_payload_2(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + client.wfile.flush() + + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + +class TestInvalidFrame(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar'))) + wfile.flush() + + def test_invalid_frame(self): + client = self._setup_connection() + + # with pytest.raises(exceptions.TcpDisconnect): + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == 15 + assert frame.payload == b'foobar' diff --git a/test/mitmproxy/protocol/test_websockets.py b/test/mitmproxy/protocol/test_websockets.py deleted file mode 100644 index 71cbb5f4..00000000 --- a/test/mitmproxy/protocol/test_websockets.py +++ /dev/null @@ -1,297 +0,0 @@ -import pytest -import os -import tempfile -import traceback - -from mitmproxy import options -from mitmproxy import exceptions -from mitmproxy.proxy.config import ProxyConfig - -import mitmproxy.net -from mitmproxy.net import http -from ...mitmproxy.net import tservers as net_tservers -from .. import tservers - -from mitmproxy.net import websockets - - -class _WebSocketsServerBase(net_tservers.ServerTestBase): - - class handler(mitmproxy.net.tcp.BaseHandler): - - def handle(self): - try: - request = http.http1.read_request(self.rfile) - assert websockets.check_handshake(request.headers) - - response = http.Response( - "HTTP/1.1", - 101, - reason=http.status_codes.RESPONSES.get(101), - headers=http.Headers( - connection='upgrade', - upgrade='websocket', - sec_websocket_accept=b'', - ), - content=b'', - ) - self.wfile.write(http.http1.assemble_response(response)) - self.wfile.flush() - - self.server.handle_websockets(self.rfile, self.wfile) - except: - traceback.print_exc() - - -class _WebSocketsTestBase: - - @classmethod - def setup_class(cls): - opts = cls.get_options() - cls.config = ProxyConfig(opts) - - tmaster = tservers.TestMaster(opts, cls.config) - cls.proxy = tservers.ProxyThread(tmaster) - cls.proxy.start() - - @classmethod - def teardown_class(cls): - cls.proxy.shutdown() - - @classmethod - def get_options(cls): - opts = options.Options( - listen_port=0, - no_upstream_cert=False, - ssl_insecure=True, - websockets=True, - ) - opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") - return opts - - @property - def master(self): - return self.proxy.tmaster - - def setup(self): - self.master.reset([]) - self.server.server.handle_websockets = self.handle_websockets - - def _setup_connection(self): - client = mitmproxy.net.tcp.TCPClient(("127.0.0.1", self.proxy.port)) - client.connect() - - request = http.Request( - "authority", - "CONNECT", - "", - "localhost", - self.server.server.address.port, - "", - "HTTP/1.1", - content=b'') - client.wfile.write(http.http1.assemble_request(request)) - client.wfile.flush() - - response = http.http1.read_response(client.rfile, request) - - if self.ssl: - client.convert_to_ssl() - assert client.ssl_established - - request = http.Request( - "relative", - "GET", - "http", - "localhost", - self.server.server.address.port, - "/ws", - "HTTP/1.1", - headers=http.Headers( - connection="upgrade", - upgrade="websocket", - sec_websocket_version="13", - sec_websocket_key="1234", - ), - content=b'') - client.wfile.write(http.http1.assemble_request(request)) - client.wfile.flush() - - response = http.http1.read_response(client.rfile, request) - assert websockets.check_handshake(response.headers) - - return client - - -class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase): - - @classmethod - def setup_class(cls): - _WebSocketsTestBase.setup_class() - _WebSocketsServerBase.setup_class(ssl=cls.ssl) - - @classmethod - def teardown_class(cls): - _WebSocketsTestBase.teardown_class() - _WebSocketsServerBase.teardown_class() - - -class TestSimple(_WebSocketsTest): - - @classmethod - def handle_websockets(cls, rfile, wfile): - wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) - wfile.flush() - - frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) - wfile.flush() - - def test_simple(self): - client = self._setup_connection() - - frame = websockets.Frame.from_file(client.rfile) - assert frame.payload == b'server-foobar' - - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) - client.wfile.flush() - - frame = websockets.Frame.from_file(client.rfile) - assert frame.payload == b'client-foobar' - - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() - - -class TestSimpleTLS(_WebSocketsTest): - ssl = True - - @classmethod - def handle_websockets(cls, rfile, wfile): - wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) - wfile.flush() - - frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) - wfile.flush() - - def test_simple_tls(self): - client = self._setup_connection() - - frame = websockets.Frame.from_file(client.rfile) - assert frame.payload == b'server-foobar' - - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) - client.wfile.flush() - - frame = websockets.Frame.from_file(client.rfile) - assert frame.payload == b'client-foobar' - - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() - - -class TestPing(_WebSocketsTest): - - @classmethod - def handle_websockets(cls, rfile, wfile): - wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) - wfile.flush() - - frame = websockets.Frame.from_file(rfile) - assert frame.header.opcode == websockets.OPCODE.PONG - assert frame.payload == b'foobar' - - wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received'))) - wfile.flush() - - def test_ping(self): - client = self._setup_connection() - - frame = websockets.Frame.from_file(client.rfile) - assert frame.header.opcode == websockets.OPCODE.PING - assert frame.payload == b'foobar' - - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) - client.wfile.flush() - - frame = websockets.Frame.from_file(client.rfile) - assert frame.header.opcode == websockets.OPCODE.TEXT - assert frame.payload == b'pong-received' - - -class TestPong(_WebSocketsTest): - - @classmethod - def handle_websockets(cls, rfile, wfile): - frame = websockets.Frame.from_file(rfile) - assert frame.header.opcode == websockets.OPCODE.PING - assert frame.payload == b'foobar' - - wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) - wfile.flush() - - def test_pong(self): - client = self._setup_connection() - - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) - client.wfile.flush() - - frame = websockets.Frame.from_file(client.rfile) - assert frame.header.opcode == websockets.OPCODE.PONG - assert frame.payload == b'foobar' - - -class TestClose(_WebSocketsTest): - - @classmethod - def handle_websockets(cls, rfile, wfile): - frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) - wfile.flush() - - with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(rfile) - - def test_close(self): - client = self._setup_connection() - - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() - - with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) - - def test_close_payload_1(self): - client = self._setup_connection() - - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) - client.wfile.flush() - - with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) - - def test_close_payload_2(self): - client = self._setup_connection() - - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) - client.wfile.flush() - - with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) - - -class TestInvalidFrame(_WebSocketsTest): - - @classmethod - def handle_websockets(cls, rfile, wfile): - wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar'))) - wfile.flush() - - def test_invalid_frame(self): - client = self._setup_connection() - - # with pytest.raises(exceptions.TcpDisconnect): - frame = websockets.Frame.from_file(client.rfile) - assert frame.header.opcode == 15 - assert frame.payload == b'foobar' -- cgit v1.2.3