diff options
-rw-r--r-- | docs/scripting/api.rst | 12 | ||||
-rw-r--r-- | docs/scripting/events.rst | 4 | ||||
-rw-r--r-- | mitmproxy/addonmanager.py | 2 | ||||
-rw-r--r-- | mitmproxy/command.py | 16 | ||||
-rw-r--r-- | mitmproxy/controller.py | 8 | ||||
-rw-r--r-- | mitmproxy/flow.py | 19 | ||||
-rw-r--r-- | mitmproxy/proxy/protocol/websocket.py | 9 | ||||
-rw-r--r-- | mitmproxy/utils/human.py | 2 | ||||
-rw-r--r-- | mitmproxy/websocket.py | 37 | ||||
-rw-r--r-- | test/mitmproxy/proxy/protocol/test_websocket.py | 19 | ||||
-rw-r--r-- | test/mitmproxy/test_websocket.py | 21 | ||||
-rw-r--r-- | test/mitmproxy/tools/console/test_defaultkeys.py | 23 | ||||
-rw-r--r-- | test/mitmproxy/utils/test_human.py | 2 |
13 files changed, 137 insertions, 37 deletions
diff --git a/docs/scripting/api.rst b/docs/scripting/api.rst index e82afef4..368b9ba8 100644 --- a/docs/scripting/api.rst +++ b/docs/scripting/api.rst @@ -10,6 +10,9 @@ API - `mitmproxy.http.HTTPRequest <#mitmproxy.http.HTTPRequest>`_ - `mitmproxy.http.HTTPResponse <#mitmproxy.http.HTTPResponse>`_ - `mitmproxy.http.HTTPFlow <#mitmproxy.http.HTTPFlow>`_ +- WebSocket + - `mitmproxy.websocket.WebSocketFlow <#mitmproxy.websocket.WebSocketFlow>`_ + - `mitmproxy.websocket.WebSocketMessage <#mitmproxy.websocket.WebSocketMessage>`_ - Logging - `mitmproxy.log.Log <#mitmproxy.controller.Log>`_ - `mitmproxy.log.LogEntry <#mitmproxy.controller.LogEntry>`_ @@ -33,6 +36,15 @@ HTTP .. autoclass:: mitmproxy.http.HTTPFlow :inherited-members: +WebSocket +--------- + +.. autoclass:: mitmproxy.websocket.WebSocketFlow + :inherited-members: + +.. autoclass:: mitmproxy.websocket.WebSocketMessage + :inherited-members: + Logging -------- diff --git a/docs/scripting/events.rst b/docs/scripting/events.rst index 9e84dacf..4d74b220 100644 --- a/docs/scripting/events.rst +++ b/docs/scripting/events.rst @@ -187,8 +187,8 @@ are issued, only new WebSocket messages are called. - Called when a WebSocket message is received from the client or server. The sender and receiver are identifiable. The most recent message will be - ``flow.messages[-1]``. The message is user-modifiable. Currently there are - two types of messages, corresponding to the BINARY and TEXT frame types. + ``flow.messages[-1]``. The message is user-modifiable and is killable. + A message is either of TEXT or BINARY type. *flow* A ``models.WebSocketFlow`` object. diff --git a/mitmproxy/addonmanager.py b/mitmproxy/addonmanager.py index 70cfda30..37c501ee 100644 --- a/mitmproxy/addonmanager.py +++ b/mitmproxy/addonmanager.py @@ -230,7 +230,7 @@ class AddonManager: self.trigger(name, message) - if message.reply.state != "taken": + if message.reply.state == "start": message.reply.take() if not message.reply.has_message: message.reply.ack() diff --git a/mitmproxy/command.py b/mitmproxy/command.py index e1e56d3a..7bb2bf8e 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -76,11 +76,7 @@ class Command: ret = " -> " + ret return "%s %s%s" % (self.path, params, ret) - def call(self, args: typing.Sequence[str]) -> typing.Any: - """ - Call the command with a list of arguments. At this point, all - arguments are strings. - """ + def prepare_args(self, args: typing.Sequence[str]) -> typing.List[typing.Any]: verify_arg_signature(self.func, list(args), {}) remainder = [] # type: typing.Sequence[str] @@ -92,6 +88,14 @@ class Command: for arg, paramtype in zip(args, self.paramtypes): pargs.append(parsearg(self.manager, arg, paramtype)) pargs.extend(remainder) + return pargs + + def call(self, args: typing.Sequence[str]) -> typing.Any: + """ + Call the command with a list of arguments. At this point, all + arguments are strings. + """ + pargs = self.prepare_args(args) with self.manager.master.handlecontext(): ret = self.func(*pargs) @@ -121,7 +125,7 @@ ParseResult = typing.NamedTuple( class CommandManager(mitmproxy.types._CommandBase): def __init__(self, master): self.master = master - self.commands = {} + self.commands = {} # type: typing.Dict[str, Command] def collect_commands(self, addon): for i in dir(addon): diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index 63117ef0..f39c1b24 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -105,16 +105,16 @@ class Reply: self.q.put(self.value) def ack(self, force=False): - if self.state not in {"start", "taken"}: - raise exceptions.ControlException( - "Reply is {}, but expected it to be start or taken.".format(self.state) - ) self.send(self.obj, force) def kill(self, force=False): self.send(exceptions.Kill, force) def send(self, msg, force=False): + if self.state not in {"start", "taken"}: + raise exceptions.ControlException( + "Reply is {}, but expected it to be start or taken.".format(self.state) + ) if self.has_message and not force: raise exceptions.ControlException("There is already a reply message.") self.value = msg diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 111566b8..944c032d 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -1,13 +1,12 @@ import time +import typing # noqa import uuid -from mitmproxy import controller # noqa -from mitmproxy import stateobject from mitmproxy import connections +from mitmproxy import controller, exceptions # noqa +from mitmproxy import stateobject from mitmproxy import version -import typing # noqa - class Error(stateobject.StateObject): @@ -145,7 +144,11 @@ class Flow(stateobject.StateObject): @property def killable(self): - return self.reply and self.reply.state == "taken" + return ( + self.reply and + self.reply.state in {"start", "taken"} and + self.reply.value != exceptions.Kill + ) def kill(self): """ @@ -153,13 +156,7 @@ class Flow(stateobject.StateObject): """ self.error = Error("Connection killed") self.intercepted = False - - # reply.state should be "taken" here, or .take() will raise an - # exception. - if self.reply.state != "taken": - self.reply.take() self.reply.kill(force=True) - self.reply.commit() self.live = False def intercept(self): diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 1bd5284d..92f99518 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -109,7 +109,7 @@ class WebSocketLayer(base.Layer): self.flow.messages.append(websocket_message) self.channel.ask("websocket_message", self.flow) - if not self.flow.stream: + if not self.flow.stream and not websocket_message.killed: def get_chunk(payload): if len(payload) == length: # message has the same length, we can reuse the same sizes @@ -129,14 +129,9 @@ class WebSocketLayer(base.Layer): self.connections[other_conn].send_data(chunk, final) other_conn.send(self.connections[other_conn].bytes_to_send()) - else: - self.connections[other_conn].send_data(event.data, event.message_finished) - other_conn.send(self.connections[other_conn].bytes_to_send()) - - elif self.flow.stream: + if self.flow.stream: self.connections[other_conn].send_data(event.data, event.message_finished) other_conn.send(self.connections[other_conn].bytes_to_send()) - return True def _handle_ping_received(self, event, source_conn, other_conn, is_server): diff --git a/mitmproxy/utils/human.py b/mitmproxy/utils/human.py index e2e3142a..b21ac0b8 100644 --- a/mitmproxy/utils/human.py +++ b/mitmproxy/utils/human.py @@ -80,6 +80,8 @@ def format_address(address: tuple) -> str: """ try: host = ipaddress.ip_address(address[0]) + if host.is_unspecified: + return "*:{}".format(address[1]) if isinstance(host, ipaddress.IPv4Address): return "{}:{}".format(str(host), address[1]) # If IPv6 is mapped to IPv4 diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index 8efd4117..a37edb54 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -10,23 +10,33 @@ from mitmproxy.utils import strutils, human class WebSocketMessage(serializable.Serializable): + """ + A WebSocket message sent from one endpoint to the other. + """ + def __init__( - self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None + self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None, killed: bool=False ) -> None: self.type = wsproto.frame_protocol.Opcode(type) # type: ignore + """indicates either TEXT or BINARY (from wsproto.frame_protocol.Opcode).""" self.from_client = from_client + """True if this messages was sent by the client.""" self.content = content + """A byte-string representing the content of this message.""" self.timestamp = timestamp or int(time.time()) # type: int + """Timestamp of when this message was received or created.""" + self.killed = killed + """True if this messages was killed and should not be sent to the other endpoint.""" @classmethod def from_state(cls, state): return cls(*state) def get_state(self): - return int(self.type), self.from_client, self.content, self.timestamp + return int(self.type), self.from_client, self.content, self.timestamp, self.killed def set_state(self, state): - self.type, self.from_client, self.content, self.timestamp = state + self.type, self.from_client, self.content, self.timestamp, self.killed = state self.type = wsproto.frame_protocol.Opcode(self.type) # replace enum with bare int def __repr__(self): @@ -35,20 +45,37 @@ class WebSocketMessage(serializable.Serializable): else: return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) + def kill(self): + """ + Kill this message. + + It will not be sent to the other endpoint. This has no effect in streaming mode. + """ + self.killed = True + class WebSocketFlow(flow.Flow): """ - A WebsocketFlow is a simplified representation of a Websocket session. + A WebsocketFlow is a simplified representation of a Websocket connection. """ def __init__(self, client_conn, server_conn, handshake_flow, live=None): super().__init__("websocket", client_conn, server_conn, live) + self.messages = [] # type: List[WebSocketMessage] + """A list containing all WebSocketMessage's.""" self.close_sender = 'client' + """'client' if the client initiated connection closing.""" self.close_code = wsproto.frame_protocol.CloseReason.NORMAL_CLOSURE + """WebSocket close code.""" self.close_message = '(message missing)' + """WebSocket close message.""" self.close_reason = 'unknown status code' + """WebSocket close reason.""" self.stream = False + """True of this connection is streaming directly to the other endpoint.""" + self.handshake_flow = handshake_flow + """The HTTP flow containing the initial WebSocket handshake.""" if handshake_flow: self.client_key = websockets.get_client_key(handshake_flow.request.headers) @@ -65,8 +92,6 @@ class WebSocketFlow(flow.Flow): self.server_protocol = '' self.server_extensions = '' - self.handshake_flow = handshake_flow - _stateobject_attributes = flow.Flow._stateobject_attributes.copy() # mypy doesn't support update with kwargs _stateobject_attributes.update(dict( diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index a7acdc4d..d9389faf 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -221,6 +221,25 @@ class TestSimple(_WebSocketTest): assert frame.payload == b'foo' +class TestKillFlow(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + def test_kill(self): + class KillFlow: + def websocket_message(self, f): + f.kill() + + self.master.addons.add(KillFlow()) + self.setup_connection() + + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(self.client.rfile) + + class TestSimpleTLS(_WebSocketTest): ssl = True diff --git a/test/mitmproxy/test_websocket.py b/test/mitmproxy/test_websocket.py index 7c53a4b0..fcacec36 100644 --- a/test/mitmproxy/test_websocket.py +++ b/test/mitmproxy/test_websocket.py @@ -3,6 +3,7 @@ import pytest from mitmproxy.io import tnetstring from mitmproxy import flowfilter +from mitmproxy.exceptions import Kill, ControlException from mitmproxy.test import tflow @@ -42,6 +43,20 @@ class TestWebSocketFlow: assert f.error.get_state() == f2.error.get_state() assert f.error is not f2.error + def test_kill(self): + f = tflow.twebsocketflow() + with pytest.raises(ControlException): + f.intercept() + f.resume() + f.kill() + + f = tflow.twebsocketflow() + f.intercept() + assert f.killable + f.kill() + assert not f.killable + assert f.reply.value == Kill + def test_match(self): f = tflow.twebsocketflow() assert not flowfilter.match("~b nonexistent", f) @@ -71,3 +86,9 @@ class TestWebSocketFlow: d = tflow.twebsocketflow().handshake_flow.get_state() tnetstring.dump(d, b) assert b.getvalue() + + def test_message_kill(self): + f = tflow.twebsocketflow() + assert not f.messages[-1].killed + f.messages[-1].kill() + assert f.messages[-1].killed diff --git a/test/mitmproxy/tools/console/test_defaultkeys.py b/test/mitmproxy/tools/console/test_defaultkeys.py new file mode 100644 index 00000000..1f17c888 --- /dev/null +++ b/test/mitmproxy/tools/console/test_defaultkeys.py @@ -0,0 +1,23 @@ +from mitmproxy.test.tflow import tflow +from mitmproxy.tools.console import defaultkeys +from mitmproxy.tools.console import keymap +from mitmproxy.tools.console import master +from mitmproxy import command + + +def test_commands_exist(): + km = keymap.Keymap(None) + defaultkeys.map(km) + assert km.bindings + m = master.ConsoleMaster(None) + m.load_flow(tflow()) + + for binding in km.bindings: + cmd, *args = command.lexer(binding.command) + assert cmd in m.commands.commands + + cmd_obj = m.commands.commands[cmd] + try: + cmd_obj.prepare_args(args) + except Exception as e: + raise ValueError("Invalid command: {}".format(binding.command)) from e diff --git a/test/mitmproxy/utils/test_human.py b/test/mitmproxy/utils/test_human.py index e8ffaad4..947cfa4a 100644 --- a/test/mitmproxy/utils/test_human.py +++ b/test/mitmproxy/utils/test_human.py @@ -54,3 +54,5 @@ def test_format_address(): assert human.format_address(("::ffff:127.0.0.1", "54010", "0", "0")) == "127.0.0.1:54010" assert human.format_address(("127.0.0.1", "54010")) == "127.0.0.1:54010" assert human.format_address(("example.com", "54010")) == "example.com:54010" + assert human.format_address(("::", "8080")) == "*:8080" + assert human.format_address(("0.0.0.0", "8080")) == "*:8080" |