aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Kriechbaumer <thomas@kriechbaumer.name>2017-08-12 14:06:10 +0200
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2017-12-12 22:09:46 +0100
commit130021b76d781f0ebb43928aa8083c8b8d560882 (patch)
tree7891a1c87b82231af01c0edb9d7842d73ba5449d
parent8e9194c2b4b8c1b82832cdab1b364f3300e2d3fd (diff)
downloadmitmproxy-130021b76d781f0ebb43928aa8083c8b8d560882.tar.gz
mitmproxy-130021b76d781f0ebb43928aa8083c8b8d560882.tar.bz2
mitmproxy-130021b76d781f0ebb43928aa8083c8b8d560882.zip
prepare WebSocket stack to move to wsproto
-rw-r--r--mitmproxy/proxy/protocol/websocket.py245
-rw-r--r--test/mitmproxy/proxy/protocol/test_websocket.py26
2 files changed, 154 insertions, 117 deletions
diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py
index 19546eb2..d1abd134 100644
--- a/mitmproxy/proxy/protocol/websocket.py
+++ b/mitmproxy/proxy/protocol/websocket.py
@@ -3,9 +3,14 @@ import socket
import struct
from OpenSSL import SSL
+from wsproto import events
+from wsproto.connection import ConnectionType, WSConnection
+from wsproto.extensions import PerMessageDeflate
+
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy.proxy.protocol import base
+from mitmproxy.net import http
from mitmproxy.net import tcp
from mitmproxy.net import websockets
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
@@ -44,108 +49,139 @@ class WebSocketLayer(base.Layer):
self.client_frame_buffer = []
self.server_frame_buffer = []
- def _handle_frame(self, frame, source_conn, other_conn, is_server):
- if frame.header.opcode & 0x8 == 0:
- return self._handle_data_frame(frame, source_conn, other_conn, is_server)
- elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
- return self._handle_ping_pong(frame, source_conn, other_conn, is_server)
- elif frame.header.opcode == websockets.OPCODE.CLOSE:
- return self._handle_close(frame, source_conn, other_conn, is_server)
- else:
- return self._handle_unknown_frame(frame, source_conn, other_conn, is_server)
-
- def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
-
- fb = self.server_frame_buffer if is_server else self.client_frame_buffer
- fb.append(frame)
-
- if frame.header.fin:
- payload = b''.join(f.payload for f in fb)
- original_chunk_sizes = [len(f.payload) for f in fb]
- message_type = fb[0].header.opcode
- compressed_message = fb[0].header.rsv1
- fb.clear()
-
- websocket_message = WebSocketMessage(message_type, not is_server, payload)
- length = len(websocket_message.content)
- self.flow.messages.append(websocket_message)
- self.channel.ask("websocket_message", self.flow)
-
- if not self.flow.stream:
- def get_chunk(payload):
- if len(payload) == length:
- # message has the same length, we can reuse the same sizes
- pos = 0
- for s in original_chunk_sizes:
- yield payload[pos:pos + s]
- pos += s
- else:
- # just re-chunk everything into 4kB frames
- # header len = 4 bytes without masking key and 8 bytes with masking key
- chunk_size = 4092 if is_server else 4088
- chunks = range(0, len(payload), chunk_size)
- for i in chunks:
- yield payload[i:i + chunk_size]
-
- frms = [
- websockets.Frame(
- payload=chunk,
- opcode=frame.header.opcode,
- mask=(False if is_server else 1),
- masking_key=(b'' if is_server else os.urandom(4)))
- for chunk in get_chunk(websocket_message.content)
- ]
-
- if len(frms) > 0:
- frms[-1].header.fin = True
- else:
- frms.append(websockets.Frame(
- fin=True,
- opcode=websockets.OPCODE.CONTINUE,
- mask=(False if is_server else 1),
- masking_key=(b'' if is_server else os.urandom(4))))
-
- frms[0].header.opcode = message_type
- frms[0].header.rsv1 = compressed_message
-
- for frm in frms:
- other_conn.send(bytes(frm))
-
- else:
- other_conn.send(bytes(frame))
-
- elif self.flow.stream:
- other_conn.send(bytes(frame))
+ self.connections = {} # type: Dict[object, WSConnection]
+
+ extensions = []
+ if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
+ if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
+ extensions = [PerMessageDeflate.name]
+
+ self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
+ extensions=extensions)
+ self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
+ host=handshake_flow.request.host,
+ resource=handshake_flow.request.path,
+ extensions=extensions)
+
+ data = self.connections[self.server_conn].bytes_to_send()
+ self.connections[self.client_conn].receive_bytes(data)
+
+ event = next(self.connections[self.client_conn].events())
+ assert isinstance(event, events.ConnectionRequested)
+
+ self.connections[self.client_conn].accept(event)
+ self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send())
+ assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished)
+
+ def _handle_event(self, event, source_conn, other_conn, is_server):
+ if isinstance(event, events.DataReceived):
+ return self._handle_data_received(event, source_conn, other_conn, is_server)
+ elif isinstance(event, events.PingReceived):
+ return self._handle_ping_received(event, source_conn, other_conn, is_server)
+ elif isinstance(event, events.PongReceived):
+ return self._handle_pong_received(event, source_conn, other_conn, is_server)
+ elif isinstance(event, events.ConnectionFailed):
+ return self._handle_connection_closed(event, source_conn, other_conn, is_server)
+ elif isinstance(event, events.ConnectionFailed):
+ return self._handle_connection_failed(event)
+
+ # fail-safe for unhandled events
+ return True
+
+ def _handle_data_received(self, event, source_conn, other_conn, is_server):
+ return True
+ def _handle_ping_received(self, event, source_conn, other_conn, is_server):
+ # PING is automatically answered with a PONG by wsproto
+ # TODO: log this PING and its payload
+ self.connections[other_conn].ping(event.payload)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
return True
- def _handle_ping_pong(self, frame, source_conn, other_conn, is_server):
- # just forward the ping/pong to the other side
- other_conn.send(bytes(frame))
+ def _handle_pong_received(self, event, source_conn, other_conn, is_server):
+ # TODO: log this PONG and its payload
+ self.connections[other_conn].pong(event.payload)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
return True
- def _handle_close(self, frame, source_conn, other_conn, is_server):
+ def _handle_connection_closed(self, event, source_conn, other_conn, is_server):
self.flow.close_sender = "server" if is_server else "client"
- if len(frame.payload) >= 2:
- code, = struct.unpack('!H', frame.payload[:2])
- self.flow.close_code = code
- self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
- if len(frame.payload) > 2:
- self.flow.close_reason = frame.payload[2:]
+ self.flow.close_code = event.code
+ self.flow.close_reason = event.reason
- other_conn.send(bytes(frame))
+ print(self.connections[other_conn])
+ self.connections[other_conn].close(event.code, event.reason)
# initiate close handshake
return False
- def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
- # unknown frame - just forward it
- other_conn.send(bytes(frame))
-
- sender = "server" if is_server else "client"
- self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])
-
- return True
+ def _handle_connection_failed(self, event):
+ raise exceptions.TcpException(repr(event))
+
+ # def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
+ #
+ # fb = self.server_frame_buffer if is_server else self.client_frame_buffer
+ # fb.append(frame)
+ #
+ # if frame.header.fin:
+ # payload = b''.join(f.payload for f in fb)
+ # original_chunk_sizes = [len(f.payload) for f in fb]
+ # message_type = fb[0].header.opcode
+ # compressed_message = fb[0].header.rsv1
+ # fb.clear()
+ #
+ # websocket_message = WebSocketMessage(message_type, not is_server, payload)
+ # length = len(websocket_message.content)
+ # self.flow.messages.append(websocket_message)
+ # self.channel.ask("websocket_message", self.flow)
+ #
+ # if not self.flow.stream:
+ # def get_chunk(payload):
+ # if len(payload) == length:
+ # # message has the same length, we can reuse the same sizes
+ # pos = 0
+ # for s in original_chunk_sizes:
+ # yield payload[pos:pos + s]
+ # pos += s
+ # else:
+ # # just re-chunk everything into 4kB frames
+ # # header len = 4 bytes without masking key and 8 bytes with masking key
+ # chunk_size = 4092 if is_server else 4088
+ # chunks = range(0, len(payload), chunk_size)
+ # for i in chunks:
+ # yield payload[i:i + chunk_size]
+ #
+ # frms = [
+ # websockets.Frame(
+ # payload=chunk,
+ # opcode=frame.header.opcode,
+ # mask=(False if is_server else 1),
+ # masking_key=(b'' if is_server else os.urandom(4)))
+ # for chunk in get_chunk(websocket_message.content)
+ # ]
+ #
+ # if len(frms) > 0:
+ # frms[-1].header.fin = True
+ # else:
+ # frms.append(websockets.Frame(
+ # fin=True,
+ # opcode=websockets.OPCODE.CONTINUE,
+ # mask=(False if is_server else 1),
+ # masking_key=(b'' if is_server else os.urandom(4))))
+ #
+ # frms[0].header.opcode = message_type
+ # frms[0].header.rsv1 = compressed_message
+ #
+ # for frm in frms:
+ # other_conn.send(bytes(frm))
+ #
+ # else:
+ # other_conn.send(bytes(frame))
+ #
+ # elif self.flow.stream:
+ # other_conn.send(bytes(frame))
+ #
+ # return True
def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
@@ -153,27 +189,28 @@ class WebSocketLayer(base.Layer):
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
self.channel.ask("websocket_start", self.flow)
- client = self.client_conn.connection
- server = self.server_conn.connection
- conns = [client, server]
+ conns = [c.connection for c in self.connections.keys()]
close_received = False
try:
while not self.channel.should_exit.is_set():
r = tcp.ssl_read_select(conns, 0.1)
for conn in r:
- source_conn = self.client_conn if conn == client else self.server_conn
- other_conn = self.server_conn if conn == client else self.client_conn
- is_server = (conn == self.server_conn.connection)
+ source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
+ other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
+ is_server = (source_conn == self.server_conn)
frame = websockets.Frame.from_file(source_conn.rfile)
-
- cont = self._handle_frame(frame, source_conn, other_conn, is_server)
- if not cont:
- if close_received:
- return
- else:
- close_received = True
+ self.connections[source_conn].receive_bytes(bytes(frame))
+ source_conn.send(self.connections[source_conn].bytes_to_send())
+
+ for event in self.connections[source_conn].events():
+ print('is_server:', is_server, 'event:', event)
+ if not self._handle_event(event, source_conn, other_conn, is_server):
+ if close_received:
+ break
+ else:
+ close_received = True
except (socket.error, exceptions.TcpException, SSL.Error) as e:
s = 'server' if is_server else 'client'
self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e)))
diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py
index 460d85f8..14dd7405 100644
--- a/test/mitmproxy/proxy/protocol/test_websocket.py
+++ b/test/mitmproxy/proxy/protocol/test_websocket.py
@@ -164,19 +164,19 @@ class TestSimple(_WebSocketTest):
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'server-foobar'
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'self.client-foobar'
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'\xde\xad\xbe\xef'
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
assert len(self.master.state.flows) == 2
@@ -213,13 +213,13 @@ class TestSimpleTLS(_WebSocketTest):
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'server-foobar'
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'self.client-foobar'
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
@@ -234,7 +234,7 @@ class TestPing(_WebSocketTest):
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
- wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done')))
wfile.flush()
def test_ping(self):
@@ -244,12 +244,12 @@ class TestPing(_WebSocketTest):
assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'foobar'
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
- assert frame.header.opcode == websockets.OPCODE.TEXT
- assert frame.payload == b'pong-received'
+ assert frame.header.opcode == websockets.OPCODE.PONG
+ assert frame.payload == b'done'
class TestPong(_WebSocketTest):
@@ -266,7 +266,7 @@ class TestPong(_WebSocketTest):
def test_pong(self):
self.setup_connection()
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
@@ -289,7 +289,7 @@ class TestClose(_WebSocketTest):
def test_close(self):
self.setup_connection()
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile)
@@ -299,7 +299,7 @@ class TestClose(_WebSocketTest):
def test_close_payload_1(self):
self.setup_connection()
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile)
@@ -309,7 +309,7 @@ class TestClose(_WebSocketTest):
def test_close_payload_2(self):
self.setup_connection()
- self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile)