diff options
-rw-r--r-- | mitmproxy/addons/dumper.py | 2 | ||||
-rw-r--r-- | mitmproxy/proxy/protocol/websocket.py | 9 | ||||
-rw-r--r-- | mitmproxy/tcp.py | 6 | ||||
-rw-r--r-- | mitmproxy/test/tflow.py | 5 | ||||
-rw-r--r-- | mitmproxy/version.py | 2 | ||||
-rw-r--r-- | mitmproxy/websocket.py | 100 | ||||
-rw-r--r-- | test/mitmproxy/addons/test_stickycookie.py | 1 | ||||
-rw-r--r-- | test/mitmproxy/protocol/test_websocket.py | 11 | ||||
-rw-r--r-- | test/mitmproxy/test_flow.py | 97 | ||||
-rw-r--r-- | tox.ini | 4 |
10 files changed, 171 insertions, 66 deletions
diff --git a/mitmproxy/addons/dumper.py b/mitmproxy/addons/dumper.py index 12b0c34b..222f1167 100644 --- a/mitmproxy/addons/dumper.py +++ b/mitmproxy/addons/dumper.py @@ -238,7 +238,7 @@ class Dumper: def websocket_message(self, f): if self.match(f): message = f.messages[-1] - self.echo(message.info) + self.echo(f.message_info(message)) if self.flow_detail >= 3: self._echo_message(message) diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index d0b12540..e170f19d 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -8,7 +8,7 @@ from mitmproxy import flow from mitmproxy.proxy.protocol import base from mitmproxy.net import tcp from mitmproxy.net import websockets -from mitmproxy.websocket import WebSocketFlow, WebSocketBinaryMessage, WebSocketTextMessage +from mitmproxy.websocket import WebSocketFlow, WebSocketMessage class WebSocketLayer(base.Layer): @@ -65,12 +65,7 @@ class WebSocketLayer(base.Layer): compressed_message = fb[0].header.rsv1 fb.clear() - if message_type == websockets.OPCODE.TEXT: - t = WebSocketTextMessage - else: - t = WebSocketBinaryMessage - - websocket_message = t(self.flow, not is_server, payload) + websocket_message = WebSocketMessage(message_type, not is_server, payload) length = len(websocket_message.content) self.flow.messages.append(websocket_message) self.channel.ask("websocket_message", self.flow) diff --git a/mitmproxy/tcp.py b/mitmproxy/tcp.py index 3f10f82b..067fbfe3 100644 --- a/mitmproxy/tcp.py +++ b/mitmproxy/tcp.py @@ -9,8 +9,8 @@ from mitmproxy.types import serializable class TCPMessage(serializable.Serializable): def __init__(self, from_client, content, timestamp=None): - self.content = content self.from_client = from_client + self.content = content self.timestamp = timestamp or time.time() @classmethod @@ -21,9 +21,7 @@ class TCPMessage(serializable.Serializable): return self.from_client, self.content, self.timestamp def set_state(self, state): - self.from_client = state.pop("from_client") - self.content = state.pop("content") - self.timestamp = state.pop("timestamp") + self.from_client, self.content, self.timestamp = state def __repr__(self): return "{direction} {content}".format( diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index a5670538..6d330840 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -1,3 +1,4 @@ +from mitmproxy.net import websockets from mitmproxy.test import tutils from mitmproxy import tcp from mitmproxy import websocket @@ -70,8 +71,8 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None, if messages is True: messages = [ - websocket.WebSocketBinaryMessage(f, True, b"hello binary"), - websocket.WebSocketTextMessage(f, False, "hello text".encode()), + websocket.WebSocketMessage(websockets.OPCODE.BINARY, True, b"hello binary"), + websocket.WebSocketMessage(websockets.OPCODE.TEXT, False, "hello text".encode()), ] if err is True: err = terr() diff --git a/mitmproxy/version.py b/mitmproxy/version.py index 22382c94..a5faf511 100644 --- a/mitmproxy/version.py +++ b/mitmproxy/version.py @@ -3,5 +3,5 @@ VERSION = ".".join(str(i) for i in IVERSION) PATHOD = "pathod " + VERSION MITMPROXY = "mitmproxy " + VERSION -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover print(VERSION) diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index 6e998a52..25a82878 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -1,63 +1,38 @@ import time - -from typing import List +from typing import List, Optional from mitmproxy import flow from mitmproxy.http import HTTPFlow from mitmproxy.net import websockets -from mitmproxy.utils import strutils from mitmproxy.types import serializable +from mitmproxy.utils import strutils class WebSocketMessage(serializable.Serializable): - - def __init__(self, flow, from_client, content, timestamp=None): - self.flow = flow - self.content = content + def __init__(self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None): + self.type = type self.from_client = from_client - self.timestamp = timestamp or time.time() + self.content = content + self.timestamp = timestamp or time.time() # type: int @classmethod def from_state(cls, state): return cls(*state) def get_state(self): - return self.from_client, self.content, self.timestamp + return self.type, self.from_client, self.content, self.timestamp def set_state(self, state): - self.from_client = state.pop("from_client") - self.content = state.pop("content") - self.timestamp = state.pop("timestamp") - - @property - def info(self): - return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format( - type=self.type, - client=repr(self.flow.client_conn.address), - server=repr(self.flow.server_conn.address), - direction="->" if self.from_client else "<-", - endpoint=self.flow.handshake_flow.request.path, - ) - - -class WebSocketBinaryMessage(WebSocketMessage): - - type = 'binary' + self.type, self.from_client, self.content, self.timestamp = state def __repr__(self): - return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) - - -class WebSocketTextMessage(WebSocketMessage): - - type = 'text' - - def __repr__(self): - return "text message: {}".format(repr(self.content)) + if self.type == websockets.OPCODE.TEXT: + return "text message: {}".format(repr(self.content)) + else: + return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) class WebSocketFlow(flow.Flow): - """ A WebsocketFlow is a simplified representation of a Websocket session. """ @@ -70,18 +45,55 @@ class WebSocketFlow(flow.Flow): self.close_message = '(message missing)' self.close_reason = 'unknown status code' self.handshake_flow = handshake_flow - self.client_key = websockets.get_client_key(self.handshake_flow.request.headers) - self.client_protocol = websockets.get_protocol(self.handshake_flow.request.headers) - self.client_extensions = websockets.get_extensions(self.handshake_flow.request.headers) - self.server_accept = websockets.get_server_accept(self.handshake_flow.response.headers) - self.server_protocol = websockets.get_protocol(self.handshake_flow.response.headers) - self.server_extensions = websockets.get_extensions(self.handshake_flow.response.headers) _stateobject_attributes = flow.Flow._stateobject_attributes.copy() _stateobject_attributes.update( messages=List[WebSocketMessage], + close_sender=str, + close_code=str, + close_message=str, + close_reason=str, handshake_flow=HTTPFlow, ) + @classmethod + def from_state(cls, state): + f = cls(None, None, None) + f.set_state(state) + return f + def __repr__(self): - return "WebSocketFlow ({} messages)".format(len(self.messages)) + return "<WebSocketFlow ({} messages)>".format(len(self.messages)) + + @property + def client_key(self): + return websockets.get_client_key(self.handshake_flow.request.headers) + + @property + def client_protocol(self): + return websockets.get_protocol(self.handshake_flow.request.headers) + + @property + def client_extensions(self): + return websockets.get_extensions(self.handshake_flow.request.headers) + + @property + def server_accept(self): + return websockets.get_server_accept(self.handshake_flow.response.headers) + + @property + def server_protocol(self): + return websockets.get_protocol(self.handshake_flow.response.headers) + + @property + def server_extensions(self): + return websockets.get_extensions(self.handshake_flow.response.headers) + + def message_info(self, message: WebSocketMessage) -> str: + return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format( + type=message.type, + client=repr(self.client_conn.address), + server=repr(self.server_conn.address), + direction="->" if message.from_client else "<-", + endpoint=self.handshake_flow.request.path, + ) diff --git a/test/mitmproxy/addons/test_stickycookie.py b/test/mitmproxy/addons/test_stickycookie.py index 157f2959..9092e09b 100644 --- a/test/mitmproxy/addons/test_stickycookie.py +++ b/test/mitmproxy/addons/test_stickycookie.py @@ -39,7 +39,6 @@ class TestStickyCookie: assert "cookie" not in f.request.headers f = f.copy() - f.reply.acked = False sc.request(f) assert f.request.headers["cookie"] == "foo=bar" diff --git a/test/mitmproxy/protocol/test_websocket.py b/test/mitmproxy/protocol/test_websocket.py index e42250e0..73ee8b35 100644 --- a/test/mitmproxy/protocol/test_websocket.py +++ b/test/mitmproxy/protocol/test_websocket.py @@ -179,16 +179,15 @@ class TestSimple(_WebSocketTest): assert isinstance(self.master.state.flows[1], WebSocketFlow) assert len(self.master.state.flows[1].messages) == 5 assert self.master.state.flows[1].messages[0].content == b'server-foobar' - assert self.master.state.flows[1].messages[0].type == 'text' + assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT assert self.master.state.flows[1].messages[1].content == b'client-foobar' - assert self.master.state.flows[1].messages[1].type == 'text' + assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT assert self.master.state.flows[1].messages[2].content == b'client-foobar' - assert self.master.state.flows[1].messages[2].type == 'text' + assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' - assert self.master.state.flows[1].messages[3].type == 'binary' + assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' - assert self.master.state.flows[1].messages[4].type == 'binary' - assert [m.info for m in self.master.state.flows[1].messages] + assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY class TestSimpleTLS(_WebSocketTest): diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 371474ff..65e6845f 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -10,6 +10,7 @@ from mitmproxy.exceptions import FlowReadException, Kill from mitmproxy import flow from mitmproxy import http from mitmproxy import connections +from mitmproxy import tcp from mitmproxy.proxy import ProxyConfig from mitmproxy.proxy.server import DummyServer from mitmproxy import master @@ -156,8 +157,99 @@ class TestHTTPFlow: assert f.response.raw_content == b"abarb" +class TestWebSocketFlow: + + def test_copy(self): + f = tflow.twebsocketflow() + f.get_state() + f2 = f.copy() + a = f.get_state() + b = f2.get_state() + del a["id"] + del b["id"] + del a["handshake_flow"]["id"] + del b["handshake_flow"]["id"] + assert a == b + assert not f == f2 + assert f is not f2 + + assert f.client_key == f2.client_key + assert f.client_protocol == f2.client_protocol + assert f.client_extensions == f2.client_extensions + assert f.server_accept == f2.server_accept + assert f.server_protocol == f2.server_protocol + assert f.server_extensions == f2.server_extensions + assert f.messages is not f2.messages + assert f.handshake_flow is not f2.handshake_flow + + for m in f.messages: + m2 = m.copy() + m2.set_state(m2.get_state()) + assert m is not m2 + assert m.get_state() == m2.get_state() + + f = tflow.twebsocketflow(err=True) + f2 = f.copy() + assert f is not f2 + assert f.handshake_flow is not f2.handshake_flow + assert f.error.get_state() == f2.error.get_state() + assert f.error is not f2.error + + def test_match(self): + f = tflow.twebsocketflow() + assert not flowfilter.match("~b nonexistent", f) + assert flowfilter.match(None, f) + assert not flowfilter.match("~b nonexistent", f) + + f = tflow.twebsocketflow(err=True) + assert flowfilter.match("~e", f) + + with pytest.raises(ValueError): + flowfilter.match("~", f) + + def test_repr(self): + f = tflow.twebsocketflow() + assert 'WebSocketFlow' in repr(f) + assert 'binary message: ' in repr(f.messages[0]) + assert 'text message: ' in repr(f.messages[1]) + + class TestTCPFlow: + def test_copy(self): + f = tflow.ttcpflow() + f.get_state() + f2 = f.copy() + a = f.get_state() + b = f2.get_state() + del a["id"] + del b["id"] + assert a == b + assert not f == f2 + assert f is not f2 + + assert f.messages is not f2.messages + + for m in f.messages: + assert m.get_state() + m2 = m.copy() + assert not m == m2 + assert m is not m2 + + a = m.get_state() + b = m2.get_state() + assert a == b + + m = tcp.TCPMessage(False, 'foo') + m.set_state(f.messages[0].get_state()) + assert m.timestamp == f.messages[0].timestamp + + f = tflow.ttcpflow(err=True) + f2 = f.copy() + assert f is not f2 + assert f.error.get_state() == f2.error.get_state() + assert f.error is not f2.error + def test_match(self): f = tflow.ttcpflow() assert not flowfilter.match("~b nonexistent", f) @@ -170,6 +262,11 @@ class TestTCPFlow: with pytest.raises(ValueError): flowfilter.match("~", f) + def test_repr(self): + f = tflow.ttcpflow() + assert 'TCPFlow' in repr(f) + assert '-> ' in repr(f.messages[0]) + class TestSerialize: @@ -27,6 +27,10 @@ commands = --full-cov=mitmproxy/io.py \ --full-cov=mitmproxy/log.py \ --full-cov=mitmproxy/options.py \ + --full-cov=mitmproxy/stateobject.py \ + --full-cov=mitmproxy/version.py \ + --full-cov=mitmproxy/tcp.py \ + --full-cov=mitmproxy/websocket.py \ --full-cov=pathod/ --no-full-cov=pathod/pathoc.py --no-full-cov=pathod/pathod.py --no-full-cov=pathod/test.py --no-full-cov=pathod/protocols/http2.py \ {posargs} {env:CI_COMMANDS:python -c ""} |