aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Kriechbaumer <thomas@kriechbaumer.name>2016-11-13 17:50:51 +0100
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2016-11-23 10:18:45 +0100
commit3d8f3d4c239c0b5da5bd5fcc3fddd0fed72815d3 (patch)
treed8f11f44845a47c3183a362c8b10514aec7dc628
parentffb3988dc9ef3f7f8137b913edb7986e148e0dc4 (diff)
downloadmitmproxy-3d8f3d4c239c0b5da5bd5fcc3fddd0fed72815d3.tar.gz
mitmproxy-3d8f3d4c239c0b5da5bd5fcc3fddd0fed72815d3.tar.bz2
mitmproxy-3d8f3d4c239c0b5da5bd5fcc3fddd0fed72815d3.zip
add WebSocket flows and messages
-rw-r--r--mitmproxy/addons/dumper.py18
-rw-r--r--mitmproxy/events.py16
-rw-r--r--mitmproxy/io.py2
-rw-r--r--mitmproxy/master.py16
-rw-r--r--mitmproxy/proxy/protocol/websocket.py146
-rw-r--r--mitmproxy/tcp.py4
-rw-r--r--mitmproxy/tools/console/master.py7
-rw-r--r--mitmproxy/websocket.py83
8 files changed, 238 insertions, 54 deletions
diff --git a/mitmproxy/addons/dumper.py b/mitmproxy/addons/dumper.py
index 89a9eab8..68d59b2d 100644
--- a/mitmproxy/addons/dumper.py
+++ b/mitmproxy/addons/dumper.py
@@ -223,6 +223,21 @@ class Dumper:
if self.match(f):
self.echo_flow(f)
+ def websocket_error(self, f):
+ self.echo(
+ "Error in WebSocket connection to {}: {}".format(
+ repr(f.server_conn.address), f.error
+ ),
+ fg="red"
+ )
+
+ def websocket_message(self, f):
+ if self.match(f):
+ message = f.messages[-1]
+ self.echo(message.info)
+ if self.flow_detail >= 3:
+ self._echo_message(message)
+
def tcp_error(self, f):
self.echo(
"Error in TCP connection to {}: {}".format(
@@ -240,4 +255,5 @@ class Dumper:
server=repr(f.server_conn.address),
direction=direction,
))
- self._echo_message(message)
+ if self.flow_detail >= 3:
+ self._echo_message(message)
diff --git a/mitmproxy/events.py b/mitmproxy/events.py
index f9475768..f144b412 100644
--- a/mitmproxy/events.py
+++ b/mitmproxy/events.py
@@ -1,6 +1,7 @@
from mitmproxy import controller
from mitmproxy import http
from mitmproxy import tcp
+from mitmproxy import websocket
Events = frozenset([
"clientconnect",
@@ -24,6 +25,10 @@ Events = frozenset([
"resume",
"websocket_handshake",
+ "websocket_start",
+ "websocket_message",
+ "websocket_error",
+ "websocket_end",
"next_layer",
@@ -45,6 +50,17 @@ def event_sequence(f):
yield "response", f
if f.error:
yield "error", f
+ elif isinstance(f, websocket.WebSocketFlow):
+ messages = f.messages
+ f.messages = []
+ f.reply = controller.DummyReply()
+ yield "websocket_start", f
+ while messages:
+ f.messages.append(messages.pop(0))
+ yield "websocket_message", f
+ if f.error:
+ yield "websocket_error", f
+ yield "websocket_end", f
elif isinstance(f, tcp.TCPFlow):
messages = f.messages
f.messages = []
diff --git a/mitmproxy/io.py b/mitmproxy/io.py
index 27ffa036..ad2f00c4 100644
--- a/mitmproxy/io.py
+++ b/mitmproxy/io.py
@@ -4,12 +4,14 @@ from mitmproxy import exceptions
from mitmproxy import flowfilter
from mitmproxy import http
from mitmproxy import tcp
+from mitmproxy import websocket
from mitmproxy.contrib import tnetstring
from mitmproxy import io_compat
FLOW_TYPES = dict(
http=http.HTTPFlow,
+ websocket=websocket.WebSocketFlow,
tcp=tcp.TCPFlow,
)
diff --git a/mitmproxy/master.py b/mitmproxy/master.py
index 55eb74e5..7f114096 100644
--- a/mitmproxy/master.py
+++ b/mitmproxy/master.py
@@ -284,6 +284,22 @@ class Master:
pass
@controller.handler
+ def websocket_start(self, flow):
+ pass
+
+ @controller.handler
+ def websocket_message(self, flow):
+ pass
+
+ @controller.handler
+ def websocket_error(self, flow):
+ pass
+
+ @controller.handler
+ def websocket_end(self, flow):
+ pass
+
+ @controller.handler
def tcp_start(self, flow):
pass
diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py
index 47628013..31521882 100644
--- a/mitmproxy/proxy/protocol/websocket.py
+++ b/mitmproxy/proxy/protocol/websocket.py
@@ -1,18 +1,23 @@
+import os
import socket
import struct
from OpenSSL import SSL
+
+
from mitmproxy import exceptions
+from mitmproxy import flow
from mitmproxy.proxy.protocol import base
from mitmproxy.utils import strutils
from mitmproxy.net import tcp
from mitmproxy.net import websockets
+from mitmproxy.websocket import WebSocketFlow, WebSocketBinaryMessage, WebSocketTextMessage
class WebSocketLayer(base.Layer):
"""
- WebSocket layer to intercept, modify, and forward WebSocket connections
+ WebSocket layer to intercept, modify, and forward WebSocket messages.
- Only version 13 is supported (as specified in RFC6455)
+ Only version 13 is supported (as specified in RFC6455).
Only HTTP/1.1-initiated connections are supported.
The client starts by sending an Upgrade-request.
@@ -29,65 +34,106 @@ class WebSocketLayer(base.Layer):
This layer is transparent to any negotiated extensions.
This layer is transparent to any negotiated subprotocols.
Only raw frames are forwarded to the other endpoint.
+
+ WebSocket messages are stored in a WebSocketFlow.
"""
- def __init__(self, ctx, flow):
+ def __init__(self, ctx, handshake_flow):
super().__init__(ctx)
- self._flow = flow
+ self.handshake_flow = handshake_flow
+ self.flow = None # type: WebSocketFlow
- self.client_key = websockets.get_client_key(self._flow.request.headers)
- self.client_protocol = websockets.get_protocol(self._flow.request.headers)
- self.client_extensions = websockets.get_extensions(self._flow.request.headers)
-
- self.server_accept = websockets.get_server_accept(self._flow.response.headers)
- self.server_protocol = websockets.get_protocol(self._flow.response.headers)
- self.server_extensions = websockets.get_extensions(self._flow.response.headers)
+ self.client_frame_buffer = []
+ self.server_frame_buffer = []
def _handle_frame(self, frame, source_conn, other_conn, is_server):
- sender = "server" if is_server else "client"
- self.log(
- "WebSocket frame received from {}".format(sender),
- "debug",
- [repr(frame)]
- )
+ # sender = "server" if is_server else "client"
+ # self.log(
+ # "WebSocket frame received from {}".format(sender),
+ # "debug",
+ # [repr(frame)]
+ # )
if frame.header.opcode & 0x8 == 0:
- self.log(
- "{direction} websocket {direction} {server}".format(
- server=repr(self.server_conn.address),
- direction="<-" if is_server else "->",
- ),
- "info",
- strutils.bytes_to_escaped_str(frame.payload, keep_spacing=True).splitlines()
- )
- # forward the data frame to the other side
- other_conn.send(bytes(frame))
+ return self._handle_data_frame(frame, source_conn, other_conn, is_server)
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
- # just forward the ping/pong to the other side
- other_conn.send(bytes(frame))
+ return self._handle_ping_pong(frame, source_conn, other_conn, is_server)
elif frame.header.opcode == websockets.OPCODE.CLOSE:
- code = '(status code missing)'
- msg = None
- reason = '(message missing)'
- if len(frame.payload) >= 2:
- code, = struct.unpack('!H', frame.payload[:2])
- msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
- if len(frame.payload) > 2:
- reason = frame.payload[2:]
- self.log("WebSocket connection closed by {}: {} {}, {}".format(sender, code, msg, reason), "info")
-
- other_conn.send(bytes(frame))
- # close the connection
- return False
+ return self._handle_close(frame, source_conn, other_conn, is_server)
else:
- self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])
- # unknown frame - just forward it
- other_conn.send(bytes(frame))
+ return self._handle_unknown_frame(frame, source_conn, other_conn, is_server)
+
+ def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
+ fb = self.server_frame_buffer if is_server else self.client_frame_buffer
+ fb.append(frame)
+
+ if frame.header.fin:
+ if frame.header.opcode == websockets.OPCODE.TEXT:
+ t = WebSocketTextMessage
+ else:
+ t = WebSocketBinaryMessage
+
+ payload = b''.join(f.payload for f in fb)
+ fb.clear()
+
+ websocket_message = t(self.flow, not is_server, payload)
+ self.flow.messages.append(websocket_message)
+ self.channel.ask("websocket_message", self.flow)
+
+ # chunk payload into multiple 10kB frames, and send them
+ payload = websocket_message.content
+ chunk_size = 10240 # 10kB
+ chunks = range(0, len(payload), chunk_size)
+ frms = [
+ websockets.Frame(
+ payload=payload[i:i + chunk_size],
+ opcode=frame.header.opcode,
+ mask=(False if is_server else 1),
+ masking_key=(b'' if is_server else os.urandom(4))) for i in chunks
+ ]
+ frms[-1].header.fin = 1
+
+ for frm in frms:
+ other_conn.send(bytes(frm))
+
+ return True
+
+ def _handle_ping_pong(self, frame, source_conn, other_conn, is_server):
+ # just forward the ping/pong to the other side
+ other_conn.send(bytes(frame))
+ return True
+
+ def _handle_close(self, frame, source_conn, other_conn, is_server):
+ code = '(status code missing)'
+ msg = None
+ reason = '(message missing)'
+ if len(frame.payload) >= 2:
+ code, = struct.unpack('!H', frame.payload[:2])
+ msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
+ if len(frame.payload) > 2:
+ reason = frame.payload[2:]
+
+ other_conn.send(bytes(frame))
+
+ sender = "server" if is_server else "client"
+ self.log("WebSocket connection closed by {}: {} {}, {}".format(sender, code, msg, reason), "info")
+
+ # close the connection
+ return False
+
+ def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
+ # unknown frame - just forward it
+ other_conn.send(bytes(frame))
+
+ sender = "server" if is_server else "client"
+ self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])
- # continue the connection
return True
def __call__(self):
+ self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
+ self.channel.ask("websocket_start", self.flow)
+
client = self.client_conn.connection
server = self.server_conn.connection
conns = [client, server]
@@ -105,7 +151,7 @@ class WebSocketLayer(base.Layer):
if not self._handle_frame(frame, source_conn, other_conn, is_server):
return
except (socket.error, exceptions.TcpException, SSL.Error) as e:
- self.log("WebSocket connection closed unexpectedly by {}: {}".format(
- "server" if is_server else "client", repr(e)), "info")
- except Exception as e: # pragma: no cover
- raise exceptions.ProtocolException("Error in WebSocket connection: {}".format(repr(e)))
+ self.flow.error = flow.Error("WebSocket connection closed unexpectedly: {}".format(repr(e)))
+ self.channel.tell("websocket_error", self.flow)
+ finally:
+ self.channel.tell("websocket_end", self.flow)
diff --git a/mitmproxy/tcp.py b/mitmproxy/tcp.py
index d73be98d..3f10f82b 100644
--- a/mitmproxy/tcp.py
+++ b/mitmproxy/tcp.py
@@ -11,9 +11,7 @@ class TCPMessage(serializable.Serializable):
def __init__(self, from_client, content, timestamp=None):
self.content = content
self.from_client = from_client
- if timestamp is None:
- timestamp = time.time()
- self.timestamp = timestamp
+ self.timestamp = timestamp or time.time()
@classmethod
def from_state(cls, state):
diff --git a/mitmproxy/tools/console/master.py b/mitmproxy/tools/console/master.py
index f8850404..99c61825 100644
--- a/mitmproxy/tools/console/master.py
+++ b/mitmproxy/tools/console/master.py
@@ -447,6 +447,13 @@ class ConsoleMaster(master.Master):
# Handlers
@controller.handler
+ def websocket_message(self, f):
+ super().websocket_message(f)
+ message = f.messages[-1]
+ self.add_log(message.info, "info")
+ self.add_log(strutils.bytes_to_escaped_str(message.content), "debug")
+
+ @controller.handler
def tcp_message(self, f):
super().tcp_message(f)
message = f.messages[-1]
diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py
new file mode 100644
index 00000000..eed943cd
--- /dev/null
+++ b/mitmproxy/websocket.py
@@ -0,0 +1,83 @@
+import time
+
+from typing import List
+
+from mitmproxy import flow
+from mitmproxy.http import HTTPFlow
+from mitmproxy.net import websockets
+from mitmproxy.utils import strutils
+from mitmproxy.types import serializable
+
+
+class WebSocketMessage(serializable.Serializable):
+
+ def __init__(self, flow, from_client, content, timestamp=None):
+ self.flow = flow
+ self.content = content
+ self.from_client = from_client
+ self.timestamp = timestamp or time.time()
+
+ @classmethod
+ def from_state(cls, state):
+ return cls(*state)
+
+ def get_state(self):
+ return self.from_client, self.content, self.timestamp
+
+ def set_state(self, state):
+ self.from_client = state.pop("from_client")
+ self.content = state.pop("content")
+ self.timestamp = state.pop("timestamp")
+
+ @property
+ def info(self):
+ return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format(
+ type=self.type,
+ client=repr(self.flow.client_conn.address),
+ server=repr(self.flow.server_conn.address),
+ direction="->" if self.from_client else "<-",
+ endpoint=self.flow.handshake_flow.request.path,
+ )
+
+
+class WebSocketBinaryMessage(WebSocketMessage):
+
+ type = 'binary'
+
+ def __repr__(self):
+ return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content))
+
+class WebSocketTextMessage(WebSocketMessage):
+
+ type = 'text'
+
+ def __repr__(self):
+ return "text message: {}".format(repr(self.content))
+
+
+class WebSocketFlow(flow.Flow):
+
+ """
+ A WebsocketFlow is a simplified representation of a Websocket session.
+ """
+
+ def __init__(self, client_conn, server_conn, handshake_flow, live=None):
+ super().__init__("websocket", client_conn, server_conn, live)
+ self.messages = [] # type: List[WebSocketMessage]
+ self.handshake_flow = handshake_flow
+ self.client_key = websockets.get_client_key(self.handshake_flow.request.headers)
+ self.client_protocol = websockets.get_protocol(self.handshake_flow.request.headers)
+ self.client_extensions = websockets.get_extensions(self.handshake_flow.request.headers)
+ self.server_accept = websockets.get_server_accept(self.handshake_flow.response.headers)
+ self.server_protocol = websockets.get_protocol(self.handshake_flow.response.headers)
+ self.server_extensions = websockets.get_extensions(self.handshake_flow.response.headers)
+
+
+ _stateobject_attributes = flow.Flow._stateobject_attributes.copy()
+ _stateobject_attributes.update(
+ messages=List[WebSocketMessage],
+ handshake_flow=HTTPFlow,
+ )
+
+ def __repr__(self):
+ return "<WebSocketFlow ({} messages)>".format(len(self.messages))