diff options
author | Thomas Kriechbaumer <Kriechi@users.noreply.github.com> | 2017-03-10 21:15:46 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-10 21:15:46 +0100 |
commit | e9746c51820aadcbb2ddbfa8fc963d9023dd6991 (patch) | |
tree | 0017afe09e38a7f589f605a1c5140abd49d99486 | |
parent | c39b65c06b263e7ed5686df511efb04bd0232c95 (diff) | |
parent | 49e0f2384891d8ab33844a0a5c9a57981eed5085 (diff) | |
download | mitmproxy-e9746c51820aadcbb2ddbfa8fc963d9023dd6991.tar.gz mitmproxy-e9746c51820aadcbb2ddbfa8fc963d9023dd6991.tar.bz2 mitmproxy-e9746c51820aadcbb2ddbfa8fc963d9023dd6991.zip |
Merge pull request #2114 from mitmproxy/fix-websocket-serialization
make websocket flows serializable
-rw-r--r-- | mitmproxy/flow.py | 2 | ||||
-rw-r--r-- | mitmproxy/proxy/protocol/websocket.py | 4 | ||||
-rw-r--r-- | mitmproxy/stateobject.py | 8 | ||||
-rw-r--r-- | mitmproxy/test/tflow.py | 1 | ||||
-rw-r--r-- | mitmproxy/websocket.py | 51 | ||||
-rw-r--r-- | test/mitmproxy/test_stateobject.py | 22 | ||||
-rw-r--r-- | test/mitmproxy/test_websocket.py | 15 |
7 files changed, 71 insertions, 32 deletions
diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 5ef957c9..cc5f0aed 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -78,7 +78,7 @@ class Flow(stateobject.StateObject): self._backup = None # type: typing.Optional[Flow] self.reply = None # type: typing.Optional[controller.Reply] self.marked = False # type: bool - self.metadata = dict() # type: typing.Dict[str, str] + self.metadata = dict() # type: typing.Dict[str, typing.Any] _stateobject_attributes = dict( id=str, diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index e170f19d..373c6479 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -140,8 +140,8 @@ class WebSocketLayer(base.Layer): def __call__(self): self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self) - self.flow.metadata['websocket_handshake'] = self.handshake_flow - self.handshake_flow.metadata['websocket_flow'] = self.flow + self.flow.metadata['websocket_handshake'] = self.handshake_flow.id + self.handshake_flow.metadata['websocket_flow'] = self.flow.id self.channel.ask("websocket_start", self.flow) client = self.client_conn.connection diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py index 1ab744a5..14159001 100644 --- a/mitmproxy/stateobject.py +++ b/mitmproxy/stateobject.py @@ -39,6 +39,14 @@ class StateObject(serializable.Serializable): state[attr] = val.get_state() elif _is_list(cls): state[attr] = [x.get_state() for x in val] + elif isinstance(val, dict): + s = {} + for k, v in val.items(): + if hasattr(v, "get_state"): + s[k] = v.get_state() + else: + s[k] = v + state[attr] = s else: state[attr] = val return state diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index fd665055..7fbe1727 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -70,6 +70,7 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None, handshake_flow.response = resp f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow) + handshake_flow.metadata['websocket_flow'] = f if messages is True: messages = [ diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index 25a82878..5d76aafc 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -2,7 +2,6 @@ import time from typing import List, Optional from mitmproxy import flow -from mitmproxy.http import HTTPFlow from mitmproxy.net import websockets from mitmproxy.types import serializable from mitmproxy.utils import strutils @@ -44,6 +43,22 @@ class WebSocketFlow(flow.Flow): self.close_code = '(status code missing)' self.close_message = '(message missing)' self.close_reason = 'unknown status code' + + if handshake_flow: + self.client_key = websockets.get_client_key(handshake_flow.request.headers) + self.client_protocol = websockets.get_protocol(handshake_flow.request.headers) + self.client_extensions = websockets.get_extensions(handshake_flow.request.headers) + self.server_accept = websockets.get_server_accept(handshake_flow.response.headers) + self.server_protocol = websockets.get_protocol(handshake_flow.response.headers) + self.server_extensions = websockets.get_extensions(handshake_flow.response.headers) + else: + self.client_key = '' + self.client_protocol = '' + self.client_extensions = '' + self.server_accept = '' + self.server_protocol = '' + self.server_extensions = '' + self.handshake_flow = handshake_flow _stateobject_attributes = flow.Flow._stateobject_attributes.copy() @@ -53,7 +68,15 @@ class WebSocketFlow(flow.Flow): close_code=str, close_message=str, close_reason=str, - handshake_flow=HTTPFlow, + client_key=str, + client_protocol=str, + client_extensions=str, + server_accept=str, + server_protocol=str, + server_extensions=str, + # Do not include handshake_flow, to prevent recursive serialization! + # Since mitmproxy-console currently only displays HTTPFlows, + # dumping the handshake_flow will include the WebSocketFlow too. ) @classmethod @@ -65,30 +88,6 @@ class WebSocketFlow(flow.Flow): def __repr__(self): return "<WebSocketFlow ({} messages)>".format(len(self.messages)) - @property - def client_key(self): - return websockets.get_client_key(self.handshake_flow.request.headers) - - @property - def client_protocol(self): - return websockets.get_protocol(self.handshake_flow.request.headers) - - @property - def client_extensions(self): - return websockets.get_extensions(self.handshake_flow.request.headers) - - @property - def server_accept(self): - return websockets.get_server_accept(self.handshake_flow.response.headers) - - @property - def server_protocol(self): - return websockets.get_protocol(self.handshake_flow.response.headers) - - @property - def server_extensions(self): - return websockets.get_extensions(self.handshake_flow.response.headers) - def message_info(self, message: WebSocketMessage) -> str: return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format( type=message.type, diff --git a/test/mitmproxy/test_stateobject.py b/test/mitmproxy/test_stateobject.py index 7b8e30d0..d8c7a8e9 100644 --- a/test/mitmproxy/test_stateobject.py +++ b/test/mitmproxy/test_stateobject.py @@ -26,10 +26,12 @@ class Container(StateObject): def __init__(self): self.child = None self.children = None + self.dictionary = None _stateobject_attributes = dict( child=Child, children=List[Child], + dictionary=dict, ) @classmethod @@ -62,12 +64,30 @@ def test_container_list(): a.children = [Child(42), Child(44)] assert a.get_state() == { "child": None, - "children": [{"x": 42}, {"x": 44}] + "children": [{"x": 42}, {"x": 44}], + "dictionary": None, } copy = a.copy() assert len(copy.children) == 2 assert copy.children is not a.children assert copy.children[0] is not a.children[0] + assert Container.from_state(a.get_state()) + + +def test_container_dict(): + a = Container() + a.dictionary = dict() + a.dictionary['foo'] = 'bar' + a.dictionary['bar'] = Child(44) + assert a.get_state() == { + "child": None, + "children": None, + "dictionary": {'bar': {'x': 44}, 'foo': 'bar'}, + } + copy = a.copy() + assert len(copy.dictionary) == 2 + assert copy.dictionary is not a.dictionary + assert copy.dictionary['bar'] is not a.dictionary['bar'] def test_too_much_state(): diff --git a/test/mitmproxy/test_websocket.py b/test/mitmproxy/test_websocket.py index f2963390..62f69e2d 100644 --- a/test/mitmproxy/test_websocket.py +++ b/test/mitmproxy/test_websocket.py @@ -1,5 +1,7 @@ +import io import pytest +from mitmproxy.contrib import tnetstring from mitmproxy import flowfilter from mitmproxy.test import tflow @@ -14,8 +16,6 @@ class TestWebSocketFlow: b = f2.get_state() del a["id"] del b["id"] - del a["handshake_flow"]["id"] - del b["handshake_flow"]["id"] assert a == b assert not f == f2 assert f is not f2 @@ -60,3 +60,14 @@ class TestWebSocketFlow: assert 'WebSocketFlow' in repr(f) assert 'binary message: ' in repr(f.messages[0]) assert 'text message: ' in repr(f.messages[1]) + + def test_serialize(self): + b = io.BytesIO() + d = tflow.twebsocketflow().get_state() + tnetstring.dump(d, b) + assert b.getvalue() + + b = io.BytesIO() + d = tflow.twebsocketflow().handshake_flow.get_state() + tnetstring.dump(d, b) + assert b.getvalue() |