From 1847cf175c8a45359eef08f5cf2bfc414b059dbe Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 8 Feb 2017 15:00:31 +0100 Subject: websockets, tcp, version: coverage++ --- mitmproxy/websocket.py | 100 +++++++++++++++++++++++++++---------------------- 1 file changed, 56 insertions(+), 44 deletions(-) (limited to 'mitmproxy/websocket.py') diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index 6e998a52..25a82878 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -1,63 +1,38 @@ import time - -from typing import List +from typing import List, Optional from mitmproxy import flow from mitmproxy.http import HTTPFlow from mitmproxy.net import websockets -from mitmproxy.utils import strutils from mitmproxy.types import serializable +from mitmproxy.utils import strutils class WebSocketMessage(serializable.Serializable): - - def __init__(self, flow, from_client, content, timestamp=None): - self.flow = flow - self.content = content + def __init__(self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None): + self.type = type self.from_client = from_client - self.timestamp = timestamp or time.time() + self.content = content + self.timestamp = timestamp or time.time() # type: int @classmethod def from_state(cls, state): return cls(*state) def get_state(self): - return self.from_client, self.content, self.timestamp + return self.type, self.from_client, self.content, self.timestamp def set_state(self, state): - self.from_client = state.pop("from_client") - self.content = state.pop("content") - self.timestamp = state.pop("timestamp") - - @property - def info(self): - return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format( - type=self.type, - client=repr(self.flow.client_conn.address), - server=repr(self.flow.server_conn.address), - direction="->" if self.from_client else "<-", - endpoint=self.flow.handshake_flow.request.path, - ) - - -class WebSocketBinaryMessage(WebSocketMessage): - - type = 'binary' + self.type, self.from_client, self.content, self.timestamp = state def __repr__(self): - return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) - - -class WebSocketTextMessage(WebSocketMessage): - - type = 'text' - - def __repr__(self): - return "text message: {}".format(repr(self.content)) + if self.type == websockets.OPCODE.TEXT: + return "text message: {}".format(repr(self.content)) + else: + return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) class WebSocketFlow(flow.Flow): - """ A WebsocketFlow is a simplified representation of a Websocket session. """ @@ -70,18 +45,55 @@ class WebSocketFlow(flow.Flow): self.close_message = '(message missing)' self.close_reason = 'unknown status code' self.handshake_flow = handshake_flow - self.client_key = websockets.get_client_key(self.handshake_flow.request.headers) - self.client_protocol = websockets.get_protocol(self.handshake_flow.request.headers) - self.client_extensions = websockets.get_extensions(self.handshake_flow.request.headers) - self.server_accept = websockets.get_server_accept(self.handshake_flow.response.headers) - self.server_protocol = websockets.get_protocol(self.handshake_flow.response.headers) - self.server_extensions = websockets.get_extensions(self.handshake_flow.response.headers) _stateobject_attributes = flow.Flow._stateobject_attributes.copy() _stateobject_attributes.update( messages=List[WebSocketMessage], + close_sender=str, + close_code=str, + close_message=str, + close_reason=str, handshake_flow=HTTPFlow, ) + @classmethod + def from_state(cls, state): + f = cls(None, None, None) + f.set_state(state) + return f + def __repr__(self): - return "WebSocketFlow ({} messages)".format(len(self.messages)) + return "".format(len(self.messages)) + + @property + def client_key(self): + return websockets.get_client_key(self.handshake_flow.request.headers) + + @property + def client_protocol(self): + return websockets.get_protocol(self.handshake_flow.request.headers) + + @property + def client_extensions(self): + return websockets.get_extensions(self.handshake_flow.request.headers) + + @property + def server_accept(self): + return websockets.get_server_accept(self.handshake_flow.response.headers) + + @property + def server_protocol(self): + return websockets.get_protocol(self.handshake_flow.response.headers) + + @property + def server_extensions(self): + return websockets.get_extensions(self.handshake_flow.response.headers) + + def message_info(self, message: WebSocketMessage) -> str: + return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format( + type=message.type, + client=repr(self.client_conn.address), + server=repr(self.server_conn.address), + direction="->" if message.from_client else "<-", + endpoint=self.handshake_flow.request.path, + ) -- cgit v1.2.3