aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Kriechbaumer <Kriechi@users.noreply.github.com>2016-09-01 10:39:57 +0200
committerGitHub <noreply@github.com>2016-09-01 10:39:57 +0200
commit55d938b880fd861a22ac66da0da9a741bdd9abd5 (patch)
treed469bbd0dd5b1966591a332bf2094d4389100219
parent281d779fa3eb6b81ec76d046337275c0a82eff46 (diff)
parent0d0c2c788df4b60e951e6fcc13b479de8cec22c1 (diff)
downloadmitmproxy-55d938b880fd861a22ac66da0da9a741bdd9abd5.tar.gz
mitmproxy-55d938b880fd861a22ac66da0da9a741bdd9abd5.tar.bz2
mitmproxy-55d938b880fd861a22ac66da0da9a741bdd9abd5.zip
Merge pull request #1488 from mitmproxy/websockets
add WebSockets support
-rw-r--r--docs/scripting/inlinescripts.rst12
-rw-r--r--mitmproxy/controller.py2
-rw-r--r--mitmproxy/flow/master.py4
-rw-r--r--mitmproxy/protocol/__init__.py4
-rw-r--r--mitmproxy/protocol/http.py11
-rw-r--r--mitmproxy/protocol/websockets.py108
-rw-r--r--mitmproxy/proxy/root_context.py19
-rw-r--r--netlib/websockets/__init__.py34
-rw-r--r--netlib/websockets/frame.py142
-rw-r--r--netlib/websockets/masker.py33
-rw-r--r--netlib/websockets/protocol.py112
-rw-r--r--netlib/websockets/utils.py90
-rw-r--r--pathod/language/http.py4
-rw-r--r--pathod/pathoc.py2
-rw-r--r--pathod/pathod.py9
-rw-r--r--pathod/protocols/websockets.py2
-rw-r--r--test/mitmproxy/protocol/__init__.py0
-rw-r--r--test/mitmproxy/protocol/test_http1.py (renamed from test/mitmproxy/test_protocol_http1.py)4
-rw-r--r--test/mitmproxy/protocol/test_http2.py (renamed from test/mitmproxy/test_protocol_http2.py)4
-rw-r--r--test/mitmproxy/protocol/test_websockets.py299
-rw-r--r--test/netlib/websockets/test_frame.py164
-rw-r--r--test/netlib/websockets/test_masker.py23
-rw-r--r--test/netlib/websockets/test_utils.py105
-rw-r--r--test/netlib/websockets/test_websockets.py269
24 files changed, 967 insertions, 489 deletions
diff --git a/docs/scripting/inlinescripts.rst b/docs/scripting/inlinescripts.rst
index bc9d5ff5..e1c01b17 100644
--- a/docs/scripting/inlinescripts.rst
+++ b/docs/scripting/inlinescripts.rst
@@ -126,6 +126,18 @@ HTTP Events
:param HTTPFlow flow: The flow containing the error.
It is guaranteed to have non-None ``error`` attribute.
+WebSockets Events
+^^^^^^^^^^^^^^^^^
+
+.. py:function:: websockets_handshake(context, flow)
+
+ Called when a client wants to establish a WebSockets connection.
+ The WebSockets-specific headers can be manipulated to manipulate the handshake.
+ The ``flow`` object is guaranteed to have a non-None ``request`` attribute.
+
+ :param HTTPFlow flow: The flow containing the request which has been received.
+ The object is guaranteed to have a non-None ``request`` attribute.
+
TCP Events
^^^^^^^^^^
diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py
index c262b192..d886af97 100644
--- a/mitmproxy/controller.py
+++ b/mitmproxy/controller.py
@@ -28,6 +28,8 @@ Events = frozenset([
"response",
"responseheaders",
+ "websockets_handshake",
+
"next_layer",
"error",
diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py
index 0475ef4e..9cdcc8dd 100644
--- a/mitmproxy/flow/master.py
+++ b/mitmproxy/flow/master.py
@@ -334,6 +334,10 @@ class FlowMaster(controller.Master):
self.client_playback.clear(f)
return f
+ @controller.handler
+ def websockets_handshake(self, f):
+ return f
+
def handle_intercept(self, f):
self.state.update_flow(f)
diff --git a/mitmproxy/protocol/__init__.py b/mitmproxy/protocol/__init__.py
index 510cd195..b99b55bd 100644
--- a/mitmproxy/protocol/__init__.py
+++ b/mitmproxy/protocol/__init__.py
@@ -29,8 +29,10 @@ from __future__ import absolute_import, print_function, division
from .base import Layer, ServerConnectionMixin
from .http import UpstreamConnectLayer
+from .http import HttpLayer
from .http1 import Http1Layer
from .http2 import Http2Layer
+from .websockets import WebSocketsLayer
from .rawtcp import RawTCPLayer
from .tls import TlsClientHello
from .tls import TlsLayer
@@ -40,7 +42,9 @@ __all__ = [
"Layer", "ServerConnectionMixin",
"TlsLayer", "is_tls_record_magic", "TlsClientHello",
"UpstreamConnectLayer",
+ "HttpLayer",
"Http1Layer",
"Http2Layer",
+ "WebSocketsLayer",
"RawTCPLayer",
]
diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py
index d81fc8ca..1418d6e9 100644
--- a/mitmproxy/protocol/http.py
+++ b/mitmproxy/protocol/http.py
@@ -7,12 +7,14 @@ import traceback
import h2.exceptions
import six
-import netlib.exceptions
from mitmproxy import exceptions
from mitmproxy import models
from mitmproxy.protocol import base
+
+import netlib.exceptions
from netlib import http
from netlib import tcp
+from netlib import websockets
class _HttpTransmissionLayer(base.Layer):
@@ -189,6 +191,11 @@ class HttpLayer(base.Layer):
self.process_request_hook(flow)
try:
+ if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
+ # we only support RFC6455 with WebSockets version 13
+ # allow inline scripts to manupulate the client handshake
+ self.channel.ask("websockets_handshake", flow)
+
if not flow.response:
self.establish_server_connection(
flow.request.host,
@@ -212,7 +219,7 @@ class HttpLayer(base.Layer):
# It may be useful to pass additional args (such as the upgrade header)
# to next_layer in the future
if flow.response.status_code == 101:
- layer = self.ctx.next_layer(self)
+ layer = self.ctx.next_layer(self, flow)
layer()
return
diff --git a/mitmproxy/protocol/websockets.py b/mitmproxy/protocol/websockets.py
new file mode 100644
index 00000000..f15a38ef
--- /dev/null
+++ b/mitmproxy/protocol/websockets.py
@@ -0,0 +1,108 @@
+from __future__ import absolute_import, print_function, division
+
+import socket
+import struct
+
+from OpenSSL import SSL
+
+from mitmproxy import exceptions
+from mitmproxy.protocol import base
+
+import netlib.exceptions
+from netlib import tcp
+from netlib import websockets
+
+
+class WebSocketsLayer(base.Layer):
+ """
+ WebSockets layer to intercept, modify, and forward WebSockets connections
+
+ Only version 13 is supported (as specified in RFC6455)
+ Only HTTP/1.1-initiated connections are supported.
+
+ The client starts by sending an Upgrade-request.
+ In order to determine the handshake and negotiate the correct protocol
+ and extensions, the Upgrade-request is forwarded to the server.
+ The response from the server is then parsed and negotiated settings are extracted.
+ Finally the handshake is completed by forwarding the server-response to the client.
+ After that, only WebSockets frames are exchanged.
+
+ PING/PONG frames pass through and must be answered by the other endpoint.
+
+ CLOSE frames are forwarded before this WebSocketsLayer terminates.
+
+ This layer is transparent to any negotiated extensions.
+ This layer is transparent to any negotiated subprotocols.
+ Only raw frames are forwarded to the other endpoint.
+ """
+
+ def __init__(self, ctx, flow):
+ super(WebSocketsLayer, self).__init__(ctx)
+ self._flow = flow
+
+ self.client_key = websockets.get_client_key(self._flow.request.headers)
+ self.client_protocol = websockets.get_protocol(self._flow.request.headers)
+ self.client_extensions = websockets.get_extensions(self._flow.request.headers)
+
+ self.server_accept = websockets.get_server_accept(self._flow.response.headers)
+ self.server_protocol = websockets.get_protocol(self._flow.response.headers)
+ self.server_extensions = websockets.get_extensions(self._flow.response.headers)
+
+ def _handle_frame(self, frame, source_conn, other_conn, is_server):
+ self.log(
+ "WebSockets Frame received from {}".format("server" if is_server else "client"),
+ "debug",
+ [repr(frame)]
+ )
+
+ if frame.header.opcode & 0x8 == 0:
+ # forward the data frame to the other side
+ other_conn.send(bytes(frame))
+ self.log("WebSockets frame received by {}: {}".format(is_server, frame), "debug")
+ elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
+ # just forward the ping/pong to the other side
+ other_conn.send(bytes(frame))
+ elif frame.header.opcode == websockets.OPCODE.CLOSE:
+ other_conn.send(bytes(frame))
+
+ code = '(status code missing)'
+ msg = None
+ reason = '(message missing)'
+ if len(frame.payload) >= 2:
+ code, = struct.unpack('!H', frame.payload[:2])
+ msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
+ if len(frame.payload) > 2:
+ reason = frame.payload[2:]
+ self.log("WebSockets connection closed: {} {}, {}".format(code, msg, reason), "info")
+
+ # close the connection
+ return False
+ else:
+ # unknown frame - just forward it
+ other_conn.send(bytes(frame))
+
+ # continue the connection
+ return True
+
+ def __call__(self):
+ client = self.client_conn.connection
+ server = self.server_conn.connection
+ conns = [client, server]
+
+ try:
+ while not self.channel.should_exit.is_set():
+ r = tcp.ssl_read_select(conns, 1)
+ for conn in r:
+ source_conn = self.client_conn if conn == client else self.server_conn
+ other_conn = self.server_conn if conn == client else self.client_conn
+ is_server = (conn == self.server_conn.connection)
+
+ frame = websockets.Frame.from_file(source_conn.rfile)
+
+ if not self._handle_frame(frame, source_conn, other_conn, is_server):
+ return
+ except (socket.error, netlib.exceptions.TcpException, SSL.Error) as e:
+ self.log("WebSockets connection closed unexpectedly by {}: {}".format(
+ "server" if is_server else "client", repr(e)), "info")
+ except Exception as e: # pragma: no cover
+ raise exceptions.ProtocolException("Error in WebSockets connection: {}".format(repr(e)))
diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py
index 81dd625c..95611362 100644
--- a/mitmproxy/proxy/root_context.py
+++ b/mitmproxy/proxy/root_context.py
@@ -4,6 +4,7 @@ import sys
import six
+from netlib import websockets
import netlib.exceptions
from mitmproxy import exceptions
from mitmproxy import protocol
@@ -32,7 +33,7 @@ class RootContext(object):
self.channel = channel
self.config = config
- def next_layer(self, top_layer):
+ def next_layer(self, top_layer, flow=None):
"""
This function determines the next layer in the protocol stack.
@@ -42,10 +43,22 @@ class RootContext(object):
Returns:
The next layer
"""
- layer = self._next_layer(top_layer)
+ layer = self._next_layer(top_layer, flow)
return self.channel.ask("next_layer", layer)
- def _next_layer(self, top_layer):
+ def _next_layer(self, top_layer, flow):
+ if flow is not None:
+ # We already have a flow, try to derive the next information from it
+
+ # Check for WebSockets handshake
+ is_websockets = (
+ flow and
+ websockets.check_handshake(flow.request.headers) and
+ websockets.check_handshake(flow.response.headers)
+ )
+ if isinstance(top_layer, protocol.HttpLayer) and is_websockets:
+ return protocol.WebSocketsLayer(top_layer, flow)
+
try:
d = top_layer.client_conn.rfile.peek(3)
except netlib.exceptions.TcpException as e:
diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py
index fea696d9..e14e8a7d 100644
--- a/netlib/websockets/__init__.py
+++ b/netlib/websockets/__init__.py
@@ -1,11 +1,37 @@
from __future__ import absolute_import, print_function, division
-from .frame import FrameHeader, Frame, OPCODE
-from .protocol import Masker, WebsocketsProtocol
+
+from .frame import FrameHeader
+from .frame import Frame
+from .frame import OPCODE
+from .frame import CLOSE_REASON
+from .masker import Masker
+from .utils import MAGIC
+from .utils import VERSION
+from .utils import client_handshake_headers
+from .utils import server_handshake_headers
+from .utils import check_handshake
+from .utils import check_client_version
+from .utils import create_server_nonce
+from .utils import get_extensions
+from .utils import get_protocol
+from .utils import get_client_key
+from .utils import get_server_accept
__all__ = [
"FrameHeader",
"Frame",
- "Masker",
- "WebsocketsProtocol",
"OPCODE",
+ "CLOSE_REASON",
+ "Masker",
+ "MAGIC",
+ "VERSION",
+ "client_handshake_headers",
+ "server_handshake_headers",
+ "check_handshake",
+ "check_client_version",
+ "create_server_nonce",
+ "get_extensions",
+ "get_protocol",
+ "get_client_key",
+ "get_server_accept",
]
diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py
index 7d355699..e62d0e87 100644
--- a/netlib/websockets/frame.py
+++ b/netlib/websockets/frame.py
@@ -2,7 +2,6 @@ from __future__ import absolute_import
import os
import struct
import io
-import warnings
import six
@@ -10,7 +9,7 @@ from netlib import tcp
from netlib import strutils
from netlib import utils
from netlib import human
-from netlib.websockets import protocol
+from .masker import Masker
MAX_16_BIT_INT = (1 << 16)
@@ -18,6 +17,7 @@ MAX_64_BIT_INT = (1 << 64)
DEFAULT = object()
+# RFC 6455, Section 5.2 - Base Framing Protocol
OPCODE = utils.BiDi(
CONTINUE=0x00,
TEXT=0x01,
@@ -27,6 +27,23 @@ OPCODE = utils.BiDi(
PONG=0x0a
)
+# RFC 6455, Section 7.4.1 - Defined Status Codes
+CLOSE_REASON = utils.BiDi(
+ NORMAL_CLOSURE=1000,
+ GOING_AWAY=1001,
+ PROTOCOL_ERROR=1002,
+ UNSUPPORTED_DATA=1003,
+ RESERVED=1004,
+ RESERVED_NO_STATUS=1005,
+ RESERVED_ABNORMAL_CLOSURE=1006,
+ INVALID_PAYLOAD_DATA=1007,
+ POLICY_VIOLATION=1008,
+ MESSAGE_TOO_BIG=1009,
+ MANDATORY_EXTENSION=1010,
+ INTERNAL_ERROR=1011,
+ RESERVED_TLS_HANDHSAKE_FAILED=1015,
+)
+
class FrameHeader(object):
@@ -103,10 +120,6 @@ class FrameHeader(object):
vals.append(" %s" % human.pretty_size(self.payload_length))
return "".join(vals)
- def human_readable(self):
- warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
- return repr(self)
-
def __bytes__(self):
first_byte = utils.setbit(0, 7, self.fin)
first_byte = utils.setbit(first_byte, 6, self.rsv1)
@@ -128,6 +141,9 @@ class FrameHeader(object):
# '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length
b += struct.pack('!Q', self.payload_length)
+ else:
+ raise ValueError("Payload length exceeds 64bit integer")
+
if self.masking_key:
b += self.masking_key
return b
@@ -135,10 +151,6 @@ class FrameHeader(object):
if six.PY2:
__str__ = __bytes__
- def to_bytes(self):
- warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
- return bytes(self)
-
@classmethod
def from_file(cls, fp):
"""
@@ -151,19 +163,17 @@ class FrameHeader(object):
rsv1 = utils.getbit(first_byte, 6)
rsv2 = utils.getbit(first_byte, 5)
rsv3 = utils.getbit(first_byte, 4)
- # grab right-most 4 bits
- opcode = first_byte & 15
+ opcode = first_byte & 0xF
mask_bit = utils.getbit(second_byte, 7)
- # grab the next 7 bits
- length_code = second_byte & 127
+ length_code = second_byte & 0x7F
- # payload_lengthy > 125 indicates you need to read more bytes
+ # payload_length > 125 indicates you need to read more bytes
# to get the actual payload length
if length_code <= 125:
payload_length = length_code
elif length_code == 126:
payload_length, = struct.unpack("!H", fp.safe_read(2))
- elif length_code == 127:
+ else: # length_code == 127:
payload_length, = struct.unpack("!Q", fp.safe_read(8))
# masking key only present if mask bit set
@@ -191,31 +201,30 @@ class FrameHeader(object):
class Frame(object):
-
"""
- Represents one websockets frame.
- Constructor takes human readable forms of the frame components
- from_bytes() is also avaliable.
-
- WebSockets Frame as defined in RFC6455
-
- 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
- +-+-+-+-+-------+-+-------------+-------------------------------+
- |F|R|R|R| opcode|M| Payload len | Extended payload length |
- |I|S|S|S| (4) |A| (7) | (16/64) |
- |N|V|V|V| |S| | (if payload len==126/127) |
- | |1|2|3| |K| | |
- +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
- | Extended payload length continued, if payload len == 127 |
- + - - - - - - - - - - - - - - - +-------------------------------+
- | |Masking-key, if MASK set to 1 |
- +-------------------------------+-------------------------------+
- | Masking-key (continued) | Payload Data |
- +-------------------------------- - - - - - - - - - - - - - - - +
- : Payload Data continued ... :
- + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
- | Payload Data continued ... |
- +---------------------------------------------------------------+
+ Represents a single WebSockets frame.
+ Constructor takes human readable forms of the frame components.
+ from_bytes() reads from a file-like object to create a new Frame.
+
+ WebSockets Frame as defined in RFC6455
+
+ 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ +-+-+-+-+-------+-+-------------+-------------------------------+
+ |F|R|R|R| opcode|M| Payload len | Extended payload length |
+ |I|S|S|S| (4) |A| (7) | (16/64) |
+ |N|V|V|V| |S| | (if payload len==126/127) |
+ | |1|2|3| |K| | |
+ +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+ | Extended payload length continued, if payload len == 127 |
+ + - - - - - - - - - - - - - - - +-------------------------------+
+ | |Masking-key, if MASK set to 1 |
+ +-------------------------------+-------------------------------+
+ | Masking-key (continued) | Payload Data |
+ +-------------------------------- - - - - - - - - - - - - - - - +
+ : Payload Data continued ... :
+ + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
+ | Payload Data continued ... |
+ +---------------------------------------------------------------+
"""
def __init__(self, payload=b"", **kwargs):
@@ -224,27 +233,6 @@ class Frame(object):
self.header = FrameHeader(**kwargs)
@classmethod
- def default(cls, message, from_client=False):
- """
- Construct a basic websocket frame from some default values.
- Creates a non-fragmented text frame.
- """
- if from_client:
- mask_bit = 1
- masking_key = os.urandom(4)
- else:
- mask_bit = 0
- masking_key = None
-
- return cls(
- message,
- fin=1, # final frame
- opcode=OPCODE.TEXT, # text
- mask=mask_bit,
- masking_key=masking_key,
- )
-
- @classmethod
def from_bytes(cls, bytestring):
"""
Construct a websocket frame from an in-memory bytestring
@@ -258,17 +246,13 @@ class Frame(object):
ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload)
return ret
- def human_readable(self):
- warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning)
- return repr(self)
-
def __bytes__(self):
"""
Serialize the frame to wire format. Returns a string.
"""
b = bytes(self.header)
if self.header.masking_key:
- b += protocol.Masker(self.header.masking_key)(self.payload)
+ b += Masker(self.header.masking_key)(self.payload)
else:
b += self.payload
return b
@@ -276,15 +260,6 @@ class Frame(object):
if six.PY2:
__str__ = __bytes__
- def to_bytes(self):
- warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
- return bytes(self)
-
- def to_file(self, writer):
- warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning)
- writer.write(bytes(self))
- writer.flush()
-
@classmethod
def from_file(cls, fp):
"""
@@ -297,20 +272,11 @@ class Frame(object):
payload = fp.safe_read(header.payload_length)
if header.mask == 1 and header.masking_key:
- payload = protocol.Masker(header.masking_key)(payload)
+ payload = Masker(header.masking_key)(payload)
- return cls(
- payload,
- fin=header.fin,
- opcode=header.opcode,
- mask=header.mask,
- payload_length=header.payload_length,
- masking_key=header.masking_key,
- rsv1=header.rsv1,
- rsv2=header.rsv2,
- rsv3=header.rsv3,
- length_code=header.length_code
- )
+ frame = cls(payload)
+ frame.header = header
+ return frame
def __eq__(self, other):
if isinstance(other, Frame):
diff --git a/netlib/websockets/masker.py b/netlib/websockets/masker.py
new file mode 100644
index 00000000..bd39ed6a
--- /dev/null
+++ b/netlib/websockets/masker.py
@@ -0,0 +1,33 @@
+from __future__ import absolute_import
+
+import six
+
+
+class Masker(object):
+ """
+ Data sent from the server must be masked to prevent malicious clients
+ from sending data over the wire in predictable patterns.
+
+ Servers do not have to mask data they send to the client.
+ https://tools.ietf.org/html/rfc6455#section-5.3
+ """
+
+ def __init__(self, key):
+ self.key = key
+ self.offset = 0
+
+ def mask(self, offset, data):
+ result = bytearray(data)
+ for i in range(len(data)):
+ if six.PY2:
+ result[i] ^= ord(self.key[offset % 4])
+ else:
+ result[i] ^= self.key[offset % 4]
+ offset += 1
+ result = bytes(result)
+ return result
+
+ def __call__(self, data):
+ ret = self.mask(self.offset, data)
+ self.offset += len(ret)
+ return ret
diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py
deleted file mode 100644
index af0eef7d..00000000
--- a/netlib/websockets/protocol.py
+++ /dev/null
@@ -1,112 +0,0 @@
-"""
-Colleciton of utility functions that implement small portions of the RFC6455
-WebSockets Protocol Useful for building WebSocket clients and servers.
-
-Emphassis is on readabilty, simplicity and modularity, not performance or
-completeness
-
-This is a work in progress and does not yet contain all the utilites need to
-create fully complient client/servers #
-Spec: https://tools.ietf.org/html/rfc6455
-
-The magic sha that websocket servers must know to prove they understand
-RFC6455
-"""
-
-from __future__ import absolute_import
-import base64
-import hashlib
-import os
-
-import six
-
-from netlib import http, strutils
-
-websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
-VERSION = "13"
-
-
-class Masker(object):
-
- """
- Data sent from the server must be masked to prevent malicious clients
- from sending data over the wire in predictable patterns
-
- Servers do not have to mask data they send to the client.
- https://tools.ietf.org/html/rfc6455#section-5.3
- """
-
- def __init__(self, key):
- self.key = key
- self.offset = 0
-
- def mask(self, offset, data):
- result = bytearray(data)
- if six.PY2:
- for i in range(len(data)):
- result[i] ^= ord(self.key[offset % 4])
- offset += 1
- result = str(result)
- else:
-
- for i in range(len(data)):
- result[i] ^= self.key[offset % 4]
- offset += 1
- result = bytes(result)
- return result
-
- def __call__(self, data):
- ret = self.mask(self.offset, data)
- self.offset += len(ret)
- return ret
-
-
-class WebsocketsProtocol(object):
-
- def __init__(self):
- pass
-
- @classmethod
- def client_handshake_headers(self, key=None, version=VERSION):
- """
- Create the headers for a valid HTTP upgrade request. If Key is not
- specified, it is generated, and can be found in sec-websocket-key in
- the returned header set.
-
- Returns an instance of http.Headers
- """
- if not key:
- key = base64.b64encode(os.urandom(16)).decode('ascii')
- return http.Headers(
- sec_websocket_key=key,
- sec_websocket_version=version,
- connection="Upgrade",
- upgrade="websocket",
- )
-
- @classmethod
- def server_handshake_headers(self, key):
- """
- The server response is a valid HTTP 101 response.
- """
- return http.Headers(
- sec_websocket_accept=self.create_server_nonce(key),
- connection="Upgrade",
- upgrade="websocket"
- )
-
- @classmethod
- def check_client_handshake(self, headers):
- if headers.get("upgrade") != "websocket":
- return
- return headers.get("sec-websocket-key")
-
- @classmethod
- def check_server_handshake(self, headers):
- if headers.get("upgrade") != "websocket":
- return
- return headers.get("sec-websocket-accept")
-
- @classmethod
- def create_server_nonce(self, client_nonce):
- return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + websockets_magic).digest())
diff --git a/netlib/websockets/utils.py b/netlib/websockets/utils.py
new file mode 100644
index 00000000..aa0d39a1
--- /dev/null
+++ b/netlib/websockets/utils.py
@@ -0,0 +1,90 @@
+"""
+Collection of WebSockets Protocol utility functions (RFC6455)
+Spec: https://tools.ietf.org/html/rfc6455
+"""
+
+from __future__ import absolute_import
+
+import base64
+import hashlib
+import os
+
+from netlib import http, strutils
+
+MAGIC = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
+VERSION = "13"
+
+
+def client_handshake_headers(version=None, key=None, protocol=None, extensions=None):
+ """
+ Create the headers for a valid HTTP upgrade request. If Key is not
+ specified, it is generated, and can be found in sec-websocket-key in
+ the returned header set.
+
+ Returns an instance of http.Headers
+ """
+ if version is None:
+ version = VERSION
+ if key is None:
+ key = base64.b64encode(os.urandom(16)).decode('ascii')
+ h = http.Headers(
+ connection="upgrade",
+ upgrade="websocket",
+ sec_websocket_version=version,
+ sec_websocket_key=key,
+ )
+ if protocol is not None:
+ h['sec-websocket-protocol'] = protocol
+ if extensions is not None:
+ h['sec-websocket-extensions'] = extensions
+ return h
+
+
+def server_handshake_headers(client_key, protocol=None, extensions=None):
+ """
+ The server response is a valid HTTP 101 response.
+
+ Returns an instance of http.Headers
+ """
+ h = http.Headers(
+ connection="upgrade",
+ upgrade="websocket",
+ sec_websocket_accept=create_server_nonce(client_key),
+ )
+ if protocol is not None:
+ h['sec-websocket-protocol'] = protocol
+ if extensions is not None:
+ h['sec-websocket-extensions'] = extensions
+ return h
+
+
+def check_handshake(headers):
+ return (
+ "upgrade" in headers.get("connection", "").lower() and
+ headers.get("upgrade", "").lower() == "websocket" and
+ (headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None)
+ )
+
+
+def create_server_nonce(client_nonce):
+ return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + MAGIC).digest())
+
+
+def check_client_version(headers):
+ return headers.get("sec-websocket-version", "") == VERSION
+
+
+def get_extensions(headers):
+ return headers.get("sec-websocket-extensions", None)
+
+
+def get_protocol(headers):
+ return headers.get("sec-websocket-protocol", None)
+
+
+def get_client_key(headers):
+ return headers.get("sec-websocket-key", None)
+
+
+def get_server_accept(headers):
+ return headers.get("sec-websocket-accept", None)
diff --git a/pathod/language/http.py b/pathod/language/http.py
index fdc5bba6..46027ca3 100644
--- a/pathod/language/http.py
+++ b/pathod/language/http.py
@@ -198,7 +198,7 @@ class Response(_HTTPMessage):
1,
StatusCode(101)
)
- headers = netlib.websockets.WebsocketsProtocol.server_handshake_headers(
+ headers = netlib.websockets.server_handshake_headers(
settings.websocket_key
)
for i in headers.fields:
@@ -310,7 +310,7 @@ class Request(_HTTPMessage):
1,
Method("get")
)
- for i in netlib.websockets.WebsocketsProtocol.client_handshake_headers().fields:
+ for i in netlib.websockets.client_handshake_headers().fields:
if not get_header(i[0], self.headers):
tokens.append(
Header(
diff --git a/pathod/pathoc.py b/pathod/pathoc.py
index 5831ba3e..a8923013 100644
--- a/pathod/pathoc.py
+++ b/pathod/pathoc.py
@@ -139,7 +139,7 @@ class WebsocketFrameReader(basethread.BaseThread):
except exceptions.TcpDisconnect:
return
self.frames_queue.put(frm)
- log("<< %s" % frm.header.human_readable())
+ log("<< %s" % repr(frm.header))
if self.ws_read_limit is not None:
self.ws_read_limit -= 1
starttime = time.time()
diff --git a/pathod/pathod.py b/pathod/pathod.py
index 7087cba6..bd0feb73 100644
--- a/pathod/pathod.py
+++ b/pathod/pathod.py
@@ -173,12 +173,13 @@ class PathodHandler(tcp.BaseHandler):
retlog["cipher"] = self.get_current_cipher()
m = utils.MemBool()
- websocket_key = websockets.WebsocketsProtocol.check_client_handshake(headers)
- self.settings.websocket_key = websocket_key
+
+ valid_websockets_handshake = websockets.check_handshake(headers)
+ self.settings.websocket_key = websockets.get_client_key(headers)
# If this is a websocket initiation, we respond with a proper
# server response, unless over-ridden.
- if websocket_key:
+ if valid_websockets_handshake:
anchor_gen = language.parse_pathod("ws")
else:
anchor_gen = None
@@ -225,7 +226,7 @@ class PathodHandler(tcp.BaseHandler):
spec,
lg
)
- if nexthandler and websocket_key:
+ if nexthandler and valid_websockets_handshake:
self.protocol = protocols.websockets.WebsocketsProtocol(self)
return self.protocol.handle_websocket, retlog
else:
diff --git a/pathod/protocols/websockets.py b/pathod/protocols/websockets.py
index a34e75e8..df83461a 100644
--- a/pathod/protocols/websockets.py
+++ b/pathod/protocols/websockets.py
@@ -20,7 +20,7 @@ class WebsocketsProtocol:
lg("Error reading websocket frame: %s" % e)
return None, None
ended = time.time()
- lg(frm.human_readable())
+ lg(repr(frm))
retlog = dict(
type="inbound",
protocol="websockets",
diff --git a/test/mitmproxy/protocol/__init__.py b/test/mitmproxy/protocol/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/mitmproxy/protocol/__init__.py
diff --git a/test/mitmproxy/test_protocol_http1.py b/test/mitmproxy/protocol/test_http1.py
index cf7bd598..7d04c56b 100644
--- a/test/mitmproxy/test_protocol_http1.py
+++ b/test/mitmproxy/protocol/test_http1.py
@@ -1,7 +1,9 @@
+from __future__ import (absolute_import, print_function, division)
+
from netlib.http import http1
from netlib.tcp import TCPClient
from netlib.tutils import treq
-from . import tutils, tservers
+from .. import tutils, tservers
class TestHTTPFlow(object):
diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/protocol/test_http2.py
index 873c89c3..1eabebf1 100644
--- a/test/mitmproxy/test_protocol_http2.py
+++ b/test/mitmproxy/protocol/test_http2.py
@@ -13,11 +13,11 @@ from mitmproxy import options
from mitmproxy.proxy.config import ProxyConfig
import netlib
-from ..netlib import tservers as netlib_tservers
+from ...netlib import tservers as netlib_tservers
from netlib.exceptions import HttpException
from netlib.http.http2 import framereader
-from . import tservers
+from .. import tservers
import logging
logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING)
diff --git a/test/mitmproxy/protocol/test_websockets.py b/test/mitmproxy/protocol/test_websockets.py
new file mode 100644
index 00000000..e7e2684f
--- /dev/null
+++ b/test/mitmproxy/protocol/test_websockets.py
@@ -0,0 +1,299 @@
+from __future__ import absolute_import, print_function, division
+
+import pytest
+import os
+import tempfile
+import traceback
+
+from mitmproxy import options
+from mitmproxy.proxy.config import ProxyConfig
+
+import netlib
+from netlib import http
+from ...netlib import tservers as netlib_tservers
+from .. import tservers
+
+from netlib import websockets
+
+
+class _WebSocketsServerBase(netlib_tservers.ServerTestBase):
+
+ class handler(netlib.tcp.BaseHandler):
+
+ def handle(self):
+ try:
+ request = http.http1.read_request(self.rfile)
+ assert websockets.check_handshake(request.headers)
+
+ response = http.Response(
+ "HTTP/1.1",
+ 101,
+ reason=http.status_codes.RESPONSES.get(101),
+ headers=http.Headers(
+ connection='upgrade',
+ upgrade='websocket',
+ sec_websocket_accept=b'',
+ ),
+ content=b'',
+ )
+ self.wfile.write(http.http1.assemble_response(response))
+ self.wfile.flush()
+
+ self.server.handle_websockets(self.rfile, self.wfile)
+ except:
+ traceback.print_exc()
+
+
+class _WebSocketsTestBase(object):
+
+ @classmethod
+ def setup_class(cls):
+ opts = cls.get_options()
+ cls.config = ProxyConfig(opts)
+
+ tmaster = tservers.TestMaster(opts, cls.config)
+ tmaster.start_app(options.APP_HOST, options.APP_PORT)
+ cls.proxy = tservers.ProxyThread(tmaster)
+ cls.proxy.start()
+
+ @classmethod
+ def teardown_class(cls):
+ cls.proxy.shutdown()
+
+ @classmethod
+ def get_options(cls):
+ opts = options.Options(
+ listen_port=0,
+ no_upstream_cert=False,
+ ssl_insecure=True
+ )
+ opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy")
+ return opts
+
+ @property
+ def master(self):
+ return self.proxy.tmaster
+
+ def setup(self):
+ self.master.clear_log()
+ self.master.state.clear()
+ self.server.server.handle_websockets = self.handle_websockets
+
+ def _setup_connection(self):
+ client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port))
+ client.connect()
+
+ request = http.Request(
+ "authority",
+ "CONNECT",
+ "",
+ "localhost",
+ self.server.server.address.port,
+ "",
+ "HTTP/1.1",
+ content=b'')
+ client.wfile.write(http.http1.assemble_request(request))
+ client.wfile.flush()
+
+ response = http.http1.read_response(client.rfile, request)
+
+ if self.ssl:
+ client.convert_to_ssl()
+ assert client.ssl_established
+
+ request = http.Request(
+ "relative",
+ "GET",
+ "http",
+ "localhost",
+ self.server.server.address.port,
+ "/ws",
+ "HTTP/1.1",
+ headers=http.Headers(
+ connection="upgrade",
+ upgrade="websocket",
+ sec_websocket_version="13",
+ sec_websocket_key="1234",
+ ),
+ content=b'')
+ client.wfile.write(http.http1.assemble_request(request))
+ client.wfile.flush()
+
+ response = http.http1.read_response(client.rfile, request)
+ assert websockets.check_handshake(response.headers)
+
+ return client
+
+
+class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase):
+
+ @classmethod
+ def setup_class(cls):
+ _WebSocketsTestBase.setup_class()
+ _WebSocketsServerBase.setup_class(ssl=cls.ssl)
+
+ @classmethod
+ def teardown_class(cls):
+ _WebSocketsTestBase.teardown_class()
+ _WebSocketsServerBase.teardown_class()
+
+
+class TestSimple(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ wfile.write(bytes(frame))
+ wfile.flush()
+
+ def test_simple(self):
+ client = self._setup_connection()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'server-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'client-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ client.wfile.flush()
+
+
+class TestSimpleTLS(_WebSocketsTest):
+ ssl = True
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ wfile.write(bytes(frame))
+ wfile.flush()
+
+ def test_simple_tls(self):
+ client = self._setup_connection()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'server-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'client-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ client.wfile.flush()
+
+
+class TestPing(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ assert frame.header.opcode == websockets.OPCODE.PONG
+ assert frame.payload == b'foobar'
+
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
+ wfile.flush()
+
+ def test_ping(self):
+ client = self._setup_connection()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == websockets.OPCODE.PING
+ assert frame.payload == b'foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == websockets.OPCODE.TEXT
+ assert frame.payload == b'pong-received'
+
+
+class TestPong(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ frame = websockets.Frame.from_file(rfile)
+ assert frame.header.opcode == websockets.OPCODE.PING
+ assert frame.payload == b'foobar'
+
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
+ wfile.flush()
+
+ def test_pong(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == websockets.OPCODE.PONG
+ assert frame.payload == b'foobar'
+
+
+class TestClose(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ frame = websockets.Frame.from_file(rfile)
+ wfile.write(bytes(frame))
+ wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(rfile)
+
+ def test_close(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ client.wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(client.rfile)
+
+ def test_close_payload_1(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
+ client.wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(client.rfile)
+
+ def test_close_payload_2(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
+ client.wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(client.rfile)
+
+
+class TestInvalidFrame(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar')))
+ wfile.flush()
+
+ def test_invalid_frame(self):
+ client = self._setup_connection()
+
+ # with pytest.raises(netlib.exceptions.TcpDisconnect):
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == 15
+ assert frame.payload == b'foobar'
diff --git a/test/netlib/websockets/test_frame.py b/test/netlib/websockets/test_frame.py
new file mode 100644
index 00000000..cce39454
--- /dev/null
+++ b/test/netlib/websockets/test_frame.py
@@ -0,0 +1,164 @@
+import os
+import codecs
+import pytest
+
+from netlib import websockets
+from netlib import tutils
+
+
+class TestFrameHeader(object):
+
+ @pytest.mark.parametrize("input,expected", [
+ (0, '0100'),
+ (125, '017D'),
+ (126, '017E007E'),
+ (127, '017E007F'),
+ (142, '017E008E'),
+ (65534, '017EFFFE'),
+ (65535, '017EFFFF'),
+ (65536, '017F0000000000010000'),
+ (8589934591, '017F00000001FFFFFFFF'),
+ (2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'),
+ ])
+ def test_serialization_length(self, input, expected):
+ h = websockets.FrameHeader(
+ opcode=websockets.OPCODE.TEXT,
+ payload_length=input,
+ )
+ assert bytes(h) == codecs.decode(expected, 'hex')
+
+ def test_serialization_too_large(self):
+ h = websockets.FrameHeader(
+ payload_length=2 ** 64 + 1,
+ )
+ with pytest.raises(ValueError):
+ bytes(h)
+
+ @pytest.mark.parametrize("input,expected", [
+ ('0100', 0),
+ ('017D', 125),
+ ('017E007E', 126),
+ ('017E007F', 127),
+ ('017E008E', 142),
+ ('017EFFFE', 65534),
+ ('017EFFFF', 65535),
+ ('017F0000000000010000', 65536),
+ ('017F00000001FFFFFFFF', 8589934591),
+ ('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1),
+ ])
+ def test_deserialization_length(self, input, expected):
+ h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
+ assert h.payload_length == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ('0100', (False, None)),
+ ('018000000000', (True, '00000000')),
+ ('018012345678', (True, '12345678')),
+ ])
+ def test_deserialization_masking(self, input, expected):
+ h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
+ assert h.mask == expected[0]
+ if h.mask:
+ assert h.masking_key == codecs.decode(expected[1], 'hex')
+
+ def test_equality(self):
+ h = websockets.FrameHeader(mask=True, masking_key=b'1234')
+ h2 = websockets.FrameHeader(mask=True, masking_key=b'1234')
+ assert h == h2
+
+ h = websockets.FrameHeader(fin=True)
+ h2 = websockets.FrameHeader(fin=False)
+ assert h != h2
+
+ assert h != 'foobar'
+
+ def test_roundtrip(self):
+ def round(*args, **kwargs):
+ h = websockets.FrameHeader(*args, **kwargs)
+ h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h)))
+ assert h == h2
+
+ round()
+ round(fin=True)
+ round(rsv1=True)
+ round(rsv2=True)
+ round(rsv3=True)
+ round(payload_length=1)
+ round(payload_length=100)
+ round(payload_length=1000)
+ round(payload_length=10000)
+ round(opcode=websockets.OPCODE.PING)
+ round(masking_key=b"test")
+
+ def test_human_readable(self):
+ f = websockets.FrameHeader(
+ masking_key=b"test",
+ fin=True,
+ payload_length=10
+ )
+ assert repr(f)
+
+ f = websockets.FrameHeader()
+ assert repr(f)
+
+ def test_funky(self):
+ f = websockets.FrameHeader(masking_key=b"test", mask=False)
+ raw = bytes(f)
+ f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
+ assert not f2.mask
+
+ def test_violations(self):
+ tutils.raises("opcode", websockets.FrameHeader, opcode=17)
+ tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
+
+ def test_automask(self):
+ f = websockets.FrameHeader(mask=True)
+ assert f.masking_key
+
+ f = websockets.FrameHeader(masking_key=b"foob")
+ assert f.mask
+
+ f = websockets.FrameHeader(masking_key=b"foob", mask=0)
+ assert not f.mask
+ assert f.masking_key
+
+
+class TestFrame(object):
+ def test_equality(self):
+ f = websockets.Frame(payload=b'1234')
+ f2 = websockets.Frame(payload=b'1234')
+ assert f == f2
+
+ assert f != b'1234'
+
+ def test_roundtrip(self):
+ def round(*args, **kwargs):
+ f = websockets.Frame(*args, **kwargs)
+ raw = bytes(f)
+ f2 = websockets.Frame.from_file(tutils.treader(raw))
+ assert f == f2
+ round(b"test")
+ round(b"test", fin=1)
+ round(b"test", rsv1=1)
+ round(b"test", opcode=websockets.OPCODE.PING)
+ round(b"test", masking_key=b"test")
+
+ def test_human_readable(self):
+ f = websockets.Frame()
+ assert repr(f)
+
+ f = websockets.Frame(b"foobar")
+ assert "foobar" in repr(f)
+
+ @pytest.mark.parametrize("masked", [True, False])
+ @pytest.mark.parametrize("length", [100, 50000, 150000])
+ def test_serialization_bijection(self, masked, length):
+ frame = websockets.Frame(
+ os.urandom(length),
+ fin=True,
+ opcode=websockets.OPCODE.TEXT,
+ mask=int(masked),
+ masking_key=(os.urandom(4) if masked else None)
+ )
+ serialized = bytes(frame)
+ assert frame == websockets.Frame.from_bytes(serialized)
diff --git a/test/netlib/websockets/test_masker.py b/test/netlib/websockets/test_masker.py
new file mode 100644
index 00000000..528fce71
--- /dev/null
+++ b/test/netlib/websockets/test_masker.py
@@ -0,0 +1,23 @@
+import codecs
+import pytest
+
+from netlib import websockets
+
+
+class TestMasker(object):
+
+ @pytest.mark.parametrize("input,expected", [
+ ([b"a"], '00'),
+ ([b"four"], '070d1616'),
+ ([b"fourf"], '070d161607'),
+ ([b"fourfive"], '070d1616070b1501'),
+ ([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'),
+ ([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'), # noqa
+ ])
+ def test_masker(self, input, expected):
+ m = websockets.Masker(b"abcd")
+ data = b"".join([m(t) for t in input])
+ assert data == codecs.decode(expected, 'hex')
+
+ data = websockets.Masker(b"abcd")(data)
+ assert data == b"".join(input)
diff --git a/test/netlib/websockets/test_utils.py b/test/netlib/websockets/test_utils.py
new file mode 100644
index 00000000..34765e04
--- /dev/null
+++ b/test/netlib/websockets/test_utils.py
@@ -0,0 +1,105 @@
+import pytest
+
+from netlib import http
+from netlib import websockets
+
+
+class TestUtils(object):
+
+ def test_client_handshake_headers(self):
+ h = websockets.client_handshake_headers(version='42')
+ assert h['sec-websocket-version'] == '42'
+
+ h = websockets.client_handshake_headers(key='some-key')
+ assert h['sec-websocket-key'] == 'some-key'
+
+ h = websockets.client_handshake_headers(protocol='foobar')
+ assert h['sec-websocket-protocol'] == 'foobar'
+
+ h = websockets.client_handshake_headers(extensions='foo; bar')
+ assert h['sec-websocket-extensions'] == 'foo; bar'
+
+ def test_server_handshake_headers(self):
+ h = websockets.server_handshake_headers('some-key')
+ assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw='
+ assert 'sec-websocket-protocol' not in h
+ assert 'sec-websocket-extensions' not in h
+
+ h = websockets.server_handshake_headers('some-key', 'foobar', 'foo; bar')
+ assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw='
+ assert h['sec-websocket-protocol'] == 'foobar'
+ assert h['sec-websocket-extensions'] == 'foo; bar'
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], True),
+ ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-accept', b'foobar')], True),
+ ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-KeY', b'foobar')], True),
+ ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-AccePt', b'foobar')], True),
+ ([(b'connection', b'foo'), (b'upgrade', b'bar'), (b'sec-websocket-key', b'foobar')], False),
+ ([(b'connection', b'upgrade'), (b'upgrade', b'websocket')], False),
+ ([(b'connection', b'upgrade'), (b'sec-websocket-key', b'foobar')], False),
+ ([(b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], False),
+ ([], False),
+ ])
+ def test_check_handshake(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.check_handshake(h) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'sec-websocket-version', b'13')], True),
+ ([(b'Sec-WebSockeT-VerSion', b'13')], True),
+ ([(b'sec-websocket-version', b'9')], False),
+ ([(b'sec-websocket-version', b'42')], False),
+ ([(b'sec-websocket-version', b'')], False),
+ ([], False),
+ ])
+ def test_check_client_version(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.check_client_version(h) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ('foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
+ (b'foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
+ ])
+ def test_create_server_nonce(self, input, expected):
+ assert websockets.create_server_nonce(input) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'sec-websocket-extensions', b'foo; bar')], 'foo; bar'),
+ ([(b'Sec-WebSockeT-ExteNsionS', b'foo; bar')], 'foo; bar'),
+ ([(b'sec-websocket-extensions', b'')], ''),
+ ([], None),
+ ])
+ def test_get_extensions(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.get_extensions(h) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'sec-websocket-protocol', b'foobar')], 'foobar'),
+ ([(b'Sec-WebSockeT-ProTocoL', b'foobar')], 'foobar'),
+ ([(b'sec-websocket-protocol', b'')], ''),
+ ([], None),
+ ])
+ def test_get_protocol(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.get_protocol(h) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'sec-websocket-key', b'foobar')], 'foobar'),
+ ([(b'Sec-WebSockeT-KeY', b'foobar')], 'foobar'),
+ ([(b'sec-websocket-key', b'')], ''),
+ ([], None),
+ ])
+ def test_get_client_key(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.get_client_key(h) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'sec-websocket-accept', b'foobar')], 'foobar'),
+ ([(b'Sec-WebSockeT-AccepT', b'foobar')], 'foobar'),
+ ([(b'sec-websocket-accept', b'')], ''),
+ ([], None),
+ ])
+ def test_get_server_accept(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.get_server_accept(h) == expected
diff --git a/test/netlib/websockets/test_websockets.py b/test/netlib/websockets/test_websockets.py
deleted file mode 100644
index 50fa26e6..00000000
--- a/test/netlib/websockets/test_websockets.py
+++ /dev/null
@@ -1,269 +0,0 @@
-import os
-
-from netlib.http.http1 import read_response, read_request
-
-from netlib import tcp
-from netlib import tutils
-from netlib import websockets
-from netlib.http import status_codes
-from netlib.tutils import treq
-from netlib import exceptions
-
-from .. import tservers
-
-
-class WebSocketsEchoHandler(tcp.BaseHandler):
-
- def __init__(self, connection, address, server):
- super(WebSocketsEchoHandler, self).__init__(
- connection, address, server
- )
- self.protocol = websockets.WebsocketsProtocol()
- self.handshake_done = False
-
- def handle(self):
- while True:
- if not self.handshake_done:
- self.handshake()
- else:
- self.read_next_message()
-
- def read_next_message(self):
- frame = websockets.Frame.from_file(self.rfile)
- self.on_message(frame.payload)
-
- def send_message(self, message):
- frame = websockets.Frame.default(message, from_client=False)
- frame.to_file(self.wfile)
-
- def handshake(self):
-
- req = read_request(self.rfile)
- key = self.protocol.check_client_handshake(req.headers)
-
- preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
- self.wfile.write(preamble.encode() + b"\r\n")
- headers = self.protocol.server_handshake_headers(key)
- self.wfile.write(str(headers) + "\r\n")
- self.wfile.flush()
- self.handshake_done = True
-
- def on_message(self, message):
- if message is not None:
- self.send_message(message)
-
-
-class WebSocketsClient(tcp.TCPClient):
-
- def __init__(self, address, source_address=None):
- super(WebSocketsClient, self).__init__(address, source_address)
- self.protocol = websockets.WebsocketsProtocol()
- self.client_nonce = None
-
- def connect(self):
- super(WebSocketsClient, self).connect()
-
- preamble = b'GET / HTTP/1.1'
- self.wfile.write(preamble + b"\r\n")
- headers = self.protocol.client_handshake_headers()
- self.client_nonce = headers["sec-websocket-key"].encode("ascii")
- self.wfile.write(bytes(headers) + b"\r\n")
- self.wfile.flush()
-
- resp = read_response(self.rfile, treq(method=b"GET"))
- server_nonce = self.protocol.check_server_handshake(resp.headers)
-
- if not server_nonce == self.protocol.create_server_nonce(self.client_nonce):
- self.close()
-
- def read_next_message(self):
- return websockets.Frame.from_file(self.rfile).payload
-
- def send_message(self, message):
- frame = websockets.Frame.default(message, from_client=True)
- frame.to_file(self.wfile)
-
-
-class TestWebSockets(tservers.ServerTestBase):
- handler = WebSocketsEchoHandler
-
- def __init__(self):
- self.protocol = websockets.WebsocketsProtocol()
-
- def random_bytes(self, n=100):
- return os.urandom(n)
-
- def echo(self, msg):
- client = WebSocketsClient(("127.0.0.1", self.port))
- client.connect()
- client.send_message(msg)
- response = client.read_next_message()
- assert response == msg
-
- def test_simple_echo(self):
- self.echo(b"hello I'm the client")
-
- def test_frame_sizes(self):
- # length can fit in the the 7 bit payload length
- small_msg = self.random_bytes(100)
- # 50kb, sligthly larger than can fit in a 7 bit int
- medium_msg = self.random_bytes(50000)
- # 150kb, slightly larger than can fit in a 16 bit int
- large_msg = self.random_bytes(150000)
-
- self.echo(small_msg)
- self.echo(medium_msg)
- self.echo(large_msg)
-
- def test_default_builder(self):
- """
- default builder should always generate valid frames
- """
- msg = self.random_bytes()
- assert websockets.Frame.default(msg, from_client=True)
- assert websockets.Frame.default(msg, from_client=False)
-
- def test_serialization_bijection(self):
- """
- Ensure that various frame types can be serialized/deserialized back
- and forth between to_bytes() and from_bytes()
- """
- for is_client in [True, False]:
- for num_bytes in [100, 50000, 150000]:
- frame = websockets.Frame.default(
- self.random_bytes(num_bytes), is_client
- )
- frame2 = websockets.Frame.from_bytes(
- frame.to_bytes()
- )
- assert frame == frame2
-
- bytes = b'\x81\x03cba'
- assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
-
- def test_check_server_handshake(self):
- headers = self.protocol.server_handshake_headers("key")
- assert self.protocol.check_server_handshake(headers)
- headers["Upgrade"] = "not_websocket"
- assert not self.protocol.check_server_handshake(headers)
-
- def test_check_client_handshake(self):
- headers = self.protocol.client_handshake_headers("key")
- assert self.protocol.check_client_handshake(headers) == "key"
- headers["Upgrade"] = "not_websocket"
- assert not self.protocol.check_client_handshake(headers)
-
-
-class BadHandshakeHandler(WebSocketsEchoHandler):
-
- def handshake(self):
-
- client_hs = read_request(self.rfile)
- self.protocol.check_client_handshake(client_hs.headers)
-
- preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101)
- self.wfile.write(preamble.encode())
- headers = self.protocol.server_handshake_headers(b"malformed key")
- self.wfile.write(bytes(headers) + b"\r\n")
- self.wfile.flush()
- self.handshake_done = True
-
-
-class TestBadHandshake(tservers.ServerTestBase):
-
- """
- Ensure that the client disconnects if the server handshake is malformed
- """
- handler = BadHandshakeHandler
-
- def test(self):
- with tutils.raises(exceptions.TcpDisconnect):
- client = WebSocketsClient(("127.0.0.1", self.port))
- client.connect()
- client.send_message(b"hello")
-
-
-class TestFrameHeader:
-
- def test_roundtrip(self):
- def round(*args, **kwargs):
- f = websockets.FrameHeader(*args, **kwargs)
- f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f)))
- assert f == f2
- round()
- round(fin=1)
- round(rsv1=1)
- round(rsv2=1)
- round(rsv3=1)
- round(payload_length=1)
- round(payload_length=100)
- round(payload_length=1000)
- round(payload_length=10000)
- round(opcode=websockets.OPCODE.PING)
- round(masking_key=b"test")
-
- def test_human_readable(self):
- f = websockets.FrameHeader(
- masking_key=b"test",
- fin=True,
- payload_length=10
- )
- assert repr(f)
- f = websockets.FrameHeader()
- assert repr(f)
-
- def test_funky(self):
- f = websockets.FrameHeader(masking_key=b"test", mask=False)
- raw = bytes(f)
- f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
- assert not f2.mask
-
- def test_violations(self):
- tutils.raises("opcode", websockets.FrameHeader, opcode=17)
- tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
-
- def test_automask(self):
- f = websockets.FrameHeader(mask=True)
- assert f.masking_key
-
- f = websockets.FrameHeader(masking_key=b"foob")
- assert f.mask
-
- f = websockets.FrameHeader(masking_key=b"foob", mask=0)
- assert not f.mask
- assert f.masking_key
-
-
-class TestFrame:
-
- def test_roundtrip(self):
- def round(*args, **kwargs):
- f = websockets.Frame(*args, **kwargs)
- raw = bytes(f)
- f2 = websockets.Frame.from_file(tutils.treader(raw))
- assert f == f2
- round(b"test")
- round(b"test", fin=1)
- round(b"test", rsv1=1)
- round(b"test", opcode=websockets.OPCODE.PING)
- round(b"test", masking_key=b"test")
-
- def test_human_readable(self):
- f = websockets.Frame()
- assert repr(f)
-
-
-def test_masker():
- tests = [
- [b"a"],
- [b"four"],
- [b"fourf"],
- [b"fourfive"],
- [b"a", b"aasdfasdfa", b"asdf"],
- [b"a" * 50, b"aasdfasdfa", b"asdf"],
- ]
- for i in tests:
- m = websockets.Masker(b"abcd")
- data = b"".join([m(t) for t in i])
- data2 = websockets.Masker(b"abcd")(data)
- assert data2 == b"".join(i)