aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mitmproxy/addons/dumper.py2
-rw-r--r--mitmproxy/proxy/protocol/websocket.py159
-rw-r--r--mitmproxy/tools/console/consoleaddons.py2
-rw-r--r--setup.cfg9
-rw-r--r--test/mitmproxy/proxy/protocol/test_websocket.py122
5 files changed, 185 insertions, 109 deletions
diff --git a/mitmproxy/addons/dumper.py b/mitmproxy/addons/dumper.py
index 54526d5b..48bc8118 100644
--- a/mitmproxy/addons/dumper.py
+++ b/mitmproxy/addons/dumper.py
@@ -234,6 +234,8 @@ class Dumper:
message = f.messages[-1]
self.echo(f.message_info(message))
if ctx.options.flow_detail >= 3:
+ message = message.from_state(message.get_state())
+ message.content = message.content.encode() if isinstance(message.content, str) else message.content
self._echo_message(message)
def websocket_end(self, f):
diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py
index d1abd134..54d8120d 100644
--- a/mitmproxy/proxy/protocol/websocket.py
+++ b/mitmproxy/proxy/protocol/websocket.py
@@ -1,19 +1,18 @@
-import os
import socket
-import struct
from OpenSSL import SSL
from wsproto import events
from wsproto.connection import ConnectionType, WSConnection
from wsproto.extensions import PerMessageDeflate
+from wsproto.frame_protocol import Opcode
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy.proxy.protocol import base
-from mitmproxy.net import http
from mitmproxy.net import tcp
from mitmproxy.net import websockets
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
+from mitmproxy.utils import strutils
class WebSocketLayer(base.Layer):
@@ -54,14 +53,16 @@ class WebSocketLayer(base.Layer):
extensions = []
if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
- extensions = [PerMessageDeflate.name]
-
+ extensions = [PerMessageDeflate()]
self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
extensions=extensions)
self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
host=handshake_flow.request.host,
resource=handshake_flow.request.path,
extensions=extensions)
+ if extensions:
+ for conn in self.connections.values():
+ conn.extensions[0].finalize(conn, handshake_flow.response.headers['Sec-WebSocket-Extensions'])
data = self.connections[self.server_conn].bytes_to_send()
self.connections[self.client_conn].receive_bytes(data)
@@ -80,28 +81,78 @@ class WebSocketLayer(base.Layer):
return self._handle_ping_received(event, source_conn, other_conn, is_server)
elif isinstance(event, events.PongReceived):
return self._handle_pong_received(event, source_conn, other_conn, is_server)
- elif isinstance(event, events.ConnectionFailed):
+ elif isinstance(event, events.ConnectionClosed):
return self._handle_connection_closed(event, source_conn, other_conn, is_server)
- elif isinstance(event, events.ConnectionFailed):
- return self._handle_connection_failed(event)
# fail-safe for unhandled events
- return True
+ return True # pragma: no cover
def _handle_data_received(self, event, source_conn, other_conn, is_server):
+ fb = self.server_frame_buffer if is_server else self.client_frame_buffer
+ fb.append(event.data)
+
+ if event.message_finished:
+ original_chunk_sizes = [len(f) for f in fb]
+ message_type = Opcode.TEXT if isinstance(event, events.TextReceived) else Opcode.BINARY
+ if message_type == Opcode.TEXT:
+ payload = ''.join(fb)
+ else:
+ payload = b''.join(fb)
+ fb.clear()
+
+ websocket_message = WebSocketMessage(message_type, not is_server, payload)
+ length = len(websocket_message.content)
+ self.flow.messages.append(websocket_message)
+ self.channel.ask("websocket_message", self.flow)
+
+ if not self.flow.stream:
+ def get_chunk(payload):
+ if len(payload) == length:
+ # message has the same length, we can reuse the same sizes
+ pos = 0
+ for s in original_chunk_sizes:
+ yield (payload[pos:pos + s], True if pos + s == length else False)
+ pos += s
+ else:
+ # just re-chunk everything into 4kB frames
+ # header len = 4 bytes without masking key and 8 bytes with masking key
+ chunk_size = 4092 if is_server else 4088
+ chunks = range(0, len(payload), chunk_size)
+ for i in chunks:
+ yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False)
+
+ for chunk, final in get_chunk(websocket_message.content):
+ self.connections[other_conn].send_data(chunk, final)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
+
+ else:
+ self.connections[other_conn].send_data(event.data, event.message_finished)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
+
+ elif self.flow.stream:
+ self.connections[other_conn].send_data(event.data, event.message_finished)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
+
return True
def _handle_ping_received(self, event, source_conn, other_conn, is_server):
# PING is automatically answered with a PONG by wsproto
- # TODO: log this PING and its payload
- self.connections[other_conn].ping(event.payload)
+ self.connections[other_conn].ping()
other_conn.send(self.connections[other_conn].bytes_to_send())
+ source_conn.send(self.connections[source_conn].bytes_to_send())
+ self.log(
+ "Ping Received from {}".format("server" if is_server else "client"),
+ "info",
+ [strutils.bytes_to_escaped_str(bytes(event.payload))]
+ )
return True
def _handle_pong_received(self, event, source_conn, other_conn, is_server):
- # TODO: log this PONG and its payload
- self.connections[other_conn].pong(event.payload)
- other_conn.send(self.connections[other_conn].bytes_to_send())
+ self.log(
+ "Pong Received from {}".format("server" if is_server else "client"),
+ "info",
+ [strutils.bytes_to_escaped_str(bytes(event.payload))]
+ )
return True
def _handle_connection_closed(self, event, source_conn, other_conn, is_server):
@@ -109,80 +160,12 @@ class WebSocketLayer(base.Layer):
self.flow.close_code = event.code
self.flow.close_reason = event.reason
- print(self.connections[other_conn])
self.connections[other_conn].close(event.code, event.reason)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
+ source_conn.send(self.connections[source_conn].bytes_to_send())
- # initiate close handshake
return False
- def _handle_connection_failed(self, event):
- raise exceptions.TcpException(repr(event))
-
- # def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
- #
- # fb = self.server_frame_buffer if is_server else self.client_frame_buffer
- # fb.append(frame)
- #
- # if frame.header.fin:
- # payload = b''.join(f.payload for f in fb)
- # original_chunk_sizes = [len(f.payload) for f in fb]
- # message_type = fb[0].header.opcode
- # compressed_message = fb[0].header.rsv1
- # fb.clear()
- #
- # websocket_message = WebSocketMessage(message_type, not is_server, payload)
- # length = len(websocket_message.content)
- # self.flow.messages.append(websocket_message)
- # self.channel.ask("websocket_message", self.flow)
- #
- # if not self.flow.stream:
- # def get_chunk(payload):
- # if len(payload) == length:
- # # message has the same length, we can reuse the same sizes
- # pos = 0
- # for s in original_chunk_sizes:
- # yield payload[pos:pos + s]
- # pos += s
- # else:
- # # just re-chunk everything into 4kB frames
- # # header len = 4 bytes without masking key and 8 bytes with masking key
- # chunk_size = 4092 if is_server else 4088
- # chunks = range(0, len(payload), chunk_size)
- # for i in chunks:
- # yield payload[i:i + chunk_size]
- #
- # frms = [
- # websockets.Frame(
- # payload=chunk,
- # opcode=frame.header.opcode,
- # mask=(False if is_server else 1),
- # masking_key=(b'' if is_server else os.urandom(4)))
- # for chunk in get_chunk(websocket_message.content)
- # ]
- #
- # if len(frms) > 0:
- # frms[-1].header.fin = True
- # else:
- # frms.append(websockets.Frame(
- # fin=True,
- # opcode=websockets.OPCODE.CONTINUE,
- # mask=(False if is_server else 1),
- # masking_key=(b'' if is_server else os.urandom(4))))
- #
- # frms[0].header.opcode = message_type
- # frms[0].header.rsv1 = compressed_message
- #
- # for frm in frms:
- # other_conn.send(bytes(frm))
- #
- # else:
- # other_conn.send(bytes(frame))
- #
- # elif self.flow.stream:
- # other_conn.send(bytes(frame))
- #
- # return True
-
def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
self.flow.metadata['websocket_handshake'] = self.handshake_flow.id
@@ -204,12 +187,12 @@ class WebSocketLayer(base.Layer):
self.connections[source_conn].receive_bytes(bytes(frame))
source_conn.send(self.connections[source_conn].bytes_to_send())
+ if close_received:
+ return
+
for event in self.connections[source_conn].events():
- print('is_server:', is_server, 'event:', event)
if not self._handle_event(event, source_conn, other_conn, is_server):
- if close_received:
- break
- else:
+ if not close_received:
close_received = True
except (socket.error, exceptions.TcpException, SSL.Error) as e:
s = 'server' if is_server else 'client'
diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py
index 1bda219f..8233d45e 100644
--- a/mitmproxy/tools/console/consoleaddons.py
+++ b/mitmproxy/tools/console/consoleaddons.py
@@ -49,7 +49,7 @@ class UnsupportedLog:
def websocket_message(self, f):
message = f.messages[-1]
signals.add_log(f.message_info(message), "info")
- signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug")
+ signals.add_log(message.content if isinstance(message.content, str) else strutils.bytes_to_escaped_str(message.content), "debug")
def websocket_end(self, f):
signals.add_log("WebSocket connection closed by {}: {} {}, {}".format(
diff --git a/setup.cfg b/setup.cfg
index eaabfa12..fd31d15b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -21,7 +21,13 @@ exclude_lines =
[tool:full_coverage]
exclude =
- mitmproxy/proxy/protocol/
+ mitmproxy/proxy/protocol/base.py
+ mitmproxy/proxy/protocol/http.py
+ mitmproxy/proxy/protocol/http1.py
+ mitmproxy/proxy/protocol/http2.py
+ mitmproxy/proxy/protocol/http_replay.py
+ mitmproxy/proxy/protocol/rawtcp.py
+ mitmproxy/proxy/protocol/tls.py
mitmproxy/proxy/root_context.py
mitmproxy/proxy/server.py
mitmproxy/tools/
@@ -64,7 +70,6 @@ exclude =
mitmproxy/proxy/protocol/http_replay.py
mitmproxy/proxy/protocol/rawtcp.py
mitmproxy/proxy/protocol/tls.py
- mitmproxy/proxy/protocol/websocket.py
mitmproxy/proxy/root_context.py
mitmproxy/proxy/server.py
mitmproxy/stateobject.py
diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py
index 14dd7405..a7acdc4d 100644
--- a/test/mitmproxy/proxy/protocol/test_websocket.py
+++ b/test/mitmproxy/proxy/protocol/test_websocket.py
@@ -1,5 +1,6 @@
import pytest
import os
+import struct
import tempfile
import traceback
@@ -33,6 +34,7 @@ class _WebSocketServerBase(net_tservers.ServerTestBase):
connection='upgrade',
upgrade='websocket',
sec_websocket_accept=b'',
+ sec_websocket_extensions='permessage-deflate' if "permessage-deflate" in request.headers.values() else ''
),
content=b'',
)
@@ -80,7 +82,7 @@ class _WebSocketTestBase:
if self.client:
self.client.close()
- def setup_connection(self):
+ def setup_connection(self, extension=False):
self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port))
self.client.connect()
@@ -115,6 +117,7 @@ class _WebSocketTestBase:
upgrade="websocket",
sec_websocket_version="13",
sec_websocket_key="1234",
+ sec_websocket_extensions="permessage-deflate" if extension else ""
),
content=b'')
self.client.wfile.write(http.http1.assemble_request(request))
@@ -145,11 +148,11 @@ class TestSimple(_WebSocketTest):
wfile.flush()
frame = websockets.Frame.from_file(rfile)
- wfile.write(bytes(frame))
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.flush()
frame = websockets.Frame.from_file(rfile)
- wfile.write(bytes(frame))
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.flush()
@pytest.mark.parametrize('streaming', [True, False])
@@ -183,17 +186,40 @@ class TestSimple(_WebSocketTest):
assert isinstance(self.master.state.flows[0], HTTPFlow)
assert isinstance(self.master.state.flows[1], WebSocketFlow)
assert len(self.master.state.flows[1].messages) == 5
- assert self.master.state.flows[1].messages[0].content == b'server-foobar'
+ assert self.master.state.flows[1].messages[0].content == 'server-foobar'
assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT
- assert self.master.state.flows[1].messages[1].content == b'self.client-foobar'
+ assert self.master.state.flows[1].messages[1].content == 'self.client-foobar'
assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT
- assert self.master.state.flows[1].messages[2].content == b'self.client-foobar'
+ assert self.master.state.flows[1].messages[2].content == 'self.client-foobar'
assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT
assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY
assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY
+ def test_change_payload(self):
+ class Addon:
+ def websocket_message(self, f):
+ f.messages[-1].content = "foo"
+
+ self.master.addons.add(Addon())
+ self.setup_connection()
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.payload == b'foo'
+
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
+ self.client.wfile.flush()
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.payload == b'foo'
+
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
+ self.client.wfile.flush()
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.payload == b'foo'
+
class TestSimpleTLS(_WebSocketTest):
ssl = True
@@ -204,7 +230,7 @@ class TestSimpleTLS(_WebSocketTest):
wfile.flush()
frame = websockets.Frame.from_file(rfile)
- wfile.write(bytes(frame))
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.flush()
def test_simple_tls(self):
@@ -237,19 +263,21 @@ class TestPing(_WebSocketTest):
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done')))
wfile.flush()
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ wfile.flush()
+ websockets.Frame.from_file(rfile)
+
def test_ping(self):
self.setup_connection()
frame = websockets.Frame.from_file(self.client.rfile)
- assert frame.header.opcode == websockets.OPCODE.PING
- assert frame.payload == b'foobar'
-
- self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
+ websockets.Frame.from_file(self.client.rfile)
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
+ assert frame.header.opcode == websockets.OPCODE.PING
+ assert frame.payload == b'' # We don't send payload to other end
- frame = websockets.Frame.from_file(self.client.rfile)
- assert frame.header.opcode == websockets.OPCODE.PONG
- assert frame.payload == b'done'
+ assert self.master.has_log("Pong Received from server", "info")
class TestPong(_WebSocketTest):
@@ -258,11 +286,15 @@ class TestPong(_WebSocketTest):
def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile)
assert frame.header.opcode == websockets.OPCODE.PING
- assert frame.payload == b'foobar'
+ assert frame.payload == b''
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
wfile.flush()
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ wfile.flush()
+ websockets.Frame.from_file(rfile)
+
def test_pong(self):
self.setup_connection()
@@ -270,8 +302,13 @@ class TestPong(_WebSocketTest):
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
+ websockets.Frame.from_file(self.client.rfile)
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
+ self.client.wfile.flush()
+
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
+ assert self.master.has_log("Pong Received from server", "info")
class TestClose(_WebSocketTest):
@@ -279,7 +316,7 @@ class TestClose(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile)
- wfile.write(bytes(frame))
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
wfile.flush()
@@ -329,8 +366,9 @@ class TestInvalidFrame(_WebSocketTest):
# with pytest.raises(exceptions.TcpDisconnect):
frame = websockets.Frame.from_file(self.client.rfile)
- assert frame.header.opcode == 15
- assert frame.payload == b'foobar'
+ code, = struct.unpack('!H', frame.payload[:2])
+ assert code == 1002
+ assert frame.payload[2:].startswith(b'Invalid opcode')
class TestStreaming(_WebSocketTest):
@@ -360,3 +398,51 @@ class TestStreaming(_WebSocketTest):
assert frame
assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received
+
+
+class TestExtension(_WebSocketTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(b'\xc1\x0f*N-*K-\xd2M\xcb\xcfOJ,\x02\x00')
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ assert frame.header.rsv1
+ wfile.write(b'\xc1\nJ\xce\xc9L\xcd+\x81r\x00\x00')
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ assert frame.header.rsv1
+ wfile.write(b'\xc2\x07\xba\xb7v\xdf{\x00\x00')
+ wfile.flush()
+
+ def test_extension(self):
+ self.setup_connection(True)
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.header.rsv1
+
+ self.client.wfile.write(b'\xc1\x8fQ\xb7vX\x1by\xbf\x14\x9c\x9c\xa7\x15\x9ax9\x12}\xb5v')
+ self.client.wfile.flush()
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.header.rsv1
+
+ self.client.wfile.write(b'\xc2\x87\xeb\xbb\x0csQ\x0cz\xac\x90\xbb\x0c')
+ self.client.wfile.flush()
+
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.header.rsv1
+
+ assert len(self.master.state.flows[1].messages) == 5
+ assert self.master.state.flows[1].messages[0].content == 'server-foobar'
+ assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT
+ assert self.master.state.flows[1].messages[1].content == 'client-foobar'
+ assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT
+ assert self.master.state.flows[1].messages[2].content == 'client-foobar'
+ assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT
+ assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
+ assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY
+ assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
+ assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY