aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Kriechbaumer <thomas@kriechbaumer.name>2016-08-16 18:31:50 +0200
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2016-09-01 09:57:18 +0200
commite5b0dae7e9ef8d2ce62fc263c377c76546190825 (patch)
treed3b7257cfb0a0e66e2e7176f9d32bc16d820e618
parentd12515f84b32b3157fa99ac3c3a7a7318f9626ba (diff)
downloadmitmproxy-e5b0dae7e9ef8d2ce62fc263c377c76546190825.tar.gz
mitmproxy-e5b0dae7e9ef8d2ce62fc263c377c76546190825.tar.bz2
mitmproxy-e5b0dae7e9ef8d2ce62fc263c377c76546190825.zip
add websockets support to mitmproxy
-rw-r--r--mitmproxy/protocol/http.py20
-rw-r--r--mitmproxy/protocol/websockets.py140
-rw-r--r--pathod/language/http.py4
-rw-r--r--pathod/pathoc.py2
-rw-r--r--pathod/pathod.py9
-rw-r--r--pathod/protocols/websockets.py2
-rw-r--r--test/mitmproxy/protocol/test_websockets.py297
7 files changed, 465 insertions, 9 deletions
diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py
index d81fc8ca..fbb52c92 100644
--- a/mitmproxy/protocol/http.py
+++ b/mitmproxy/protocol/http.py
@@ -7,12 +7,15 @@ import traceback
import h2.exceptions
import six
-import netlib.exceptions
from mitmproxy import exceptions
from mitmproxy import models
from mitmproxy.protocol import base
+from .websockets import WebSocketsLayer
+
+import netlib.exceptions
from netlib import http
from netlib import tcp
+from netlib import websockets
class _HttpTransmissionLayer(base.Layer):
@@ -189,6 +192,21 @@ class HttpLayer(base.Layer):
self.process_request_hook(flow)
try:
+ # WebSockets
+ if websockets.check_handshake(request.headers):
+ if websockets.check_client_version(request.headers):
+ layer = WebSocketsLayer(self, request)
+ layer()
+ return
+ else:
+ # we only support RFC6455 with WebSockets version 13
+ self.send_response(models.make_error_response(
+ 400,
+ http.status_codes.RESPONSES.get(400),
+ http.Headers(sec_websocket_version="13")
+ ))
+ return
+
if not flow.response:
self.establish_server_connection(
flow.request.host,
diff --git a/mitmproxy/protocol/websockets.py b/mitmproxy/protocol/websockets.py
new file mode 100644
index 00000000..05eaa537
--- /dev/null
+++ b/mitmproxy/protocol/websockets.py
@@ -0,0 +1,140 @@
+from __future__ import absolute_import, print_function, division
+
+import socket
+import struct
+
+from OpenSSL import SSL
+
+from mitmproxy import exceptions
+from mitmproxy import models
+from mitmproxy.protocol import base
+
+import netlib.exceptions
+from netlib import tcp
+from netlib import http
+from netlib import websockets
+
+
+class WebSocketsLayer(base.Layer):
+ """
+ WebSockets layer to intercept, modify, and forward WebSockets connections
+
+ Only version 13 is supported (as specified in RFC6455)
+ Only HTTP/1.1-initiated connections are supported.
+
+ The client starts by sending an Upgrade-request.
+ In order to determine the handshake and negotiate the correct protocol
+ and extensions, the Upgrade-request is forwarded to the server.
+ The response from the server is then parsed and negotiated settings are extracted.
+ Finally the handshake is completed by forwarding the server-response to the client.
+ After that, only WebSockets frames are exchanged.
+
+ PING/PONG frames pass through and must be answered by the other endpoint.
+
+ CLOSE frames are forwarded before this WebSocketsLayer terminates.
+
+ This layer is transparent to any negotiated extensions.
+ This layer is transparent to any negotiated subprotocols.
+ Only raw frames are forwarded to the other endpoint.
+ """
+
+ def __init__(self, ctx, request):
+ super(WebSocketsLayer, self).__init__(ctx)
+ self._request = request
+
+ self.client_key = websockets.get_client_key(self._request.headers)
+ self.client_protocol = websockets.get_protocol(self._request.headers)
+ self.client_extensions = websockets.get_extensions(self._request.headers)
+
+ self.server_accept = None
+ self.server_protocol = None
+ self.server_extensions = None
+
+ def _initiate_server_conn(self):
+ self.establish_server_connection(
+ self._request.host,
+ self._request.port,
+ self._request.scheme,
+ )
+
+ self.server_conn.send(netlib.http.http1.assemble_request(self._request))
+ response = netlib.http.http1.read_response(self.server_conn.rfile, self._request, body_size_limit=None)
+
+ if not websockets.check_handshake(response.headers):
+ raise exceptions.ProtocolException("Establishing WebSockets connection with server failed: {}".format(response.headers))
+
+ self.server_accept = websockets.get_server_accept(response.headers)
+ self.server_protocol = websockets.get_protocol(response.headers)
+ self.server_extensions = websockets.get_extensions(response.headers)
+
+ def _complete_handshake(self):
+ headers = websockets.server_handshake_headers(self.client_key, self.server_protocol, self.server_extensions)
+ self.send_response(models.HTTPResponse(
+ self._request.http_version,
+ 101,
+ http.status_codes.RESPONSES.get(101),
+ headers,
+ b"",
+ ))
+
+ def _handle_frame(self, frame, source_conn, other_conn, is_server):
+ self.log(
+ "WebSockets Frame received from {}".format("server" if is_server else "client"),
+ "debug",
+ [repr(frame)]
+ )
+
+ if frame.header.opcode & 0x8 == 0:
+ # forward the data frame to the other side
+ other_conn.send(bytes(frame))
+ self.log("WebSockets frame received by {}: {}".format(is_server, frame), "debug")
+ elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
+ # just forward the ping/pong to the other side
+ other_conn.send(bytes(frame))
+ elif frame.header.opcode == websockets.OPCODE.CLOSE:
+ other_conn.send(bytes(frame))
+
+ code = '(status code missing)'
+ msg = None
+ reason = '(message missing)'
+ if len(frame.payload) >= 2:
+ code, = struct.unpack('!H', frame.payload[:2])
+ msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
+ if len(frame.payload) > 2:
+ reason = frame.payload[2:]
+ self.log("WebSockets connection closed: {} {}, {}".format(code, msg, reason), "info")
+
+ # close the connection
+ return False
+ else:
+ # unknown frame - just forward it
+ other_conn.send(bytes(frame))
+
+ # continue the connection
+ return True
+
+ def __call__(self):
+ self._initiate_server_conn()
+ self._complete_handshake()
+
+ client = self.client_conn.connection
+ server = self.server_conn.connection
+ conns = [client, server]
+
+ try:
+ while not self.channel.should_exit.is_set():
+ r = tcp.ssl_read_select(conns, 1)
+ for conn in r:
+ source_conn = self.client_conn if conn == client else self.server_conn
+ other_conn = self.server_conn if conn == client else self.client_conn
+ is_server = (conn == self.server_conn.connection)
+
+ frame = websockets.Frame.from_file(source_conn.rfile)
+
+ if not self._handle_frame(frame, source_conn, other_conn, is_server):
+ return
+ except (socket.error, netlib.exceptions.TcpException, SSL.Error) as e:
+ self.log("WebSockets connection closed unexpectedly by {}: {}".format(
+ "server" if is_server else "client", repr(e)), "info")
+ except Exception as e: # pragma: no cover
+ raise exceptions.ProtocolException("Error in WebSockets connection: {}".format(repr(e)))
diff --git a/pathod/language/http.py b/pathod/language/http.py
index fdc5bba6..46027ca3 100644
--- a/pathod/language/http.py
+++ b/pathod/language/http.py
@@ -198,7 +198,7 @@ class Response(_HTTPMessage):
1,
StatusCode(101)
)
- headers = netlib.websockets.WebsocketsProtocol.server_handshake_headers(
+ headers = netlib.websockets.server_handshake_headers(
settings.websocket_key
)
for i in headers.fields:
@@ -310,7 +310,7 @@ class Request(_HTTPMessage):
1,
Method("get")
)
- for i in netlib.websockets.WebsocketsProtocol.client_handshake_headers().fields:
+ for i in netlib.websockets.client_handshake_headers().fields:
if not get_header(i[0], self.headers):
tokens.append(
Header(
diff --git a/pathod/pathoc.py b/pathod/pathoc.py
index 5831ba3e..a8923013 100644
--- a/pathod/pathoc.py
+++ b/pathod/pathoc.py
@@ -139,7 +139,7 @@ class WebsocketFrameReader(basethread.BaseThread):
except exceptions.TcpDisconnect:
return
self.frames_queue.put(frm)
- log("<< %s" % frm.header.human_readable())
+ log("<< %s" % repr(frm.header))
if self.ws_read_limit is not None:
self.ws_read_limit -= 1
starttime = time.time()
diff --git a/pathod/pathod.py b/pathod/pathod.py
index 7087cba6..bd0feb73 100644
--- a/pathod/pathod.py
+++ b/pathod/pathod.py
@@ -173,12 +173,13 @@ class PathodHandler(tcp.BaseHandler):
retlog["cipher"] = self.get_current_cipher()
m = utils.MemBool()
- websocket_key = websockets.WebsocketsProtocol.check_client_handshake(headers)
- self.settings.websocket_key = websocket_key
+
+ valid_websockets_handshake = websockets.check_handshake(headers)
+ self.settings.websocket_key = websockets.get_client_key(headers)
# If this is a websocket initiation, we respond with a proper
# server response, unless over-ridden.
- if websocket_key:
+ if valid_websockets_handshake:
anchor_gen = language.parse_pathod("ws")
else:
anchor_gen = None
@@ -225,7 +226,7 @@ class PathodHandler(tcp.BaseHandler):
spec,
lg
)
- if nexthandler and websocket_key:
+ if nexthandler and valid_websockets_handshake:
self.protocol = protocols.websockets.WebsocketsProtocol(self)
return self.protocol.handle_websocket, retlog
else:
diff --git a/pathod/protocols/websockets.py b/pathod/protocols/websockets.py
index a34e75e8..df83461a 100644
--- a/pathod/protocols/websockets.py
+++ b/pathod/protocols/websockets.py
@@ -20,7 +20,7 @@ class WebsocketsProtocol:
lg("Error reading websocket frame: %s" % e)
return None, None
ended = time.time()
- lg(frm.human_readable())
+ lg(repr(frm))
retlog = dict(
type="inbound",
protocol="websockets",
diff --git a/test/mitmproxy/protocol/test_websockets.py b/test/mitmproxy/protocol/test_websockets.py
new file mode 100644
index 00000000..cc478c0b
--- /dev/null
+++ b/test/mitmproxy/protocol/test_websockets.py
@@ -0,0 +1,297 @@
+import pytest
+import os
+import tempfile
+import traceback
+
+from mitmproxy import options
+from mitmproxy.proxy.config import ProxyConfig
+
+import netlib
+from netlib import http
+from ...netlib import tservers as netlib_tservers
+from .. import tservers
+
+from netlib import websockets
+
+
+class _WebSocketsServerBase(netlib_tservers.ServerTestBase):
+
+ class handler(netlib.tcp.BaseHandler):
+
+ def handle(self):
+ try:
+ request = http.http1.read_request(self.rfile)
+ assert websockets.check_handshake(request.headers)
+
+ response = http.Response(
+ "HTTP/1.1",
+ 101,
+ reason=http.status_codes.RESPONSES.get(101),
+ headers=http.Headers(
+ connection='upgrade',
+ upgrade='websocket',
+ sec_websocket_accept=b'',
+ ),
+ content=b'',
+ )
+ self.wfile.write(http.http1.assemble_response(response))
+ self.wfile.flush()
+
+ self.server.handle_websockets(self.rfile, self.wfile)
+ except:
+ traceback.print_exc()
+
+
+class _WebSocketsTestBase(object):
+
+ @classmethod
+ def setup_class(cls):
+ opts = cls.get_options()
+ cls.config = ProxyConfig(opts)
+
+ tmaster = tservers.TestMaster(opts, cls.config)
+ tmaster.start_app(options.APP_HOST, options.APP_PORT)
+ cls.proxy = tservers.ProxyThread(tmaster)
+ cls.proxy.start()
+
+ @classmethod
+ def teardown_class(cls):
+ cls.proxy.shutdown()
+
+ @classmethod
+ def get_options(cls):
+ opts = options.Options(
+ listen_port=0,
+ no_upstream_cert=False,
+ ssl_insecure=True
+ )
+ opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy")
+ return opts
+
+ @property
+ def master(self):
+ return self.proxy.tmaster
+
+ def setup(self):
+ self.master.clear_log()
+ self.master.state.clear()
+ self.server.server.handle_websockets = self.handle_websockets
+
+ def _setup_connection(self):
+ client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port))
+ client.connect()
+
+ request = http.Request(
+ "authority",
+ "CONNECT",
+ "",
+ "localhost",
+ self.server.server.address.port,
+ "",
+ "HTTP/1.1",
+ content=b'')
+ client.wfile.write(http.http1.assemble_request(request))
+ client.wfile.flush()
+
+ response = http.http1.read_response(client.rfile, request)
+
+ if self.ssl:
+ client.convert_to_ssl()
+ assert client.ssl_established
+
+ request = http.Request(
+ "relative",
+ "GET",
+ "http",
+ "localhost",
+ self.server.server.address.port,
+ "/ws",
+ "HTTP/1.1",
+ headers=http.Headers(
+ connection="upgrade",
+ upgrade="websocket",
+ sec_websocket_version="13",
+ sec_websocket_key="1234",
+ ),
+ content=b'')
+ client.wfile.write(http.http1.assemble_request(request))
+ client.wfile.flush()
+
+ response = http.http1.read_response(client.rfile, request)
+ assert websockets.check_handshake(response.headers)
+
+ return client
+
+
+class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase):
+
+ @classmethod
+ def setup_class(cls):
+ _WebSocketsTestBase.setup_class()
+ _WebSocketsServerBase.setup_class(ssl=cls.ssl)
+
+ @classmethod
+ def teardown_class(cls):
+ _WebSocketsTestBase.teardown_class()
+ _WebSocketsServerBase.teardown_class()
+
+
+class TestSimple(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ wfile.write(bytes(frame))
+ wfile.flush()
+
+ def test_simple(self):
+ client = self._setup_connection()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'server-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'client-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ client.wfile.flush()
+
+
+class TestSimpleTLS(_WebSocketsTest):
+ ssl = True
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ wfile.write(bytes(frame))
+ wfile.flush()
+
+ def test_simple_tls(self):
+ client = self._setup_connection()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'server-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'client-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ client.wfile.flush()
+
+
+class TestPing(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ assert frame.header.opcode == websockets.OPCODE.PONG
+ assert frame.payload == b'foobar'
+
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
+ wfile.flush()
+
+ def test_ping(self):
+ client = self._setup_connection()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == websockets.OPCODE.PING
+ assert frame.payload == b'foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == websockets.OPCODE.TEXT
+ assert frame.payload == b'pong-received'
+
+
+class TestPong(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ frame = websockets.Frame.from_file(rfile)
+ assert frame.header.opcode == websockets.OPCODE.PING
+ assert frame.payload == b'foobar'
+
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
+ wfile.flush()
+
+ def test_pong(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == websockets.OPCODE.PONG
+ assert frame.payload == b'foobar'
+
+
+class TestClose(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ frame = websockets.Frame.from_file(rfile)
+ wfile.write(bytes(frame))
+ wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(rfile)
+
+ def test_close(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ client.wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(client.rfile)
+
+ def test_close_payload_1(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
+ client.wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(client.rfile)
+
+ def test_close_payload_2(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
+ client.wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(client.rfile)
+
+
+class TestInvalidFrame(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar')))
+ wfile.flush()
+
+ def test_invalid_frame(self):
+ client = self._setup_connection()
+
+ # with pytest.raises(netlib.exceptions.TcpDisconnect):
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == 15
+ assert frame.payload == b'foobar'