aboutsummaryrefslogtreecommitdiffstats
path: root/mitmproxy/proxy
diff options
context:
space:
mode:
authorUjjwal Verma <ujjwalverma1111@gmail.com>2017-08-17 21:12:07 +0530
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2017-12-12 22:09:46 +0100
commit5214f544e7b690dea2a45cb4cda44bbffec9a77e (patch)
treecaaa4c597403214f39a855d2b511341d54095a9e /mitmproxy/proxy
parent130021b76d781f0ebb43928aa8083c8b8d560882 (diff)
downloadmitmproxy-5214f544e7b690dea2a45cb4cda44bbffec9a77e.tar.gz
mitmproxy-5214f544e7b690dea2a45cb4cda44bbffec9a77e.tar.bz2
mitmproxy-5214f544e7b690dea2a45cb4cda44bbffec9a77e.zip
Use wsproto for websockets
Diffstat (limited to 'mitmproxy/proxy')
-rw-r--r--mitmproxy/proxy/protocol/websocket.py159
1 files changed, 71 insertions, 88 deletions
diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py
index d1abd134..54d8120d 100644
--- a/mitmproxy/proxy/protocol/websocket.py
+++ b/mitmproxy/proxy/protocol/websocket.py
@@ -1,19 +1,18 @@
-import os
import socket
-import struct
from OpenSSL import SSL
from wsproto import events
from wsproto.connection import ConnectionType, WSConnection
from wsproto.extensions import PerMessageDeflate
+from wsproto.frame_protocol import Opcode
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy.proxy.protocol import base
-from mitmproxy.net import http
from mitmproxy.net import tcp
from mitmproxy.net import websockets
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
+from mitmproxy.utils import strutils
class WebSocketLayer(base.Layer):
@@ -54,14 +53,16 @@ class WebSocketLayer(base.Layer):
extensions = []
if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
- extensions = [PerMessageDeflate.name]
-
+ extensions = [PerMessageDeflate()]
self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
extensions=extensions)
self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
host=handshake_flow.request.host,
resource=handshake_flow.request.path,
extensions=extensions)
+ if extensions:
+ for conn in self.connections.values():
+ conn.extensions[0].finalize(conn, handshake_flow.response.headers['Sec-WebSocket-Extensions'])
data = self.connections[self.server_conn].bytes_to_send()
self.connections[self.client_conn].receive_bytes(data)
@@ -80,28 +81,78 @@ class WebSocketLayer(base.Layer):
return self._handle_ping_received(event, source_conn, other_conn, is_server)
elif isinstance(event, events.PongReceived):
return self._handle_pong_received(event, source_conn, other_conn, is_server)
- elif isinstance(event, events.ConnectionFailed):
+ elif isinstance(event, events.ConnectionClosed):
return self._handle_connection_closed(event, source_conn, other_conn, is_server)
- elif isinstance(event, events.ConnectionFailed):
- return self._handle_connection_failed(event)
# fail-safe for unhandled events
- return True
+ return True # pragma: no cover
def _handle_data_received(self, event, source_conn, other_conn, is_server):
+ fb = self.server_frame_buffer if is_server else self.client_frame_buffer
+ fb.append(event.data)
+
+ if event.message_finished:
+ original_chunk_sizes = [len(f) for f in fb]
+ message_type = Opcode.TEXT if isinstance(event, events.TextReceived) else Opcode.BINARY
+ if message_type == Opcode.TEXT:
+ payload = ''.join(fb)
+ else:
+ payload = b''.join(fb)
+ fb.clear()
+
+ websocket_message = WebSocketMessage(message_type, not is_server, payload)
+ length = len(websocket_message.content)
+ self.flow.messages.append(websocket_message)
+ self.channel.ask("websocket_message", self.flow)
+
+ if not self.flow.stream:
+ def get_chunk(payload):
+ if len(payload) == length:
+ # message has the same length, we can reuse the same sizes
+ pos = 0
+ for s in original_chunk_sizes:
+ yield (payload[pos:pos + s], True if pos + s == length else False)
+ pos += s
+ else:
+ # just re-chunk everything into 4kB frames
+ # header len = 4 bytes without masking key and 8 bytes with masking key
+ chunk_size = 4092 if is_server else 4088
+ chunks = range(0, len(payload), chunk_size)
+ for i in chunks:
+ yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False)
+
+ for chunk, final in get_chunk(websocket_message.content):
+ self.connections[other_conn].send_data(chunk, final)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
+
+ else:
+ self.connections[other_conn].send_data(event.data, event.message_finished)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
+
+ elif self.flow.stream:
+ self.connections[other_conn].send_data(event.data, event.message_finished)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
+
return True
def _handle_ping_received(self, event, source_conn, other_conn, is_server):
# PING is automatically answered with a PONG by wsproto
- # TODO: log this PING and its payload
- self.connections[other_conn].ping(event.payload)
+ self.connections[other_conn].ping()
other_conn.send(self.connections[other_conn].bytes_to_send())
+ source_conn.send(self.connections[source_conn].bytes_to_send())
+ self.log(
+ "Ping Received from {}".format("server" if is_server else "client"),
+ "info",
+ [strutils.bytes_to_escaped_str(bytes(event.payload))]
+ )
return True
def _handle_pong_received(self, event, source_conn, other_conn, is_server):
- # TODO: log this PONG and its payload
- self.connections[other_conn].pong(event.payload)
- other_conn.send(self.connections[other_conn].bytes_to_send())
+ self.log(
+ "Pong Received from {}".format("server" if is_server else "client"),
+ "info",
+ [strutils.bytes_to_escaped_str(bytes(event.payload))]
+ )
return True
def _handle_connection_closed(self, event, source_conn, other_conn, is_server):
@@ -109,80 +160,12 @@ class WebSocketLayer(base.Layer):
self.flow.close_code = event.code
self.flow.close_reason = event.reason
- print(self.connections[other_conn])
self.connections[other_conn].close(event.code, event.reason)
+ other_conn.send(self.connections[other_conn].bytes_to_send())
+ source_conn.send(self.connections[source_conn].bytes_to_send())
- # initiate close handshake
return False
- def _handle_connection_failed(self, event):
- raise exceptions.TcpException(repr(event))
-
- # def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
- #
- # fb = self.server_frame_buffer if is_server else self.client_frame_buffer
- # fb.append(frame)
- #
- # if frame.header.fin:
- # payload = b''.join(f.payload for f in fb)
- # original_chunk_sizes = [len(f.payload) for f in fb]
- # message_type = fb[0].header.opcode
- # compressed_message = fb[0].header.rsv1
- # fb.clear()
- #
- # websocket_message = WebSocketMessage(message_type, not is_server, payload)
- # length = len(websocket_message.content)
- # self.flow.messages.append(websocket_message)
- # self.channel.ask("websocket_message", self.flow)
- #
- # if not self.flow.stream:
- # def get_chunk(payload):
- # if len(payload) == length:
- # # message has the same length, we can reuse the same sizes
- # pos = 0
- # for s in original_chunk_sizes:
- # yield payload[pos:pos + s]
- # pos += s
- # else:
- # # just re-chunk everything into 4kB frames
- # # header len = 4 bytes without masking key and 8 bytes with masking key
- # chunk_size = 4092 if is_server else 4088
- # chunks = range(0, len(payload), chunk_size)
- # for i in chunks:
- # yield payload[i:i + chunk_size]
- #
- # frms = [
- # websockets.Frame(
- # payload=chunk,
- # opcode=frame.header.opcode,
- # mask=(False if is_server else 1),
- # masking_key=(b'' if is_server else os.urandom(4)))
- # for chunk in get_chunk(websocket_message.content)
- # ]
- #
- # if len(frms) > 0:
- # frms[-1].header.fin = True
- # else:
- # frms.append(websockets.Frame(
- # fin=True,
- # opcode=websockets.OPCODE.CONTINUE,
- # mask=(False if is_server else 1),
- # masking_key=(b'' if is_server else os.urandom(4))))
- #
- # frms[0].header.opcode = message_type
- # frms[0].header.rsv1 = compressed_message
- #
- # for frm in frms:
- # other_conn.send(bytes(frm))
- #
- # else:
- # other_conn.send(bytes(frame))
- #
- # elif self.flow.stream:
- # other_conn.send(bytes(frame))
- #
- # return True
-
def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
self.flow.metadata['websocket_handshake'] = self.handshake_flow.id
@@ -204,12 +187,12 @@ class WebSocketLayer(base.Layer):
self.connections[source_conn].receive_bytes(bytes(frame))
source_conn.send(self.connections[source_conn].bytes_to_send())
+ if close_received:
+ return
+
for event in self.connections[source_conn].events():
- print('is_server:', is_server, 'event:', event)
if not self._handle_event(event, source_conn, other_conn, is_server):
- if close_received:
- break
- else:
+ if not close_received:
close_received = True
except (socket.error, exceptions.TcpException, SSL.Error) as e:
s = 'server' if is_server else 'client'