aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Kriechbaumer <thomas@kriechbaumer.name>2017-12-12 21:47:24 +0100
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2017-12-12 22:09:46 +0100
commitf5fafbfcb56bbc3fb7cca7ed32dd7b3b41c39e83 (patch)
tree6851c801d2609db1531a8bcdd8acc526a84e4b55
parent70e1409261adfd165b8473f1d21aa760023795d7 (diff)
downloadmitmproxy-f5fafbfcb56bbc3fb7cca7ed32dd7b3b41c39e83.tar.gz
mitmproxy-f5fafbfcb56bbc3fb7cca7ed32dd7b3b41c39e83.tar.bz2
mitmproxy-f5fafbfcb56bbc3fb7cca7ed32dd7b3b41c39e83.zip
vendoring of wsproto
https://github.com/python-hyper/wsproto.git commit 5ea2da61266796666f5de6461aaae22e6b00deba
-rw-r--r--mitmproxy/contrib/wsproto/compat.py20
-rw-r--r--mitmproxy/contrib/wsproto/connection.py477
-rw-r--r--mitmproxy/contrib/wsproto/events.py81
-rw-r--r--mitmproxy/contrib/wsproto/extensions.py257
-rw-r--r--mitmproxy/contrib/wsproto/frame_protocol.py579
-rw-r--r--mitmproxy/proxy/protocol/websocket.py8
-rw-r--r--setup.py1
7 files changed, 1419 insertions, 4 deletions
diff --git a/mitmproxy/contrib/wsproto/compat.py b/mitmproxy/contrib/wsproto/compat.py
new file mode 100644
index 00000000..1911f83c
--- /dev/null
+++ b/mitmproxy/contrib/wsproto/compat.py
@@ -0,0 +1,20 @@
+# flake8: noqa
+
+import sys
+
+
+PY2 = sys.version_info.major == 2
+PY3 = sys.version_info.major == 3
+
+
+if PY3:
+ unicode = str
+
+ def Utf8Validator():
+ return None
+else:
+ unicode = unicode
+ try:
+ from wsaccel.utf8validator import Utf8Validator
+ except ImportError:
+ from .utf8validator import Utf8Validator
diff --git a/mitmproxy/contrib/wsproto/connection.py b/mitmproxy/contrib/wsproto/connection.py
new file mode 100644
index 00000000..f994cd3a
--- /dev/null
+++ b/mitmproxy/contrib/wsproto/connection.py
@@ -0,0 +1,477 @@
+# -*- coding: utf-8 -*-
+"""
+wsproto/connection
+~~~~~~~~~~~~~~
+
+An implementation of a WebSocket connection.
+"""
+
+import os
+import base64
+import hashlib
+from collections import deque
+
+from enum import Enum
+
+import h11
+
+from .events import (
+ ConnectionRequested, ConnectionEstablished, ConnectionClosed,
+ ConnectionFailed, TextReceived, BytesReceived, PingReceived, PongReceived
+)
+from .frame_protocol import FrameProtocol, ParseFailed, CloseReason, Opcode
+
+
+# RFC6455, Section 1.3 - Opening Handshake
+ACCEPT_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
+
+
+class ConnectionState(Enum):
+ """
+ RFC 6455, Section 4 - Opening Handshake
+ """
+ CONNECTING = 0
+ OPEN = 1
+ CLOSING = 2
+ CLOSED = 3
+
+
+class ConnectionType(Enum):
+ CLIENT = 1
+ SERVER = 2
+
+
+CLIENT = ConnectionType.CLIENT
+SERVER = ConnectionType.SERVER
+
+
+# Some convenience utilities for working with HTTP headers
+def _normed_header_dict(h11_headers):
+ # This mangles Set-Cookie headers. But it happens that we don't care about
+ # any of those, so it's OK. For every other HTTP header, if there are
+ # multiple instances then you're allowed to join them together with
+ # commas.
+ name_to_values = {}
+ for name, value in h11_headers:
+ name_to_values.setdefault(name, []).append(value)
+ name_to_normed_value = {}
+ for name, values in name_to_values.items():
+ name_to_normed_value[name] = b", ".join(values)
+ return name_to_normed_value
+
+
+# We use this for parsing the proposed protocol list, and for parsing the
+# proposed and accepted extension lists. For the proposed protocol list it's
+# fine, because the ABNF is just 1#token. But for the extension lists, it's
+# wrong, because those can contain quoted strings, which can in turn contain
+# commas. XX FIXME
+def _split_comma_header(value):
+ return [piece.decode('ascii').strip() for piece in value.split(b',')]
+
+
+class WSConnection(object):
+ """
+ A low-level WebSocket connection object.
+
+ This wraps two other protocol objects, an HTTP/1.1 protocol object used
+ to do the initial HTTP upgrade handshake and a WebSocket frame protocol
+ object used to exchange messages and other control frames.
+
+ :param conn_type: Whether this object is on the client- or server-side of
+ a connection. To initialise as a client pass ``CLIENT`` otherwise
+ pass ``SERVER``.
+ :type conn_type: ``ConnectionType``
+
+ :param host: The hostname to pass to the server when acting as a client.
+ :type host: ``str``
+
+ :param resource: The resource (aka path) to pass to the server when acting
+ as a client.
+ :type resource: ``str``
+
+ :param extensions: A list of extensions to use on this connection.
+ Extensions should be instances of a subclass of
+ :class:`Extension <wsproto.extensions.Extension>`.
+
+ :param subprotocols: A list of subprotocols to request when acting as a
+ client, ordered by preference. This has no impact on the connection
+ itself.
+ :type subprotocol: ``list`` of ``str``
+ """
+
+ def __init__(self, conn_type, host=None, resource=None, extensions=None,
+ subprotocols=None):
+ self.client = conn_type is ConnectionType.CLIENT
+
+ self.host = host
+ self.resource = resource
+
+ self.subprotocols = subprotocols or []
+ self.extensions = extensions or []
+
+ self.version = b'13'
+
+ self._state = ConnectionState.CONNECTING
+ self._close_reason = None
+
+ self._nonce = None
+ self._outgoing = b''
+ self._events = deque()
+ self._proto = None
+
+ if self.client:
+ self._upgrade_connection = h11.Connection(h11.CLIENT)
+ else:
+ self._upgrade_connection = h11.Connection(h11.SERVER)
+
+ if self.client:
+ if self.host is None:
+ raise ValueError(
+ "Host must not be None for a client-side connection.")
+ if self.resource is None:
+ raise ValueError(
+ "Resource must not be None for a client-side connection.")
+ self.initiate_connection()
+
+ def initiate_connection(self):
+ self._generate_nonce()
+
+ headers = {
+ b"Host": self.host.encode('ascii'),
+ b"Upgrade": b'WebSocket',
+ b"Connection": b'Upgrade',
+ b"Sec-WebSocket-Key": self._nonce,
+ b"Sec-WebSocket-Version": self.version,
+ }
+
+ if self.subprotocols:
+ headers[b"Sec-WebSocket-Protocol"] = ", ".join(self.subprotocols)
+
+ if self.extensions:
+ offers = {e.name: e.offer(self) for e in self.extensions}
+ extensions = []
+ for name, params in offers.items():
+ if params is True:
+ extensions.append(name.encode('ascii'))
+ elif params:
+ # py34 annoyance: doesn't support bytestring formatting
+ extensions.append(('%s; %s' % (name, params))
+ .encode("ascii"))
+ if extensions:
+ headers[b'Sec-WebSocket-Extensions'] = b', '.join(extensions)
+
+ upgrade = h11.Request(method=b'GET', target=self.resource,
+ headers=headers.items())
+ self._outgoing += self._upgrade_connection.send(upgrade)
+
+ def send_data(self, payload, final=True):
+ """
+ Send a message or part of a message to the remote peer.
+
+ If ``final`` is ``False`` it indicates that this is part of a longer
+ message. If ``final`` is ``True`` it indicates that this is either a
+ self-contained message or the last part of a longer message.
+
+ If ``payload`` is of type ``bytes`` then the message is flagged as
+ being binary If it is of type ``str`` encoded as UTF-8 and sent as
+ text.
+
+ :param payload: The message body to send.
+ :type payload: ``bytes`` or ``str``
+
+ :param final: Whether there are more parts to this message to be sent.
+ :type final: ``bool``
+ """
+
+ self._outgoing += self._proto.send_data(payload, final)
+
+ def close(self, code=CloseReason.NORMAL_CLOSURE, reason=None):
+ self._outgoing += self._proto.close(code, reason)
+ self._state = ConnectionState.CLOSING
+
+ @property
+ def closed(self):
+ return self._state is ConnectionState.CLOSED
+
+ def bytes_to_send(self, amount=None):
+ """
+ Return any data that is to be sent to the remote peer.
+
+ :param amount: (optional) The maximum number of bytes to be provided.
+ If ``None`` or not provided it will return all available bytes.
+ :type amount: ``int``
+ """
+
+ if amount is None:
+ data = self._outgoing
+ self._outgoing = b''
+ else:
+ data = self._outgoing[:amount]
+ self._outgoing = self._outgoing[amount:]
+
+ return data
+
+ def receive_bytes(self, data):
+ """
+ Pass some received bytes to the connection for processing.
+
+ :param data: The data received from the remote peer.
+ :type data: ``bytes``
+ """
+
+ if data is None and self._state is ConnectionState.OPEN:
+ # "If _The WebSocket Connection is Closed_ and no Close control
+ # frame was received by the endpoint (such as could occur if the
+ # underlying transport connection is lost), _The WebSocket
+ # Connection Close Code_ is considered to be 1006."
+ self._events.append(ConnectionClosed(CloseReason.ABNORMAL_CLOSURE))
+ self._state = ConnectionState.CLOSED
+ return
+ elif data is None:
+ self._state = ConnectionState.CLOSED
+ return
+
+ if self._state is ConnectionState.CONNECTING:
+ event, data = self._process_upgrade(data)
+ if event is not None:
+ self._events.append(event)
+
+ if self._state is ConnectionState.OPEN:
+ self._proto.receive_bytes(data)
+
+ def _process_upgrade(self, data):
+ self._upgrade_connection.receive_data(data)
+ while True:
+ try:
+ event = self._upgrade_connection.next_event()
+ except h11.RemoteProtocolError:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Bad HTTP message"), b''
+ if event is h11.NEED_DATA:
+ break
+ elif self.client and isinstance(event, (h11.InformationalResponse,
+ h11.Response)):
+ data = self._upgrade_connection.trailing_data[0]
+ return self._establish_client_connection(event), data
+ elif not self.client and isinstance(event, h11.Request):
+ return self._process_connection_request(event), None
+ else:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Bad HTTP message"), b''
+
+ self._incoming = b''
+ return None, None
+
+ def events(self):
+ """
+ Return a generator that provides any events that have been generated
+ by protocol activity.
+
+ :returns: generator
+ """
+
+ while self._events:
+ yield self._events.popleft()
+
+ if self._proto is None:
+ return
+
+ try:
+ for frame in self._proto.received_frames():
+ if frame.opcode is Opcode.PING:
+ assert frame.frame_finished and frame.message_finished
+ self._outgoing += self._proto.pong(frame.payload)
+ yield PingReceived(frame.payload)
+
+ elif frame.opcode is Opcode.PONG:
+ assert frame.frame_finished and frame.message_finished
+ yield PongReceived(frame.payload)
+
+ elif frame.opcode is Opcode.CLOSE:
+ code, reason = frame.payload
+ self.close(code, reason)
+ yield ConnectionClosed(code, reason)
+
+ elif frame.opcode is Opcode.TEXT:
+ yield TextReceived(frame.payload,
+ frame.frame_finished,
+ frame.message_finished)
+
+ elif frame.opcode is Opcode.BINARY:
+ yield BytesReceived(frame.payload,
+ frame.frame_finished,
+ frame.message_finished)
+ except ParseFailed as exc:
+ # XX FIXME: apparently autobahn intentionally deviates from the
+ # spec in that on protocol errors it just closes the connection
+ # rather than trying to send a CLOSE frame. Investigate whether we
+ # should do the same.
+ self.close(code=exc.code, reason=str(exc))
+ yield ConnectionClosed(exc.code, reason=str(exc))
+
+ def _generate_nonce(self):
+ # os.urandom may be overkill for this use case, but I don't think this
+ # is a bottleneck, and better safe than sorry...
+ self._nonce = base64.b64encode(os.urandom(16))
+
+ def _generate_accept_token(self, token):
+ accept_token = token + ACCEPT_GUID
+ accept_token = hashlib.sha1(accept_token).digest()
+ return base64.b64encode(accept_token)
+
+ def _establish_client_connection(self, event):
+ if event.status_code != 101:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Bad status code from server")
+ headers = _normed_header_dict(event.headers)
+ if headers[b'connection'].lower() != b'upgrade':
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Connection: Upgrade header")
+ if headers[b'upgrade'].lower() != b'websocket':
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Upgrade: WebSocket header")
+
+ accept_token = self._generate_accept_token(self._nonce)
+ if headers[b'sec-websocket-accept'] != accept_token:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Bad accept token")
+
+ subprotocol = headers.get(b'sec-websocket-protocol', None)
+ if subprotocol is not None:
+ subprotocol = subprotocol.decode('ascii')
+ if subprotocol not in self.subprotocols:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "unrecognized subprotocol {!r}"
+ .format(subprotocol))
+
+ extensions = headers.get(b'sec-websocket-extensions', None)
+ if extensions:
+ accepts = _split_comma_header(extensions)
+
+ for accept in accepts:
+ name = accept.split(';', 1)[0].strip()
+ for extension in self.extensions:
+ if extension.name == name:
+ extension.finalize(self, accept)
+ break
+ else:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "unrecognized extension {!r}"
+ .format(name))
+
+ self._proto = FrameProtocol(self.client, self.extensions)
+ self._state = ConnectionState.OPEN
+ return ConnectionEstablished(subprotocol, extensions)
+
+ def _process_connection_request(self, event):
+ if event.method != b'GET':
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Request method must be GET")
+ headers = _normed_header_dict(event.headers)
+ if headers[b'connection'].lower() != b'upgrade':
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Connection: Upgrade header")
+ if headers[b'upgrade'].lower() != b'websocket':
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Upgrade: WebSocket header")
+
+ if b'sec-websocket-version' not in headers:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Sec-WebSocket-Version header")
+ # XX FIXME: need to check Sec-Websocket-Version, and respond with a
+ # 400 if it's not what we expect
+
+ if b'sec-websocket-protocol' in headers:
+ proposed_subprotocols = _split_comma_header(
+ headers[b'sec-websocket-protocol'])
+ else:
+ proposed_subprotocols = []
+
+ if b'sec-websocket-key' not in headers:
+ return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
+ "Missing Sec-WebSocket-Key header")
+
+ return ConnectionRequested(proposed_subprotocols, event)
+
+ def _extension_accept(self, extensions_header):
+ accepts = {}
+ offers = _split_comma_header(extensions_header)
+
+ for offer in offers:
+ name = offer.split(';', 1)[0].strip()
+ for extension in self.extensions:
+ if extension.name == name:
+ accept = extension.accept(self, offer)
+ if accept is True:
+ accepts[extension.name] = True
+ elif accept is not False and accept is not None:
+ accepts[extension.name] = accept.encode('ascii')
+
+ if accepts:
+ extensions = []
+ for name, params in accepts.items():
+ if params is True:
+ extensions.append(name.encode('ascii'))
+ else:
+ # py34 annoyance: doesn't support bytestring formatting
+ params = params.decode("ascii")
+ extensions.append(('%s; %s' % (name, params))
+ .encode("ascii"))
+ return b', '.join(extensions)
+
+ return None
+
+ def accept(self, event, subprotocol=None):
+ request = event.h11request
+ request_headers = _normed_header_dict(request.headers)
+
+ nonce = request_headers[b'sec-websocket-key']
+ accept_token = self._generate_accept_token(nonce)
+
+ headers = {
+ b"Upgrade": b'WebSocket',
+ b"Connection": b'Upgrade',
+ b"Sec-WebSocket-Accept": accept_token,
+ }
+
+ if subprotocol is not None:
+ if subprotocol not in event.proposed_subprotocols:
+ raise ValueError(
+ "unexpected subprotocol {!r}".format(subprotocol))
+ headers[b'Sec-WebSocket-Protocol'] = subprotocol
+
+ extensions = request_headers.get(b'sec-websocket-extensions', None)
+ if extensions:
+ accepts = self._extension_accept(extensions)
+ if accepts:
+ headers[b"Sec-WebSocket-Extensions"] = accepts
+
+ response = h11.InformationalResponse(status_code=101,
+ headers=headers.items())
+ self._outgoing += self._upgrade_connection.send(response)
+ self._proto = FrameProtocol(self.client, self.extensions)
+ self._state = ConnectionState.OPEN
+
+ def ping(self, payload=None):
+ """
+ Send a PING message to the peer.
+
+ :param payload: an optional payload to send with the message
+ """
+
+ payload = bytes(payload or b'')
+ self._outgoing += self._proto.ping(payload)
+
+ def pong(self, payload=None):
+ """
+ Send a PONG message to the peer.
+
+ This method can be used to send an unsolicted PONG to the peer.
+ It is not needed otherwise since every received PING causes a
+ corresponding PONG to be sent automatically.
+
+ :param payload: an optional payload to send with the message
+ """
+
+ payload = bytes(payload or b'')
+ self._outgoing += self._proto.pong(payload)
diff --git a/mitmproxy/contrib/wsproto/events.py b/mitmproxy/contrib/wsproto/events.py
new file mode 100644
index 00000000..73ce27aa
--- /dev/null
+++ b/mitmproxy/contrib/wsproto/events.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+"""
+wsproto/events
+~~~~~~~~~~
+
+Events that result from processing data on a WebSocket connection.
+"""
+
+
+class ConnectionRequested(object):
+ def __init__(self, proposed_subprotocols, h11request):
+ self.proposed_subprotocols = proposed_subprotocols
+ self.h11request = h11request
+
+ def __repr__(self):
+ path = self.h11request.target
+
+ headers = dict(self.h11request.headers)
+ host = headers[b'host']
+ version = headers[b'sec-websocket-version']
+ subprotocol = headers.get(b'sec-websocket-protocol', None)
+ extensions = []
+
+ fmt = '<%s host=%s path=%s version=%s subprotocol=%r extensions=%r>'
+ return fmt % (self.__class__.__name__, host, path, version,
+ subprotocol, extensions)
+
+
+class ConnectionEstablished(object):
+ def __init__(self, subprotocol=None, extensions=None):
+ self.subprotocol = subprotocol
+ self.extensions = extensions
+ if self.extensions is None:
+ self.extensions = []
+
+ def __repr__(self):
+ return '<ConnectionEstablished subprotocol=%r extensions=%r>' % \
+ (self.subprotocol, self.extensions)
+
+
+class ConnectionClosed(object):
+ def __init__(self, code, reason=None):
+ self.code = code
+ self.reason = reason
+
+ def __repr__(self):
+ return '<%s code=%r reason="%s">' % (self.__class__.__name__,
+ self.code, self.reason)
+
+
+class ConnectionFailed(ConnectionClosed):
+ pass
+
+
+class DataReceived(object):
+ def __init__(self, data, frame_finished, message_finished):
+ self.data = data
+ # This has no semantic content, but is provided just in case some
+ # weird edge case user wants to be able to reconstruct the
+ # fragmentation pattern of the original stream. You don't want it:
+ self.frame_finished = frame_finished
+ # This is the field that you almost certainly want:
+ self.message_finished = message_finished
+
+
+class TextReceived(DataReceived):
+ pass
+
+
+class BytesReceived(DataReceived):
+ pass
+
+
+class PingReceived(object):
+ def __init__(self, payload):
+ self.payload = payload
+
+
+class PongReceived(object):
+ def __init__(self, payload):
+ self.payload = payload
diff --git a/mitmproxy/contrib/wsproto/extensions.py b/mitmproxy/contrib/wsproto/extensions.py
new file mode 100644
index 00000000..f7cf4fb6
--- /dev/null
+++ b/mitmproxy/contrib/wsproto/extensions.py
@@ -0,0 +1,257 @@
+# -*- coding: utf-8 -*-
+"""
+wsproto/extensions
+~~~~~~~~~~~~~~
+
+WebSocket extensions.
+"""
+
+import zlib
+
+from .frame_protocol import CloseReason, Opcode, RsvBits
+
+
+class Extension(object):
+ name = None
+
+ def enabled(self):
+ return False
+
+ def offer(self, connection):
+ pass
+
+ def accept(self, connection, offer):
+ pass
+
+ def finalize(self, connection, offer):
+ pass
+
+ def frame_inbound_header(self, proto, opcode, rsv, payload_length):
+ return RsvBits(False, False, False)
+
+ def frame_inbound_payload_data(self, proto, data):
+ return data
+
+ def frame_inbound_complete(self, proto, fin):
+ pass
+
+ def frame_outbound(self, proto, opcode, rsv, data, fin):
+ return (rsv, data)
+
+
+class PerMessageDeflate(Extension):
+ name = 'permessage-deflate'
+
+ DEFAULT_CLIENT_MAX_WINDOW_BITS = 15
+ DEFAULT_SERVER_MAX_WINDOW_BITS = 15
+
+ def __init__(self, client_no_context_takeover=False,
+ client_max_window_bits=None, server_no_context_takeover=False,
+ server_max_window_bits=None):
+ self.client_no_context_takeover = client_no_context_takeover
+ if client_max_window_bits is None:
+ client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS
+ self.client_max_window_bits = client_max_window_bits
+ self.server_no_context_takeover = server_no_context_takeover
+ if server_max_window_bits is None:
+ server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS
+ self.server_max_window_bits = server_max_window_bits
+
+ self._compressor = None
+ self._decompressor = None
+ # This refers to the current frame
+ self._inbound_is_compressible = None
+ # This refers to the ongoing message (which might span multiple
+ # frames). Only the first frame in a fragmented message is flagged for
+ # compression, so this carries that bit forward.
+ self._inbound_compressed = None
+
+ self._enabled = False
+
+ def _compressible_opcode(self, opcode):
+ return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION)
+
+ def enabled(self):
+ return self._enabled
+
+ def offer(self, connection):
+ parameters = [
+ 'client_max_window_bits=%d' % self.client_max_window_bits,
+ 'server_max_window_bits=%d' % self.server_max_window_bits,
+ ]
+
+ if self.client_no_context_takeover:
+ parameters.append('client_no_context_takeover')
+ if self.server_no_context_takeover:
+ parameters.append('server_no_context_takeover')
+
+ return '; '.join(parameters)
+
+ def finalize(self, connection, offer):
+ bits = [b.strip() for b in offer.split(';')]
+ for bit in bits[1:]:
+ if bit.startswith('client_no_context_takeover'):
+ self.client_no_context_takeover = True
+ elif bit.startswith('server_no_context_takeover'):
+ self.server_no_context_takeover = True
+ elif bit.startswith('client_max_window_bits'):
+ self.client_max_window_bits = int(bit.split('=', 1)[1].strip())
+ elif bit.startswith('server_max_window_bits'):
+ self.server_max_window_bits = int(bit.split('=', 1)[1].strip())
+
+ self._enabled = True
+
+ def _parse_params(self, params):
+ client_max_window_bits = None
+ server_max_window_bits = None
+
+ bits = [b.strip() for b in params.split(';')]
+ for bit in bits[1:]:
+ if bit.startswith('client_no_context_takeover'):
+ self.client_no_context_takeover = True
+ elif bit.startswith('server_no_context_takeover'):
+ self.server_no_context_takeover = True
+ elif bit.startswith('client_max_window_bits'):
+ if '=' in bit:
+ client_max_window_bits = int(bit.split('=', 1)[1].strip())
+ else:
+ client_max_window_bits = self.client_max_window_bits
+ elif bit.startswith('server_max_window_bits'):
+ if '=' in bit:
+ server_max_window_bits = int(bit.split('=', 1)[1].strip())
+ else:
+ server_max_window_bits = self.server_max_window_bits
+
+ return client_max_window_bits, server_max_window_bits
+
+ def accept(self, connection, offer):
+ client_max_window_bits, server_max_window_bits = \
+ self._parse_params(offer)
+
+ self._enabled = True
+
+ parameters = []
+
+ if self.client_no_context_takeover:
+ parameters.append('client_no_context_takeover')
+ if client_max_window_bits is not None:
+ parameters.append('client_max_window_bits=%d' %
+ client_max_window_bits)
+ self.client_max_window_bits = client_max_window_bits
+ if self.server_no_context_takeover:
+ parameters.append('server_no_context_takeover')
+ if server_max_window_bits is not None:
+ parameters.append('server_max_window_bits=%d' %
+ server_max_window_bits)
+ self.server_max_window_bits = server_max_window_bits
+
+ return '; '.join(parameters)
+
+ def frame_inbound_header(self, proto, opcode, rsv, payload_length):
+ if rsv.rsv1 and opcode.iscontrol():
+ return CloseReason.PROTOCOL_ERROR
+ elif rsv.rsv1 and opcode is Opcode.CONTINUATION:
+ return CloseReason.PROTOCOL_ERROR
+
+ self._inbound_is_compressible = self._compressible_opcode(opcode)
+
+ if self._inbound_compressed is None:
+ self._inbound_compressed = rsv.rsv1
+ if self._inbound_compressed:
+ assert self._inbound_is_compressible
+ if proto.client:
+ bits = self.server_max_window_bits
+ else:
+ bits = self.client_max_window_bits
+ if self._decompressor is None:
+ self._decompressor = zlib.decompressobj(-int(bits))
+
+ return RsvBits(True, False, False)
+
+ def frame_inbound_payload_data(self, proto, data):
+ if not self._inbound_compressed or not self._inbound_is_compressible:
+ return data
+
+ try:
+ return self._decompressor.decompress(bytes(data))
+ except zlib.error:
+ return CloseReason.INVALID_FRAME_PAYLOAD_DATA
+
+ def frame_inbound_complete(self, proto, fin):
+ if not fin:
+ return
+ elif not self._inbound_is_compressible:
+ return
+ elif not self._inbound_compressed:
+ return
+
+ try:
+ data = self._decompressor.decompress(b'\x00\x00\xff\xff')
+ data += self._decompressor.flush()
+ except zlib.error:
+ return CloseReason.INVALID_FRAME_PAYLOAD_DATA
+
+ if proto.client:
+ no_context_takeover = self.server_no_context_takeover
+ else:
+ no_context_takeover = self.client_no_context_takeover
+
+ if no_context_takeover:
+ self._decompressor = None
+
+ self._inbound_compressed = None
+
+ return data
+
+ def frame_outbound(self, proto, opcode, rsv, data, fin):
+ if not self._compressible_opcode(opcode):
+ return (rsv, data)
+
+ if opcode is not Opcode.CONTINUATION:
+ rsv = RsvBits(True, *rsv[1:])
+
+ if self._compressor is None:
+ assert opcode is not Opcode.CONTINUATION
+ if proto.client:
+ bits = self.client_max_window_bits
+ else:
+ bits = self.server_max_window_bits
+ self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
+ zlib.DEFLATED, -int(bits))
+
+ data = self._compressor.compress(bytes(data))
+
+ if fin:
+ data += self._compressor.flush(zlib.Z_SYNC_FLUSH)
+ data = data[:-4]
+
+ if proto.client:
+ no_context_takeover = self.client_no_context_takeover
+ else:
+ no_context_takeover = self.server_no_context_takeover
+
+ if no_context_takeover:
+ self._compressor = None
+
+ return (rsv, data)
+
+ def __repr__(self):
+ descr = ['client_max_window_bits=%d' % self.client_max_window_bits]
+ if self.client_no_context_takeover:
+ descr.append('client_no_context_takeover')
+ descr.append('server_max_window_bits=%d' % self.server_max_window_bits)
+ if self.server_no_context_takeover:
+ descr.append('server_no_context_takeover')
+
+ descr = '; '.join(descr)
+
+ return '<%s %s>' % (self.__class__.__name__, descr)
+
+
+#: SUPPORTED_EXTENSIONS maps all supported extension names to their class.
+#: This can be used to iterate all supported extensions of wsproto, instantiate
+#: new extensions based on their name, or check if a given extension is
+#: supported or not.
+SUPPORTED_EXTENSIONS = {
+ PerMessageDeflate.name: PerMessageDeflate
+}
diff --git a/mitmproxy/contrib/wsproto/frame_protocol.py b/mitmproxy/contrib/wsproto/frame_protocol.py
new file mode 100644
index 00000000..b95dceec
--- /dev/null
+++ b/mitmproxy/contrib/wsproto/frame_protocol.py
@@ -0,0 +1,579 @@
+# -*- coding: utf-8 -*-
+"""
+wsproto/frame_protocol
+~~~~~~~~~~~~~~
+
+WebSocket frame protocol implementation.
+"""
+
+import os
+import itertools
+import struct
+from codecs import getincrementaldecoder
+from collections import namedtuple
+
+from enum import Enum, IntEnum
+
+from .compat import unicode, Utf8Validator
+
+try:
+ from wsaccel.xormask import XorMaskerSimple
+except ImportError:
+ class XorMaskerSimple:
+ def __init__(self, masking_key):
+ self._maskbytes = itertools.cycle(bytearray(masking_key))
+
+ def process(self, data):
+ maskbytes = self._maskbytes
+ return bytearray(b ^ next(maskbytes) for b in bytearray(data))
+
+
+class XorMaskerNull:
+ def process(self, data):
+ return data
+
+
+# RFC6455, Section 5.2 - Base Framing Protocol
+
+# Payload length constants
+PAYLOAD_LENGTH_TWO_BYTE = 126
+PAYLOAD_LENGTH_EIGHT_BYTE = 127
+MAX_PAYLOAD_NORMAL = 125
+MAX_PAYLOAD_TWO_BYTE = 2 ** 16 - 1
+MAX_PAYLOAD_EIGHT_BYTE = 2 ** 64 - 1
+MAX_FRAME_PAYLOAD = MAX_PAYLOAD_EIGHT_BYTE
+
+# MASK and PAYLOAD LEN are packed into a byte
+MASK_MASK = 0x80
+PAYLOAD_LEN_MASK = 0x7f
+
+# FIN, RSV[123] and OPCODE are packed into a single byte
+FIN_MASK = 0x80
+RSV1_MASK = 0x40
+RSV2_MASK = 0x20
+RSV3_MASK = 0x10
+OPCODE_MASK = 0x0f
+
+
+class Opcode(IntEnum):
+ """
+ RFC 6455, Section 5.2 - Base Framing Protocol
+ """
+ CONTINUATION = 0x0
+ TEXT = 0x1
+ BINARY = 0x2
+ CLOSE = 0x8
+ PING = 0x9
+ PONG = 0xA
+
+ def iscontrol(self):
+ return bool(self & 0x08)
+
+
+class CloseReason(IntEnum):
+ """
+ RFC 6455, Section 7.4.1 - Defined Status Codes
+ """
+ NORMAL_CLOSURE = 1000
+ GOING_AWAY = 1001
+ PROTOCOL_ERROR = 1002
+ UNSUPPORTED_DATA = 1003
+ NO_STATUS_RCVD = 1005
+ ABNORMAL_CLOSURE = 1006
+ INVALID_FRAME_PAYLOAD_DATA = 1007
+ POLICY_VIOLATION = 1008
+ MESSAGE_TOO_BIG = 1009
+ MANDATORY_EXT = 1010
+ INTERNAL_ERROR = 1011
+ SERVICE_RESTART = 1012
+ TRY_AGAIN_LATER = 1013
+ TLS_HANDSHAKE_FAILED = 1015
+
+
+# RFC 6455, Section 7.4.1 - Defined Status Codes
+LOCAL_ONLY_CLOSE_REASONS = (
+ CloseReason.NO_STATUS_RCVD,
+ CloseReason.ABNORMAL_CLOSURE,
+ CloseReason.TLS_HANDSHAKE_FAILED,
+)
+
+
+# RFC 6455, Section 7.4.2 - Status Code Ranges
+MIN_CLOSE_REASON = 1000
+MIN_PROTOCOL_CLOSE_REASON = 1000
+MAX_PROTOCOL_CLOSE_REASON = 2999
+MIN_LIBRARY_CLOSE_REASON = 3000
+MAX_LIBRARY_CLOSE_REASON = 3999
+MIN_PRIVATE_CLOSE_REASON = 4000
+MAX_PRIVATE_CLOSE_REASON = 4999
+MAX_CLOSE_REASON = 4999
+
+
+NULL_MASK = struct.pack("!I", 0)
+
+
+class ParseFailed(Exception):
+ def __init__(self, msg, code=CloseReason.PROTOCOL_ERROR):
+ super(ParseFailed, self).__init__(msg)
+ self.code = code
+
+
+Header = namedtuple("Header", "fin rsv opcode payload_len masking_key".split())
+
+
+Frame = namedtuple("Frame",
+ "opcode payload frame_finished message_finished".split())
+
+
+RsvBits = namedtuple("RsvBits", "rsv1 rsv2 rsv3".split())
+
+
+def _truncate_utf8(data, nbytes):
+ if len(data) <= nbytes:
+ return data
+
+ # Truncate
+ data = data[:nbytes]
+ # But we might have cut a codepoint in half, in which case we want to
+ # discard the partial character so the data is at least
+ # well-formed. This is a little inefficient since it processes the
+ # whole message twice when in theory we could just peek at the last
+ # few characters, but since this is only used for close messages (max
+ # length = 125 bytes) it really doesn't matter.
+ data = data.decode("utf-8", errors="ignore").encode("utf-8")
+ return data
+
+
+class Buffer(object):
+ def __init__(self, initial_bytes=None):
+ self.buffer = bytearray()
+ self.bytes_used = 0
+ if initial_bytes:
+ self.feed(initial_bytes)
+
+ def feed(self, new_bytes):
+ self.buffer += new_bytes
+
+ def consume_at_most(self, nbytes):
+ if not nbytes:
+ return bytearray()
+
+ data = self.buffer[self.bytes_used:self.bytes_used + nbytes]
+ self.bytes_used += len(data)
+ return data
+
+ def consume_exactly(self, nbytes):
+ if len(self.buffer) - self.bytes_used < nbytes:
+ return None
+
+ return self.consume_at_most(nbytes)
+
+ def commit(self):
+ # In CPython 3.4+, del[:n] is amortized O(n), *not* quadratic
+ del self.buffer[:self.bytes_used]
+ self.bytes_used = 0
+
+ def rollback(self):
+ self.bytes_used = 0
+
+ def __len__(self):
+ return len(self.buffer)
+
+
+class MessageDecoder(object):
+ def __init__(self):
+ self.opcode = None
+ self.validator = None
+ self.decoder = None
+
+ def process_frame(self, frame):
+ assert not frame.opcode.iscontrol()
+
+ if self.opcode is None:
+ if frame.opcode is Opcode.CONTINUATION:
+ raise ParseFailed("unexpected CONTINUATION")
+ self.opcode = frame.opcode
+ elif frame.opcode is not Opcode.CONTINUATION:
+ raise ParseFailed("expected CONTINUATION, got %r" % frame.opcode)
+
+ if frame.opcode is Opcode.TEXT:
+ self.validator = Utf8Validator()
+ self.decoder = getincrementaldecoder("utf-8")()
+
+ finished = frame.frame_finished and frame.message_finished
+
+ if self.decoder is not None:
+ data = self.decode_payload(frame.payload, finished)
+ else:
+ data = frame.payload
+
+ frame = Frame(self.opcode, data, frame.frame_finished, finished)
+
+ if finished:
+ self.opcode = None
+ self.decoder = None
+
+ return frame
+
+ def decode_payload(self, data, finished):
+ if self.validator is not None:
+ results = self.validator.validate(bytes(data))
+ if not results[0] or (finished and not results[1]):
+ raise ParseFailed(u'encountered invalid UTF-8 while processing'
+ ' text message at payload octet index %d' %
+ results[3],
+ CloseReason.INVALID_FRAME_PAYLOAD_DATA)
+
+ try:
+ return self.decoder.decode(data, finished)
+ except UnicodeDecodeError as exc:
+ raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA)
+
+
+class FrameDecoder(object):
+ def __init__(self, client, extensions=None):
+ self.client = client
+ self.extensions = extensions or []
+
+ self.buffer = Buffer()
+
+ self.header = None
+ self.effective_opcode = None
+ self.masker = None
+ self.payload_required = 0
+ self.payload_consumed = 0
+
+ def receive_bytes(self, data):
+ self.buffer.feed(data)
+
+ def process_buffer(self):
+ if not self.header:
+ if not self.parse_header():
+ return None
+
+ if len(self.buffer) < self.payload_required:
+ return None
+
+ payload_remaining = self.header.payload_len - self.payload_consumed
+ payload = self.buffer.consume_at_most(payload_remaining)
+ if not payload and self.header.payload_len > 0:
+ return None
+ self.buffer.commit()
+
+ self.payload_consumed += len(payload)
+ finished = self.payload_consumed == self.header.payload_len
+
+ payload = self.masker.process(payload)
+
+ for extension in self.extensions:
+ payload = extension.frame_inbound_payload_data(self, payload)
+ if isinstance(payload, CloseReason):
+ raise ParseFailed("error in extension", payload)
+
+ if finished:
+ final = bytearray()
+ for extension in self.extensions:
+ result = extension.frame_inbound_complete(self,
+ self.header.fin)
+ if isinstance(result, CloseReason):
+ raise ParseFailed("error in extension", result)
+ if result is not None:
+ final += result
+ payload += final
+
+ frame = Frame(self.effective_opcode, payload, finished,
+ self.header.fin)
+
+ if finished:
+ self.header = None
+ self.effective_opcode = None
+ self.masker = None
+ else:
+ self.effective_opcode = Opcode.CONTINUATION
+
+ return frame
+
+ def parse_header(self):
+ data = self.buffer.consume_exactly(2)
+ if data is None:
+ self.buffer.rollback()
+ return False
+
+ fin = bool(data[0] & FIN_MASK)
+ rsv = RsvBits(bool(data[0] & RSV1_MASK),
+ bool(data[0] & RSV2_MASK),
+ bool(data[0] & RSV3_MASK))
+ opcode = data[0] & OPCODE_MASK
+ try:
+ opcode = Opcode(opcode)
+ except ValueError:
+ raise ParseFailed("Invalid opcode {:#x}".format(opcode))
+
+ if opcode.iscontrol() and not fin:
+ raise ParseFailed("Invalid attempt to fragment control frame")
+
+ has_mask = bool(data[1] & MASK_MASK)
+ payload_len = data[1] & PAYLOAD_LEN_MASK
+ payload_len = self.parse_extended_payload_length(opcode, payload_len)
+ if payload_len is None:
+ self.buffer.rollback()
+ return False
+
+ self.extension_processing(opcode, rsv, payload_len)
+
+ if has_mask and self.client:
+ raise ParseFailed("client received unexpected masked frame")
+ if not has_mask and not self.client:
+ raise ParseFailed("server received unexpected unmasked frame")
+ if has_mask:
+ masking_key = self.buffer.consume_exactly(4)
+ if masking_key is None:
+ self.buffer.rollback()
+ return False
+ self.masker = XorMaskerSimple(masking_key)
+ else:
+ self.masker = XorMaskerNull()
+
+ self.buffer.commit()
+ self.header = Header(fin, rsv, opcode, payload_len, None)
+ self.effective_opcode = self.header.opcode
+ if self.header.opcode.iscontrol():
+ self.payload_required = payload_len
+ else:
+ self.payload_required = 0
+ self.payload_consumed = 0
+ return True
+
+ def parse_extended_payload_length(self, opcode, payload_len):
+ if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL:
+ raise ParseFailed("Control frame with payload len > 125")
+ if payload_len == PAYLOAD_LENGTH_TWO_BYTE:
+ data = self.buffer.consume_exactly(2)
+ if data is None:
+ return None
+ (payload_len,) = struct.unpack("!H", data)
+ if payload_len <= MAX_PAYLOAD_NORMAL:
+ raise ParseFailed(
+ "Payload length used 2 bytes when 1 would have sufficed")
+ elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE:
+ data = self.buffer.consume_exactly(8)
+ if data is None:
+ return None
+ (payload_len,) = struct.unpack("!Q", data)
+ if payload_len <= MAX_PAYLOAD_TWO_BYTE:
+ raise ParseFailed(
+ "Payload length used 8 bytes when 2 would have sufficed")
+ if payload_len >> 63:
+ # I'm not sure why this is illegal, but that's what the RFC
+ # says, so...
+ raise ParseFailed("8-byte payload length with non-zero MSB")
+
+ return payload_len
+
+ def extension_processing(self, opcode, rsv, payload_len):
+ rsv_used = [False, False, False]
+ for extension in self.extensions:
+ result = extension.frame_inbound_header(self, opcode, rsv,
+ payload_len)
+ if isinstance(result, CloseReason):
+ raise ParseFailed("error in extension", result)
+ for bit, used in enumerate(result):
+ if used:
+ rsv_used[bit] = True
+ for expected, found in zip(rsv_used, rsv):
+ if found and not expected:
+ raise ParseFailed("Reserved bit set unexpectedly")
+
+
+class FrameProtocol(object):
+ class State(Enum):
+ HEADER = 1
+ PAYLOAD = 2
+ FRAME_COMPLETE = 3
+ FAILED = 4
+
+ def __init__(self, client, extensions):
+ self.client = client
+ self.extensions = [ext for ext in extensions if ext.enabled()]
+
+ # Global state
+ self._frame_decoder = FrameDecoder(self.client, self.extensions)
+ self._message_decoder = MessageDecoder()
+ self._parse_more = self.parse_more_gen()
+
+ self._outbound_opcode = None
+
+ def _process_close(self, frame):
+ data = frame.payload
+
+ if not data:
+ # "If this Close control frame contains no status code, _The
+ # WebSocket Connection Close Code_ is considered to be 1005"
+ data = (CloseReason.NO_STATUS_RCVD, "")
+ elif len(data) == 1:
+ raise ParseFailed("CLOSE with 1 byte payload")
+ else:
+ (code,) = struct.unpack("!H", data[:2])
+ if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON:
+ raise ParseFailed("CLOSE with invalid code")
+ try:
+ code = CloseReason(code)
+ except ValueError:
+ pass
+ if code in LOCAL_ONLY_CLOSE_REASONS:
+ raise ParseFailed(
+ "remote CLOSE with local-only reason")
+ if not isinstance(code, CloseReason) and \
+ code <= MAX_PROTOCOL_CLOSE_REASON:
+ raise ParseFailed(
+ "CLOSE with unknown reserved code")
+ validator = Utf8Validator()
+ if validator is not None:
+ results = validator.validate(bytes(data[2:]))
+ if not (results[0] and results[1]):
+ raise ParseFailed(u'encountered invalid UTF-8 while'
+ ' processing close message at payload'
+ ' octet index %d' %
+ results[3],
+ CloseReason.INVALID_FRAME_PAYLOAD_DATA)
+ try:
+ reason = data[2:].decode("utf-8")
+ except UnicodeDecodeError as exc:
+ raise ParseFailed(
+ "Error decoding CLOSE reason: " + str(exc),
+ CloseReason.INVALID_FRAME_PAYLOAD_DATA)
+ data = (code, reason)
+
+ return Frame(frame.opcode, data, frame.frame_finished,
+ frame.message_finished)
+
+ def parse_more_gen(self):
+ # Consume as much as we can from self._buffer, yielding events, and
+ # then yield None when we need more data. Or raise ParseFailed.
+
+ # XX FIXME this should probably be refactored so that we never see
+ # disabled extensions in the first place...
+ self.extensions = [ext for ext in self.extensions if ext.enabled()]
+ closed = False
+
+ while not closed:
+ frame = self._frame_decoder.process_buffer()
+
+ if frame is not None:
+ if not frame.opcode.iscontrol():
+ frame = self._message_decoder.process_frame(frame)
+ elif frame.opcode == Opcode.CLOSE:
+ frame = self._process_close(frame)
+ closed = True
+
+ yield frame
+
+ def receive_bytes(self, data):
+ self._frame_decoder.receive_bytes(data)
+
+ def received_frames(self):
+ for event in self._parse_more:
+ if event is None:
+ break
+ else:
+ yield event
+
+ def close(self, code=None, reason=None):
+ payload = bytearray()
+ if code is None and reason is not None:
+ raise TypeError("cannot specify a reason without a code")
+ if code in LOCAL_ONLY_CLOSE_REASONS:
+ code = CloseReason.NORMAL_CLOSURE
+ if code is not None:
+ payload += bytearray(struct.pack('!H', code))
+ if reason is not None:
+ payload += _truncate_utf8(reason.encode('utf-8'),
+ MAX_PAYLOAD_NORMAL - 2)
+
+ return self._serialize_frame(Opcode.CLOSE, payload)
+
+ def ping(self, payload=b''):
+ return self._serialize_frame(Opcode.PING, payload)
+
+ def pong(self, payload=b''):
+ return self._serialize_frame(Opcode.PONG, payload)
+
+ def send_data(self, payload=b'', fin=True):
+ if isinstance(payload, (bytes, bytearray, memoryview)):
+ opcode = Opcode.BINARY
+ elif isinstance(payload, unicode):
+ opcode = Opcode.TEXT
+ payload = payload.encode('utf-8')
+ else:
+ raise ValueError('Must provide bytes or text')
+
+ if self._outbound_opcode is None:
+ self._outbound_opcode = opcode
+ elif self._outbound_opcode is not opcode:
+ raise TypeError('Data type mismatch inside message')
+ else:
+ opcode = Opcode.CONTINUATION
+
+ if fin:
+ self._outbound_opcode = None
+
+ return self._serialize_frame(opcode, payload, fin)
+
+ def _make_fin_rsv_opcode(self, fin, rsv, opcode):
+ fin = int(fin) << 7
+ rsv = (int(rsv.rsv1) << 6) + (int(rsv.rsv2) << 5) + \
+ (int(rsv.rsv3) << 4)
+ opcode = int(opcode)
+
+ return fin | rsv | opcode
+
+ def _serialize_frame(self, opcode, payload=b'', fin=True):
+ rsv = RsvBits(False, False, False)
+ for extension in reversed(self.extensions):
+ rsv, payload = extension.frame_outbound(self, opcode, rsv, payload,
+ fin)
+
+ fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode)
+
+ payload_length = len(payload)
+ quad_payload = False
+ if payload_length <= MAX_PAYLOAD_NORMAL:
+ first_payload = payload_length
+ second_payload = None
+ elif payload_length <= MAX_PAYLOAD_TWO_BYTE:
+ first_payload = PAYLOAD_LENGTH_TWO_BYTE
+ second_payload = payload_length
+ else:
+ first_payload = PAYLOAD_LENGTH_EIGHT_BYTE
+ second_payload = payload_length
+ quad_payload = True
+
+ if self.client:
+ first_payload |= 1 << 7
+
+ header = bytearray([fin_rsv_opcode, first_payload])
+ if second_payload is not None:
+ if opcode.iscontrol():
+ raise ValueError("payload too long for control frame")
+ if quad_payload:
+ header += bytearray(struct.pack('!Q', second_payload))
+ else:
+ header += bytearray(struct.pack('!H', second_payload))
+
+ if self.client:
+ # "The masking key is a 32-bit value chosen at random by the
+ # client. When preparing a masked frame, the client MUST pick a
+ # fresh masking key from the set of allowed 32-bit values. The
+ # masking key needs to be unpredictable; thus, the masking key
+ # MUST be derived from a strong source of entropy, and the masking
+ # key for a given frame MUST NOT make it simple for a server/proxy
+ # to predict the masking key for a subsequent frame. The
+ # unpredictability of the masking key is essential to prevent
+ # authors of malicious applications from selecting the bytes that
+ # appear on the wire."
+ # -- https://tools.ietf.org/html/rfc6455#section-5.3
+ masking_key = os.urandom(4)
+ masker = XorMaskerSimple(masking_key)
+ return header + masking_key + masker.process(payload)
+
+ return header + payload
diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py
index 54d8120d..34dcba06 100644
--- a/mitmproxy/proxy/protocol/websocket.py
+++ b/mitmproxy/proxy/protocol/websocket.py
@@ -1,10 +1,10 @@
import socket
from OpenSSL import SSL
-from wsproto import events
-from wsproto.connection import ConnectionType, WSConnection
-from wsproto.extensions import PerMessageDeflate
-from wsproto.frame_protocol import Opcode
+from mitmproxy.contrib.wsproto import events
+from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection
+from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate
+from mitmproxy.contrib.wsproto.frame_protocol import Opcode
from mitmproxy import exceptions
from mitmproxy import flow
diff --git a/setup.py b/setup.py
index 54c2811d..ad792881 100644
--- a/setup.py
+++ b/setup.py
@@ -65,6 +65,7 @@ setup(
"certifi>=2015.11.20.1", # no semver here - this should always be on the last release!
"click>=6.2, <7",
"cryptography>=2.0,<2.2",
+ 'h11>=0.7.0,<0.8',
"h2>=3.0, <4",
"hyperframe>=5.0, <6",
"kaitaistruct>=0.7, <0.8",