aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Kriechbaumer <thomas@kriechbaumer.name>2017-03-07 18:32:56 +0100
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2017-03-09 19:08:59 +0100
commitb1dd86d7ae7d6692eb3004a5e4b8bf6504af2635 (patch)
tree390188404fb81c409ab2ab17eb99cd2726655835
parent98b589385519eb6b27f8be89bb1ba45940d45245 (diff)
downloadmitmproxy-b1dd86d7ae7d6692eb3004a5e4b8bf6504af2635.tar.gz
mitmproxy-b1dd86d7ae7d6692eb3004a5e4b8bf6504af2635.tar.bz2
mitmproxy-b1dd86d7ae7d6692eb3004a5e4b8bf6504af2635.zip
make websocket flows serializable
fixes #2113
-rw-r--r--mitmproxy/flow.py2
-rw-r--r--mitmproxy/proxy/protocol/websocket.py1
-rw-r--r--mitmproxy/stateobject.py8
-rw-r--r--mitmproxy/test/tflow.py1
-rw-r--r--mitmproxy/websocket.py51
-rw-r--r--test/mitmproxy/test_stateobject.py22
-rw-r--r--test/mitmproxy/test_websocket.py15
7 files changed, 69 insertions, 31 deletions
diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py
index 5ef957c9..cc5f0aed 100644
--- a/mitmproxy/flow.py
+++ b/mitmproxy/flow.py
@@ -78,7 +78,7 @@ class Flow(stateobject.StateObject):
self._backup = None # type: typing.Optional[Flow]
self.reply = None # type: typing.Optional[controller.Reply]
self.marked = False # type: bool
- self.metadata = dict() # type: typing.Dict[str, str]
+ self.metadata = dict() # type: typing.Dict[str, typing.Any]
_stateobject_attributes = dict(
id=str,
diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py
index e170f19d..1f898ed2 100644
--- a/mitmproxy/proxy/protocol/websocket.py
+++ b/mitmproxy/proxy/protocol/websocket.py
@@ -140,7 +140,6 @@ class WebSocketLayer(base.Layer):
def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
- self.flow.metadata['websocket_handshake'] = self.handshake_flow
self.handshake_flow.metadata['websocket_flow'] = self.flow
self.channel.ask("websocket_start", self.flow)
diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py
index 1ab744a5..14159001 100644
--- a/mitmproxy/stateobject.py
+++ b/mitmproxy/stateobject.py
@@ -39,6 +39,14 @@ class StateObject(serializable.Serializable):
state[attr] = val.get_state()
elif _is_list(cls):
state[attr] = [x.get_state() for x in val]
+ elif isinstance(val, dict):
+ s = {}
+ for k, v in val.items():
+ if hasattr(v, "get_state"):
+ s[k] = v.get_state()
+ else:
+ s[k] = v
+ state[attr] = s
else:
state[attr] = val
return state
diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py
index fd665055..7fbe1727 100644
--- a/mitmproxy/test/tflow.py
+++ b/mitmproxy/test/tflow.py
@@ -70,6 +70,7 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
handshake_flow.response = resp
f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow)
+ handshake_flow.metadata['websocket_flow'] = f
if messages is True:
messages = [
diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py
index 25a82878..5d76aafc 100644
--- a/mitmproxy/websocket.py
+++ b/mitmproxy/websocket.py
@@ -2,7 +2,6 @@ import time
from typing import List, Optional
from mitmproxy import flow
-from mitmproxy.http import HTTPFlow
from mitmproxy.net import websockets
from mitmproxy.types import serializable
from mitmproxy.utils import strutils
@@ -44,6 +43,22 @@ class WebSocketFlow(flow.Flow):
self.close_code = '(status code missing)'
self.close_message = '(message missing)'
self.close_reason = 'unknown status code'
+
+ if handshake_flow:
+ self.client_key = websockets.get_client_key(handshake_flow.request.headers)
+ self.client_protocol = websockets.get_protocol(handshake_flow.request.headers)
+ self.client_extensions = websockets.get_extensions(handshake_flow.request.headers)
+ self.server_accept = websockets.get_server_accept(handshake_flow.response.headers)
+ self.server_protocol = websockets.get_protocol(handshake_flow.response.headers)
+ self.server_extensions = websockets.get_extensions(handshake_flow.response.headers)
+ else:
+ self.client_key = ''
+ self.client_protocol = ''
+ self.client_extensions = ''
+ self.server_accept = ''
+ self.server_protocol = ''
+ self.server_extensions = ''
+
self.handshake_flow = handshake_flow
_stateobject_attributes = flow.Flow._stateobject_attributes.copy()
@@ -53,7 +68,15 @@ class WebSocketFlow(flow.Flow):
close_code=str,
close_message=str,
close_reason=str,
- handshake_flow=HTTPFlow,
+ client_key=str,
+ client_protocol=str,
+ client_extensions=str,
+ server_accept=str,
+ server_protocol=str,
+ server_extensions=str,
+ # Do not include handshake_flow, to prevent recursive serialization!
+ # Since mitmproxy-console currently only displays HTTPFlows,
+ # dumping the handshake_flow will include the WebSocketFlow too.
)
@classmethod
@@ -65,30 +88,6 @@ class WebSocketFlow(flow.Flow):
def __repr__(self):
return "<WebSocketFlow ({} messages)>".format(len(self.messages))
- @property
- def client_key(self):
- return websockets.get_client_key(self.handshake_flow.request.headers)
-
- @property
- def client_protocol(self):
- return websockets.get_protocol(self.handshake_flow.request.headers)
-
- @property
- def client_extensions(self):
- return websockets.get_extensions(self.handshake_flow.request.headers)
-
- @property
- def server_accept(self):
- return websockets.get_server_accept(self.handshake_flow.response.headers)
-
- @property
- def server_protocol(self):
- return websockets.get_protocol(self.handshake_flow.response.headers)
-
- @property
- def server_extensions(self):
- return websockets.get_extensions(self.handshake_flow.response.headers)
-
def message_info(self, message: WebSocketMessage) -> str:
return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format(
type=message.type,
diff --git a/test/mitmproxy/test_stateobject.py b/test/mitmproxy/test_stateobject.py
index 7b8e30d0..d8c7a8e9 100644
--- a/test/mitmproxy/test_stateobject.py
+++ b/test/mitmproxy/test_stateobject.py
@@ -26,10 +26,12 @@ class Container(StateObject):
def __init__(self):
self.child = None
self.children = None
+ self.dictionary = None
_stateobject_attributes = dict(
child=Child,
children=List[Child],
+ dictionary=dict,
)
@classmethod
@@ -62,12 +64,30 @@ def test_container_list():
a.children = [Child(42), Child(44)]
assert a.get_state() == {
"child": None,
- "children": [{"x": 42}, {"x": 44}]
+ "children": [{"x": 42}, {"x": 44}],
+ "dictionary": None,
}
copy = a.copy()
assert len(copy.children) == 2
assert copy.children is not a.children
assert copy.children[0] is not a.children[0]
+ assert Container.from_state(a.get_state())
+
+
+def test_container_dict():
+ a = Container()
+ a.dictionary = dict()
+ a.dictionary['foo'] = 'bar'
+ a.dictionary['bar'] = Child(44)
+ assert a.get_state() == {
+ "child": None,
+ "children": None,
+ "dictionary": {'bar': {'x': 44}, 'foo': 'bar'},
+ }
+ copy = a.copy()
+ assert len(copy.dictionary) == 2
+ assert copy.dictionary is not a.dictionary
+ assert copy.dictionary['bar'] is not a.dictionary['bar']
def test_too_much_state():
diff --git a/test/mitmproxy/test_websocket.py b/test/mitmproxy/test_websocket.py
index f2963390..62f69e2d 100644
--- a/test/mitmproxy/test_websocket.py
+++ b/test/mitmproxy/test_websocket.py
@@ -1,5 +1,7 @@
+import io
import pytest
+from mitmproxy.contrib import tnetstring
from mitmproxy import flowfilter
from mitmproxy.test import tflow
@@ -14,8 +16,6 @@ class TestWebSocketFlow:
b = f2.get_state()
del a["id"]
del b["id"]
- del a["handshake_flow"]["id"]
- del b["handshake_flow"]["id"]
assert a == b
assert not f == f2
assert f is not f2
@@ -60,3 +60,14 @@ class TestWebSocketFlow:
assert 'WebSocketFlow' in repr(f)
assert 'binary message: ' in repr(f.messages[0])
assert 'text message: ' in repr(f.messages[1])
+
+ def test_serialize(self):
+ b = io.BytesIO()
+ d = tflow.twebsocketflow().get_state()
+ tnetstring.dump(d, b)
+ assert b.getvalue()
+
+ b = io.BytesIO()
+ d = tflow.twebsocketflow().handshake_flow.get_state()
+ tnetstring.dump(d, b)
+ assert b.getvalue()