diff options
30 files changed, 275 insertions, 1649 deletions
diff --git a/.appveyor.yml b/.appveyor.yml index 160cdf73..3ef985be 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -34,9 +34,9 @@ install: test_script: - ps: "tox -- --verbose --cov-report=term" - ps: | - $Env:VERSION = $(python mitmproxy/version.py) + $Env:VERSION = $(python -m mitmproxy.version) $Env:SKIP_MITMPROXY = "python -c `"print('skip mitmproxy')`"" - tox -e wheel + tox -e rtool -- wheel tox -e rtool -- bdist - ps: | @@ -46,7 +46,7 @@ test_script: ) { echo "Decrypt license..." tox -e rtool -- decrypt release\installbuilder\license.xml.enc release\installbuilder\license.xml - $ibVersion = "17.9.0" + $ibVersion = "17.12.0" $ibSetup = "C:\projects\mitmproxy\release\installbuilder-installer.exe" $ibCli = "C:\Program Files (x86)\BitRock InstallBuilder Enterprise $ibVersion\bin\builder-cli.exe" if (!(Test-Path $ibSetup)) { diff --git a/docs/features/passthrough.rst b/docs/features/passthrough.rst index dbaf3506..91fcb9b6 100644 --- a/docs/features/passthrough.rst +++ b/docs/features/passthrough.rst @@ -38,7 +38,7 @@ There are two important quirks to consider: - **In transparent mode, the ignore pattern is matched against the IP and ClientHello SNI host.** While we usually infer the hostname from the Host header if the ``--host`` argument is passed to mitmproxy, we do not have access to this information before the SSL handshake. If the client uses SNI however, then we treat the SNI host as an ignore target. -- **In regular mode, explicit HTTP requests are never ignored.** [#explicithttp]_ The ignore pattern is +- **In regular and upstream proxy mode, explicit HTTP requests are never ignored.** [#explicithttp]_ The ignore pattern is applied on CONNECT requests, which initiate HTTPS or clear-text WebSocket connections. Tutorial diff --git a/mitmproxy/__init__.py b/mitmproxy/__init__.py index 9697de87..e69de29b 100644 --- a/mitmproxy/__init__.py +++ b/mitmproxy/__init__.py @@ -1,3 +0,0 @@ -# https://github.com/mitmproxy/mitmproxy/issues/1809 -# import script here so that pyinstaller registers it. -from . import script # noqa diff --git a/mitmproxy/addons/clientplayback.py b/mitmproxy/addons/clientplayback.py index bed06e82..3fd96669 100644 --- a/mitmproxy/addons/clientplayback.py +++ b/mitmproxy/addons/clientplayback.py @@ -27,6 +27,7 @@ class ClientPlayback: Stop client replay. """ self.flows = [] + ctx.log.alert("Client replay stopped.") ctx.master.addons.trigger("update", []) @command.command("replay.client") @@ -35,6 +36,7 @@ class ClientPlayback: Replay requests from flows. """ self.flows = list(flows) + ctx.log.alert("Replaying %s flows." % len(self.flows)) ctx.master.addons.trigger("update", []) @command.command("replay.client.file") @@ -43,7 +45,9 @@ class ClientPlayback: flows = io.read_flows_from_paths([path]) except exceptions.FlowReadException as e: raise exceptions.CommandError(str(e)) + ctx.log.alert("Replaying %s flows." % len(self.flows)) self.flows = flows + ctx.master.addons.trigger("update", []) def configure(self, updated): if not self.configured and ctx.options.client_replay: diff --git a/mitmproxy/command.py b/mitmproxy/command.py index e1e56d3a..7bb2bf8e 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -76,11 +76,7 @@ class Command: ret = " -> " + ret return "%s %s%s" % (self.path, params, ret) - def call(self, args: typing.Sequence[str]) -> typing.Any: - """ - Call the command with a list of arguments. At this point, all - arguments are strings. - """ + def prepare_args(self, args: typing.Sequence[str]) -> typing.List[typing.Any]: verify_arg_signature(self.func, list(args), {}) remainder = [] # type: typing.Sequence[str] @@ -92,6 +88,14 @@ class Command: for arg, paramtype in zip(args, self.paramtypes): pargs.append(parsearg(self.manager, arg, paramtype)) pargs.extend(remainder) + return pargs + + def call(self, args: typing.Sequence[str]) -> typing.Any: + """ + Call the command with a list of arguments. At this point, all + arguments are strings. + """ + pargs = self.prepare_args(args) with self.manager.master.handlecontext(): ret = self.func(*pargs) @@ -121,7 +125,7 @@ ParseResult = typing.NamedTuple( class CommandManager(mitmproxy.types._CommandBase): def __init__(self, master): self.master = master - self.commands = {} + self.commands = {} # type: typing.Dict[str, Command] def collect_commands(self, addon): for i in dir(addon): diff --git a/mitmproxy/contrib/wsproto/__init__.py b/mitmproxy/contrib/wsproto/__init__.py deleted file mode 100644 index d0592bc5..00000000 --- a/mitmproxy/contrib/wsproto/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from . import compat -from . import connection -from . import events -from . import extensions -from . import frame_protocol - -__all__ = [ - 'compat', - 'connection', - 'events', - 'extensions', - 'frame_protocol', -] diff --git a/mitmproxy/contrib/wsproto/compat.py b/mitmproxy/contrib/wsproto/compat.py deleted file mode 100644 index 1911f83c..00000000 --- a/mitmproxy/contrib/wsproto/compat.py +++ /dev/null @@ -1,20 +0,0 @@ -# flake8: noqa - -import sys - - -PY2 = sys.version_info.major == 2 -PY3 = sys.version_info.major == 3 - - -if PY3: - unicode = str - - def Utf8Validator(): - return None -else: - unicode = unicode - try: - from wsaccel.utf8validator import Utf8Validator - except ImportError: - from .utf8validator import Utf8Validator diff --git a/mitmproxy/contrib/wsproto/connection.py b/mitmproxy/contrib/wsproto/connection.py deleted file mode 100644 index f994cd3a..00000000 --- a/mitmproxy/contrib/wsproto/connection.py +++ /dev/null @@ -1,477 +0,0 @@ -# -*- coding: utf-8 -*- -""" -wsproto/connection -~~~~~~~~~~~~~~ - -An implementation of a WebSocket connection. -""" - -import os -import base64 -import hashlib -from collections import deque - -from enum import Enum - -import h11 - -from .events import ( - ConnectionRequested, ConnectionEstablished, ConnectionClosed, - ConnectionFailed, TextReceived, BytesReceived, PingReceived, PongReceived -) -from .frame_protocol import FrameProtocol, ParseFailed, CloseReason, Opcode - - -# RFC6455, Section 1.3 - Opening Handshake -ACCEPT_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - - -class ConnectionState(Enum): - """ - RFC 6455, Section 4 - Opening Handshake - """ - CONNECTING = 0 - OPEN = 1 - CLOSING = 2 - CLOSED = 3 - - -class ConnectionType(Enum): - CLIENT = 1 - SERVER = 2 - - -CLIENT = ConnectionType.CLIENT -SERVER = ConnectionType.SERVER - - -# Some convenience utilities for working with HTTP headers -def _normed_header_dict(h11_headers): - # This mangles Set-Cookie headers. But it happens that we don't care about - # any of those, so it's OK. For every other HTTP header, if there are - # multiple instances then you're allowed to join them together with - # commas. - name_to_values = {} - for name, value in h11_headers: - name_to_values.setdefault(name, []).append(value) - name_to_normed_value = {} - for name, values in name_to_values.items(): - name_to_normed_value[name] = b", ".join(values) - return name_to_normed_value - - -# We use this for parsing the proposed protocol list, and for parsing the -# proposed and accepted extension lists. For the proposed protocol list it's -# fine, because the ABNF is just 1#token. But for the extension lists, it's -# wrong, because those can contain quoted strings, which can in turn contain -# commas. XX FIXME -def _split_comma_header(value): - return [piece.decode('ascii').strip() for piece in value.split(b',')] - - -class WSConnection(object): - """ - A low-level WebSocket connection object. - - This wraps two other protocol objects, an HTTP/1.1 protocol object used - to do the initial HTTP upgrade handshake and a WebSocket frame protocol - object used to exchange messages and other control frames. - - :param conn_type: Whether this object is on the client- or server-side of - a connection. To initialise as a client pass ``CLIENT`` otherwise - pass ``SERVER``. - :type conn_type: ``ConnectionType`` - - :param host: The hostname to pass to the server when acting as a client. - :type host: ``str`` - - :param resource: The resource (aka path) to pass to the server when acting - as a client. - :type resource: ``str`` - - :param extensions: A list of extensions to use on this connection. - Extensions should be instances of a subclass of - :class:`Extension <wsproto.extensions.Extension>`. - - :param subprotocols: A list of subprotocols to request when acting as a - client, ordered by preference. This has no impact on the connection - itself. - :type subprotocol: ``list`` of ``str`` - """ - - def __init__(self, conn_type, host=None, resource=None, extensions=None, - subprotocols=None): - self.client = conn_type is ConnectionType.CLIENT - - self.host = host - self.resource = resource - - self.subprotocols = subprotocols or [] - self.extensions = extensions or [] - - self.version = b'13' - - self._state = ConnectionState.CONNECTING - self._close_reason = None - - self._nonce = None - self._outgoing = b'' - self._events = deque() - self._proto = None - - if self.client: - self._upgrade_connection = h11.Connection(h11.CLIENT) - else: - self._upgrade_connection = h11.Connection(h11.SERVER) - - if self.client: - if self.host is None: - raise ValueError( - "Host must not be None for a client-side connection.") - if self.resource is None: - raise ValueError( - "Resource must not be None for a client-side connection.") - self.initiate_connection() - - def initiate_connection(self): - self._generate_nonce() - - headers = { - b"Host": self.host.encode('ascii'), - b"Upgrade": b'WebSocket', - b"Connection": b'Upgrade', - b"Sec-WebSocket-Key": self._nonce, - b"Sec-WebSocket-Version": self.version, - } - - if self.subprotocols: - headers[b"Sec-WebSocket-Protocol"] = ", ".join(self.subprotocols) - - if self.extensions: - offers = {e.name: e.offer(self) for e in self.extensions} - extensions = [] - for name, params in offers.items(): - if params is True: - extensions.append(name.encode('ascii')) - elif params: - # py34 annoyance: doesn't support bytestring formatting - extensions.append(('%s; %s' % (name, params)) - .encode("ascii")) - if extensions: - headers[b'Sec-WebSocket-Extensions'] = b', '.join(extensions) - - upgrade = h11.Request(method=b'GET', target=self.resource, - headers=headers.items()) - self._outgoing += self._upgrade_connection.send(upgrade) - - def send_data(self, payload, final=True): - """ - Send a message or part of a message to the remote peer. - - If ``final`` is ``False`` it indicates that this is part of a longer - message. If ``final`` is ``True`` it indicates that this is either a - self-contained message or the last part of a longer message. - - If ``payload`` is of type ``bytes`` then the message is flagged as - being binary If it is of type ``str`` encoded as UTF-8 and sent as - text. - - :param payload: The message body to send. - :type payload: ``bytes`` or ``str`` - - :param final: Whether there are more parts to this message to be sent. - :type final: ``bool`` - """ - - self._outgoing += self._proto.send_data(payload, final) - - def close(self, code=CloseReason.NORMAL_CLOSURE, reason=None): - self._outgoing += self._proto.close(code, reason) - self._state = ConnectionState.CLOSING - - @property - def closed(self): - return self._state is ConnectionState.CLOSED - - def bytes_to_send(self, amount=None): - """ - Return any data that is to be sent to the remote peer. - - :param amount: (optional) The maximum number of bytes to be provided. - If ``None`` or not provided it will return all available bytes. - :type amount: ``int`` - """ - - if amount is None: - data = self._outgoing - self._outgoing = b'' - else: - data = self._outgoing[:amount] - self._outgoing = self._outgoing[amount:] - - return data - - def receive_bytes(self, data): - """ - Pass some received bytes to the connection for processing. - - :param data: The data received from the remote peer. - :type data: ``bytes`` - """ - - if data is None and self._state is ConnectionState.OPEN: - # "If _The WebSocket Connection is Closed_ and no Close control - # frame was received by the endpoint (such as could occur if the - # underlying transport connection is lost), _The WebSocket - # Connection Close Code_ is considered to be 1006." - self._events.append(ConnectionClosed(CloseReason.ABNORMAL_CLOSURE)) - self._state = ConnectionState.CLOSED - return - elif data is None: - self._state = ConnectionState.CLOSED - return - - if self._state is ConnectionState.CONNECTING: - event, data = self._process_upgrade(data) - if event is not None: - self._events.append(event) - - if self._state is ConnectionState.OPEN: - self._proto.receive_bytes(data) - - def _process_upgrade(self, data): - self._upgrade_connection.receive_data(data) - while True: - try: - event = self._upgrade_connection.next_event() - except h11.RemoteProtocolError: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Bad HTTP message"), b'' - if event is h11.NEED_DATA: - break - elif self.client and isinstance(event, (h11.InformationalResponse, - h11.Response)): - data = self._upgrade_connection.trailing_data[0] - return self._establish_client_connection(event), data - elif not self.client and isinstance(event, h11.Request): - return self._process_connection_request(event), None - else: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Bad HTTP message"), b'' - - self._incoming = b'' - return None, None - - def events(self): - """ - Return a generator that provides any events that have been generated - by protocol activity. - - :returns: generator - """ - - while self._events: - yield self._events.popleft() - - if self._proto is None: - return - - try: - for frame in self._proto.received_frames(): - if frame.opcode is Opcode.PING: - assert frame.frame_finished and frame.message_finished - self._outgoing += self._proto.pong(frame.payload) - yield PingReceived(frame.payload) - - elif frame.opcode is Opcode.PONG: - assert frame.frame_finished and frame.message_finished - yield PongReceived(frame.payload) - - elif frame.opcode is Opcode.CLOSE: - code, reason = frame.payload - self.close(code, reason) - yield ConnectionClosed(code, reason) - - elif frame.opcode is Opcode.TEXT: - yield TextReceived(frame.payload, - frame.frame_finished, - frame.message_finished) - - elif frame.opcode is Opcode.BINARY: - yield BytesReceived(frame.payload, - frame.frame_finished, - frame.message_finished) - except ParseFailed as exc: - # XX FIXME: apparently autobahn intentionally deviates from the - # spec in that on protocol errors it just closes the connection - # rather than trying to send a CLOSE frame. Investigate whether we - # should do the same. - self.close(code=exc.code, reason=str(exc)) - yield ConnectionClosed(exc.code, reason=str(exc)) - - def _generate_nonce(self): - # os.urandom may be overkill for this use case, but I don't think this - # is a bottleneck, and better safe than sorry... - self._nonce = base64.b64encode(os.urandom(16)) - - def _generate_accept_token(self, token): - accept_token = token + ACCEPT_GUID - accept_token = hashlib.sha1(accept_token).digest() - return base64.b64encode(accept_token) - - def _establish_client_connection(self, event): - if event.status_code != 101: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Bad status code from server") - headers = _normed_header_dict(event.headers) - if headers[b'connection'].lower() != b'upgrade': - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Connection: Upgrade header") - if headers[b'upgrade'].lower() != b'websocket': - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Upgrade: WebSocket header") - - accept_token = self._generate_accept_token(self._nonce) - if headers[b'sec-websocket-accept'] != accept_token: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Bad accept token") - - subprotocol = headers.get(b'sec-websocket-protocol', None) - if subprotocol is not None: - subprotocol = subprotocol.decode('ascii') - if subprotocol not in self.subprotocols: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "unrecognized subprotocol {!r}" - .format(subprotocol)) - - extensions = headers.get(b'sec-websocket-extensions', None) - if extensions: - accepts = _split_comma_header(extensions) - - for accept in accepts: - name = accept.split(';', 1)[0].strip() - for extension in self.extensions: - if extension.name == name: - extension.finalize(self, accept) - break - else: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "unrecognized extension {!r}" - .format(name)) - - self._proto = FrameProtocol(self.client, self.extensions) - self._state = ConnectionState.OPEN - return ConnectionEstablished(subprotocol, extensions) - - def _process_connection_request(self, event): - if event.method != b'GET': - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Request method must be GET") - headers = _normed_header_dict(event.headers) - if headers[b'connection'].lower() != b'upgrade': - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Connection: Upgrade header") - if headers[b'upgrade'].lower() != b'websocket': - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Upgrade: WebSocket header") - - if b'sec-websocket-version' not in headers: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Sec-WebSocket-Version header") - # XX FIXME: need to check Sec-Websocket-Version, and respond with a - # 400 if it's not what we expect - - if b'sec-websocket-protocol' in headers: - proposed_subprotocols = _split_comma_header( - headers[b'sec-websocket-protocol']) - else: - proposed_subprotocols = [] - - if b'sec-websocket-key' not in headers: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Sec-WebSocket-Key header") - - return ConnectionRequested(proposed_subprotocols, event) - - def _extension_accept(self, extensions_header): - accepts = {} - offers = _split_comma_header(extensions_header) - - for offer in offers: - name = offer.split(';', 1)[0].strip() - for extension in self.extensions: - if extension.name == name: - accept = extension.accept(self, offer) - if accept is True: - accepts[extension.name] = True - elif accept is not False and accept is not None: - accepts[extension.name] = accept.encode('ascii') - - if accepts: - extensions = [] - for name, params in accepts.items(): - if params is True: - extensions.append(name.encode('ascii')) - else: - # py34 annoyance: doesn't support bytestring formatting - params = params.decode("ascii") - extensions.append(('%s; %s' % (name, params)) - .encode("ascii")) - return b', '.join(extensions) - - return None - - def accept(self, event, subprotocol=None): - request = event.h11request - request_headers = _normed_header_dict(request.headers) - - nonce = request_headers[b'sec-websocket-key'] - accept_token = self._generate_accept_token(nonce) - - headers = { - b"Upgrade": b'WebSocket', - b"Connection": b'Upgrade', - b"Sec-WebSocket-Accept": accept_token, - } - - if subprotocol is not None: - if subprotocol not in event.proposed_subprotocols: - raise ValueError( - "unexpected subprotocol {!r}".format(subprotocol)) - headers[b'Sec-WebSocket-Protocol'] = subprotocol - - extensions = request_headers.get(b'sec-websocket-extensions', None) - if extensions: - accepts = self._extension_accept(extensions) - if accepts: - headers[b"Sec-WebSocket-Extensions"] = accepts - - response = h11.InformationalResponse(status_code=101, - headers=headers.items()) - self._outgoing += self._upgrade_connection.send(response) - self._proto = FrameProtocol(self.client, self.extensions) - self._state = ConnectionState.OPEN - - def ping(self, payload=None): - """ - Send a PING message to the peer. - - :param payload: an optional payload to send with the message - """ - - payload = bytes(payload or b'') - self._outgoing += self._proto.ping(payload) - - def pong(self, payload=None): - """ - Send a PONG message to the peer. - - This method can be used to send an unsolicted PONG to the peer. - It is not needed otherwise since every received PING causes a - corresponding PONG to be sent automatically. - - :param payload: an optional payload to send with the message - """ - - payload = bytes(payload or b'') - self._outgoing += self._proto.pong(payload) diff --git a/mitmproxy/contrib/wsproto/events.py b/mitmproxy/contrib/wsproto/events.py deleted file mode 100644 index 73ce27aa..00000000 --- a/mitmproxy/contrib/wsproto/events.py +++ /dev/null @@ -1,81 +0,0 @@ -# -*- coding: utf-8 -*- -""" -wsproto/events -~~~~~~~~~~ - -Events that result from processing data on a WebSocket connection. -""" - - -class ConnectionRequested(object): - def __init__(self, proposed_subprotocols, h11request): - self.proposed_subprotocols = proposed_subprotocols - self.h11request = h11request - - def __repr__(self): - path = self.h11request.target - - headers = dict(self.h11request.headers) - host = headers[b'host'] - version = headers[b'sec-websocket-version'] - subprotocol = headers.get(b'sec-websocket-protocol', None) - extensions = [] - - fmt = '<%s host=%s path=%s version=%s subprotocol=%r extensions=%r>' - return fmt % (self.__class__.__name__, host, path, version, - subprotocol, extensions) - - -class ConnectionEstablished(object): - def __init__(self, subprotocol=None, extensions=None): - self.subprotocol = subprotocol - self.extensions = extensions - if self.extensions is None: - self.extensions = [] - - def __repr__(self): - return '<ConnectionEstablished subprotocol=%r extensions=%r>' % \ - (self.subprotocol, self.extensions) - - -class ConnectionClosed(object): - def __init__(self, code, reason=None): - self.code = code - self.reason = reason - - def __repr__(self): - return '<%s code=%r reason="%s">' % (self.__class__.__name__, - self.code, self.reason) - - -class ConnectionFailed(ConnectionClosed): - pass - - -class DataReceived(object): - def __init__(self, data, frame_finished, message_finished): - self.data = data - # This has no semantic content, but is provided just in case some - # weird edge case user wants to be able to reconstruct the - # fragmentation pattern of the original stream. You don't want it: - self.frame_finished = frame_finished - # This is the field that you almost certainly want: - self.message_finished = message_finished - - -class TextReceived(DataReceived): - pass - - -class BytesReceived(DataReceived): - pass - - -class PingReceived(object): - def __init__(self, payload): - self.payload = payload - - -class PongReceived(object): - def __init__(self, payload): - self.payload = payload diff --git a/mitmproxy/contrib/wsproto/extensions.py b/mitmproxy/contrib/wsproto/extensions.py deleted file mode 100644 index 0e0d2018..00000000 --- a/mitmproxy/contrib/wsproto/extensions.py +++ /dev/null @@ -1,259 +0,0 @@ -# type: ignore - -# -*- coding: utf-8 -*- -""" -wsproto/extensions -~~~~~~~~~~~~~~ - -WebSocket extensions. -""" - -import zlib - -from .frame_protocol import CloseReason, Opcode, RsvBits - - -class Extension(object): - name = None - - def enabled(self): - return False - - def offer(self, connection): - pass - - def accept(self, connection, offer): - pass - - def finalize(self, connection, offer): - pass - - def frame_inbound_header(self, proto, opcode, rsv, payload_length): - return RsvBits(False, False, False) - - def frame_inbound_payload_data(self, proto, data): - return data - - def frame_inbound_complete(self, proto, fin): - pass - - def frame_outbound(self, proto, opcode, rsv, data, fin): - return (rsv, data) - - -class PerMessageDeflate(Extension): - name = 'permessage-deflate' - - DEFAULT_CLIENT_MAX_WINDOW_BITS = 15 - DEFAULT_SERVER_MAX_WINDOW_BITS = 15 - - def __init__(self, client_no_context_takeover=False, - client_max_window_bits=None, server_no_context_takeover=False, - server_max_window_bits=None): - self.client_no_context_takeover = client_no_context_takeover - if client_max_window_bits is None: - client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS - self.client_max_window_bits = client_max_window_bits - self.server_no_context_takeover = server_no_context_takeover - if server_max_window_bits is None: - server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS - self.server_max_window_bits = server_max_window_bits - - self._compressor = None - self._decompressor = None - # This refers to the current frame - self._inbound_is_compressible = None - # This refers to the ongoing message (which might span multiple - # frames). Only the first frame in a fragmented message is flagged for - # compression, so this carries that bit forward. - self._inbound_compressed = None - - self._enabled = False - - def _compressible_opcode(self, opcode): - return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION) - - def enabled(self): - return self._enabled - - def offer(self, connection): - parameters = [ - 'client_max_window_bits=%d' % self.client_max_window_bits, - 'server_max_window_bits=%d' % self.server_max_window_bits, - ] - - if self.client_no_context_takeover: - parameters.append('client_no_context_takeover') - if self.server_no_context_takeover: - parameters.append('server_no_context_takeover') - - return '; '.join(parameters) - - def finalize(self, connection, offer): - bits = [b.strip() for b in offer.split(';')] - for bit in bits[1:]: - if bit.startswith('client_no_context_takeover'): - self.client_no_context_takeover = True - elif bit.startswith('server_no_context_takeover'): - self.server_no_context_takeover = True - elif bit.startswith('client_max_window_bits'): - self.client_max_window_bits = int(bit.split('=', 1)[1].strip()) - elif bit.startswith('server_max_window_bits'): - self.server_max_window_bits = int(bit.split('=', 1)[1].strip()) - - self._enabled = True - - def _parse_params(self, params): - client_max_window_bits = None - server_max_window_bits = None - - bits = [b.strip() for b in params.split(';')] - for bit in bits[1:]: - if bit.startswith('client_no_context_takeover'): - self.client_no_context_takeover = True - elif bit.startswith('server_no_context_takeover'): - self.server_no_context_takeover = True - elif bit.startswith('client_max_window_bits'): - if '=' in bit: - client_max_window_bits = int(bit.split('=', 1)[1].strip()) - else: - client_max_window_bits = self.client_max_window_bits - elif bit.startswith('server_max_window_bits'): - if '=' in bit: - server_max_window_bits = int(bit.split('=', 1)[1].strip()) - else: - server_max_window_bits = self.server_max_window_bits - - return client_max_window_bits, server_max_window_bits - - def accept(self, connection, offer): - client_max_window_bits, server_max_window_bits = \ - self._parse_params(offer) - - self._enabled = True - - parameters = [] - - if self.client_no_context_takeover: - parameters.append('client_no_context_takeover') - if client_max_window_bits is not None: - parameters.append('client_max_window_bits=%d' % - client_max_window_bits) - self.client_max_window_bits = client_max_window_bits - if self.server_no_context_takeover: - parameters.append('server_no_context_takeover') - if server_max_window_bits is not None: - parameters.append('server_max_window_bits=%d' % - server_max_window_bits) - self.server_max_window_bits = server_max_window_bits - - return '; '.join(parameters) - - def frame_inbound_header(self, proto, opcode, rsv, payload_length): - if rsv.rsv1 and opcode.iscontrol(): - return CloseReason.PROTOCOL_ERROR - elif rsv.rsv1 and opcode is Opcode.CONTINUATION: - return CloseReason.PROTOCOL_ERROR - - self._inbound_is_compressible = self._compressible_opcode(opcode) - - if self._inbound_compressed is None: - self._inbound_compressed = rsv.rsv1 - if self._inbound_compressed: - assert self._inbound_is_compressible - if proto.client: - bits = self.server_max_window_bits - else: - bits = self.client_max_window_bits - if self._decompressor is None: - self._decompressor = zlib.decompressobj(-int(bits)) - - return RsvBits(True, False, False) - - def frame_inbound_payload_data(self, proto, data): - if not self._inbound_compressed or not self._inbound_is_compressible: - return data - - try: - return self._decompressor.decompress(bytes(data)) - except zlib.error: - return CloseReason.INVALID_FRAME_PAYLOAD_DATA - - def frame_inbound_complete(self, proto, fin): - if not fin: - return - elif not self._inbound_is_compressible: - return - elif not self._inbound_compressed: - return - - try: - data = self._decompressor.decompress(b'\x00\x00\xff\xff') - data += self._decompressor.flush() - except zlib.error: - return CloseReason.INVALID_FRAME_PAYLOAD_DATA - - if proto.client: - no_context_takeover = self.server_no_context_takeover - else: - no_context_takeover = self.client_no_context_takeover - - if no_context_takeover: - self._decompressor = None - - self._inbound_compressed = None - - return data - - def frame_outbound(self, proto, opcode, rsv, data, fin): - if not self._compressible_opcode(opcode): - return (rsv, data) - - if opcode is not Opcode.CONTINUATION: - rsv = RsvBits(True, *rsv[1:]) - - if self._compressor is None: - assert opcode is not Opcode.CONTINUATION - if proto.client: - bits = self.client_max_window_bits - else: - bits = self.server_max_window_bits - self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, - zlib.DEFLATED, -int(bits)) - - data = self._compressor.compress(bytes(data)) - - if fin: - data += self._compressor.flush(zlib.Z_SYNC_FLUSH) - data = data[:-4] - - if proto.client: - no_context_takeover = self.client_no_context_takeover - else: - no_context_takeover = self.server_no_context_takeover - - if no_context_takeover: - self._compressor = None - - return (rsv, data) - - def __repr__(self): - descr = ['client_max_window_bits=%d' % self.client_max_window_bits] - if self.client_no_context_takeover: - descr.append('client_no_context_takeover') - descr.append('server_max_window_bits=%d' % self.server_max_window_bits) - if self.server_no_context_takeover: - descr.append('server_no_context_takeover') - - descr = '; '.join(descr) - - return '<%s %s>' % (self.__class__.__name__, descr) - - -#: SUPPORTED_EXTENSIONS maps all supported extension names to their class. -#: This can be used to iterate all supported extensions of wsproto, instantiate -#: new extensions based on their name, or check if a given extension is -#: supported or not. -SUPPORTED_EXTENSIONS = { - PerMessageDeflate.name: PerMessageDeflate -} diff --git a/mitmproxy/contrib/wsproto/frame_protocol.py b/mitmproxy/contrib/wsproto/frame_protocol.py deleted file mode 100644 index 30f146c6..00000000 --- a/mitmproxy/contrib/wsproto/frame_protocol.py +++ /dev/null @@ -1,581 +0,0 @@ -# type: ignore - -# -*- coding: utf-8 -*- -""" -wsproto/frame_protocol -~~~~~~~~~~~~~~ - -WebSocket frame protocol implementation. -""" - -import os -import itertools -import struct -from codecs import getincrementaldecoder -from collections import namedtuple - -from enum import Enum, IntEnum - -from .compat import unicode, Utf8Validator - -try: - from wsaccel.xormask import XorMaskerSimple -except ImportError: - class XorMaskerSimple: - def __init__(self, masking_key): - self._maskbytes = itertools.cycle(bytearray(masking_key)) - - def process(self, data): - maskbytes = self._maskbytes - return bytearray(b ^ next(maskbytes) for b in bytearray(data)) - - -class XorMaskerNull: - def process(self, data): - return data - - -# RFC6455, Section 5.2 - Base Framing Protocol - -# Payload length constants -PAYLOAD_LENGTH_TWO_BYTE = 126 -PAYLOAD_LENGTH_EIGHT_BYTE = 127 -MAX_PAYLOAD_NORMAL = 125 -MAX_PAYLOAD_TWO_BYTE = 2 ** 16 - 1 -MAX_PAYLOAD_EIGHT_BYTE = 2 ** 64 - 1 -MAX_FRAME_PAYLOAD = MAX_PAYLOAD_EIGHT_BYTE - -# MASK and PAYLOAD LEN are packed into a byte -MASK_MASK = 0x80 -PAYLOAD_LEN_MASK = 0x7f - -# FIN, RSV[123] and OPCODE are packed into a single byte -FIN_MASK = 0x80 -RSV1_MASK = 0x40 -RSV2_MASK = 0x20 -RSV3_MASK = 0x10 -OPCODE_MASK = 0x0f - - -class Opcode(IntEnum): - """ - RFC 6455, Section 5.2 - Base Framing Protocol - """ - CONTINUATION = 0x0 - TEXT = 0x1 - BINARY = 0x2 - CLOSE = 0x8 - PING = 0x9 - PONG = 0xA - - def iscontrol(self): - return bool(self & 0x08) - - -class CloseReason(IntEnum): - """ - RFC 6455, Section 7.4.1 - Defined Status Codes - """ - NORMAL_CLOSURE = 1000 - GOING_AWAY = 1001 - PROTOCOL_ERROR = 1002 - UNSUPPORTED_DATA = 1003 - NO_STATUS_RCVD = 1005 - ABNORMAL_CLOSURE = 1006 - INVALID_FRAME_PAYLOAD_DATA = 1007 - POLICY_VIOLATION = 1008 - MESSAGE_TOO_BIG = 1009 - MANDATORY_EXT = 1010 - INTERNAL_ERROR = 1011 - SERVICE_RESTART = 1012 - TRY_AGAIN_LATER = 1013 - TLS_HANDSHAKE_FAILED = 1015 - - -# RFC 6455, Section 7.4.1 - Defined Status Codes -LOCAL_ONLY_CLOSE_REASONS = ( - CloseReason.NO_STATUS_RCVD, - CloseReason.ABNORMAL_CLOSURE, - CloseReason.TLS_HANDSHAKE_FAILED, -) - - -# RFC 6455, Section 7.4.2 - Status Code Ranges -MIN_CLOSE_REASON = 1000 -MIN_PROTOCOL_CLOSE_REASON = 1000 -MAX_PROTOCOL_CLOSE_REASON = 2999 -MIN_LIBRARY_CLOSE_REASON = 3000 -MAX_LIBRARY_CLOSE_REASON = 3999 -MIN_PRIVATE_CLOSE_REASON = 4000 -MAX_PRIVATE_CLOSE_REASON = 4999 -MAX_CLOSE_REASON = 4999 - - -NULL_MASK = struct.pack("!I", 0) - - -class ParseFailed(Exception): - def __init__(self, msg, code=CloseReason.PROTOCOL_ERROR): - super(ParseFailed, self).__init__(msg) - self.code = code - - -Header = namedtuple("Header", "fin rsv opcode payload_len masking_key".split()) - - -Frame = namedtuple("Frame", - "opcode payload frame_finished message_finished".split()) - - -RsvBits = namedtuple("RsvBits", "rsv1 rsv2 rsv3".split()) - - -def _truncate_utf8(data, nbytes): - if len(data) <= nbytes: - return data - - # Truncate - data = data[:nbytes] - # But we might have cut a codepoint in half, in which case we want to - # discard the partial character so the data is at least - # well-formed. This is a little inefficient since it processes the - # whole message twice when in theory we could just peek at the last - # few characters, but since this is only used for close messages (max - # length = 125 bytes) it really doesn't matter. - data = data.decode("utf-8", errors="ignore").encode("utf-8") - return data - - -class Buffer(object): - def __init__(self, initial_bytes=None): - self.buffer = bytearray() - self.bytes_used = 0 - if initial_bytes: - self.feed(initial_bytes) - - def feed(self, new_bytes): - self.buffer += new_bytes - - def consume_at_most(self, nbytes): - if not nbytes: - return bytearray() - - data = self.buffer[self.bytes_used:self.bytes_used + nbytes] - self.bytes_used += len(data) - return data - - def consume_exactly(self, nbytes): - if len(self.buffer) - self.bytes_used < nbytes: - return None - - return self.consume_at_most(nbytes) - - def commit(self): - # In CPython 3.4+, del[:n] is amortized O(n), *not* quadratic - del self.buffer[:self.bytes_used] - self.bytes_used = 0 - - def rollback(self): - self.bytes_used = 0 - - def __len__(self): - return len(self.buffer) - - -class MessageDecoder(object): - def __init__(self): - self.opcode = None - self.validator = None - self.decoder = None - - def process_frame(self, frame): - assert not frame.opcode.iscontrol() - - if self.opcode is None: - if frame.opcode is Opcode.CONTINUATION: - raise ParseFailed("unexpected CONTINUATION") - self.opcode = frame.opcode - elif frame.opcode is not Opcode.CONTINUATION: - raise ParseFailed("expected CONTINUATION, got %r" % frame.opcode) - - if frame.opcode is Opcode.TEXT: - self.validator = Utf8Validator() - self.decoder = getincrementaldecoder("utf-8")() - - finished = frame.frame_finished and frame.message_finished - - if self.decoder is not None: - data = self.decode_payload(frame.payload, finished) - else: - data = frame.payload - - frame = Frame(self.opcode, data, frame.frame_finished, finished) - - if finished: - self.opcode = None - self.decoder = None - - return frame - - def decode_payload(self, data, finished): - if self.validator is not None: - results = self.validator.validate(bytes(data)) - if not results[0] or (finished and not results[1]): - raise ParseFailed(u'encountered invalid UTF-8 while processing' - ' text message at payload octet index %d' % - results[3], - CloseReason.INVALID_FRAME_PAYLOAD_DATA) - - try: - return self.decoder.decode(data, finished) - except UnicodeDecodeError as exc: - raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA) - - -class FrameDecoder(object): - def __init__(self, client, extensions=None): - self.client = client - self.extensions = extensions or [] - - self.buffer = Buffer() - - self.header = None - self.effective_opcode = None - self.masker = None - self.payload_required = 0 - self.payload_consumed = 0 - - def receive_bytes(self, data): - self.buffer.feed(data) - - def process_buffer(self): - if not self.header: - if not self.parse_header(): - return None - - if len(self.buffer) < self.payload_required: - return None - - payload_remaining = self.header.payload_len - self.payload_consumed - payload = self.buffer.consume_at_most(payload_remaining) - if not payload and self.header.payload_len > 0: - return None - self.buffer.commit() - - self.payload_consumed += len(payload) - finished = self.payload_consumed == self.header.payload_len - - payload = self.masker.process(payload) - - for extension in self.extensions: - payload = extension.frame_inbound_payload_data(self, payload) - if isinstance(payload, CloseReason): - raise ParseFailed("error in extension", payload) - - if finished: - final = bytearray() - for extension in self.extensions: - result = extension.frame_inbound_complete(self, - self.header.fin) - if isinstance(result, CloseReason): - raise ParseFailed("error in extension", result) - if result is not None: - final += result - payload += final - - frame = Frame(self.effective_opcode, payload, finished, - self.header.fin) - - if finished: - self.header = None - self.effective_opcode = None - self.masker = None - else: - self.effective_opcode = Opcode.CONTINUATION - - return frame - - def parse_header(self): - data = self.buffer.consume_exactly(2) - if data is None: - self.buffer.rollback() - return False - - fin = bool(data[0] & FIN_MASK) - rsv = RsvBits(bool(data[0] & RSV1_MASK), - bool(data[0] & RSV2_MASK), - bool(data[0] & RSV3_MASK)) - opcode = data[0] & OPCODE_MASK - try: - opcode = Opcode(opcode) - except ValueError: - raise ParseFailed("Invalid opcode {:#x}".format(opcode)) - - if opcode.iscontrol() and not fin: - raise ParseFailed("Invalid attempt to fragment control frame") - - has_mask = bool(data[1] & MASK_MASK) - payload_len = data[1] & PAYLOAD_LEN_MASK - payload_len = self.parse_extended_payload_length(opcode, payload_len) - if payload_len is None: - self.buffer.rollback() - return False - - self.extension_processing(opcode, rsv, payload_len) - - if has_mask and self.client: - raise ParseFailed("client received unexpected masked frame") - if not has_mask and not self.client: - raise ParseFailed("server received unexpected unmasked frame") - if has_mask: - masking_key = self.buffer.consume_exactly(4) - if masking_key is None: - self.buffer.rollback() - return False - self.masker = XorMaskerSimple(masking_key) - else: - self.masker = XorMaskerNull() - - self.buffer.commit() - self.header = Header(fin, rsv, opcode, payload_len, None) - self.effective_opcode = self.header.opcode - if self.header.opcode.iscontrol(): - self.payload_required = payload_len - else: - self.payload_required = 0 - self.payload_consumed = 0 - return True - - def parse_extended_payload_length(self, opcode, payload_len): - if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL: - raise ParseFailed("Control frame with payload len > 125") - if payload_len == PAYLOAD_LENGTH_TWO_BYTE: - data = self.buffer.consume_exactly(2) - if data is None: - return None - (payload_len,) = struct.unpack("!H", data) - if payload_len <= MAX_PAYLOAD_NORMAL: - raise ParseFailed( - "Payload length used 2 bytes when 1 would have sufficed") - elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE: - data = self.buffer.consume_exactly(8) - if data is None: - return None - (payload_len,) = struct.unpack("!Q", data) - if payload_len <= MAX_PAYLOAD_TWO_BYTE: - raise ParseFailed( - "Payload length used 8 bytes when 2 would have sufficed") - if payload_len >> 63: - # I'm not sure why this is illegal, but that's what the RFC - # says, so... - raise ParseFailed("8-byte payload length with non-zero MSB") - - return payload_len - - def extension_processing(self, opcode, rsv, payload_len): - rsv_used = [False, False, False] - for extension in self.extensions: - result = extension.frame_inbound_header(self, opcode, rsv, - payload_len) - if isinstance(result, CloseReason): - raise ParseFailed("error in extension", result) - for bit, used in enumerate(result): - if used: - rsv_used[bit] = True - for expected, found in zip(rsv_used, rsv): - if found and not expected: - raise ParseFailed("Reserved bit set unexpectedly") - - -class FrameProtocol(object): - class State(Enum): - HEADER = 1 - PAYLOAD = 2 - FRAME_COMPLETE = 3 - FAILED = 4 - - def __init__(self, client, extensions): - self.client = client - self.extensions = [ext for ext in extensions if ext.enabled()] - - # Global state - self._frame_decoder = FrameDecoder(self.client, self.extensions) - self._message_decoder = MessageDecoder() - self._parse_more = self.parse_more_gen() - - self._outbound_opcode = None - - def _process_close(self, frame): - data = frame.payload - - if not data: - # "If this Close control frame contains no status code, _The - # WebSocket Connection Close Code_ is considered to be 1005" - data = (CloseReason.NO_STATUS_RCVD, "") - elif len(data) == 1: - raise ParseFailed("CLOSE with 1 byte payload") - else: - (code,) = struct.unpack("!H", data[:2]) - if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON: - raise ParseFailed("CLOSE with invalid code") - try: - code = CloseReason(code) - except ValueError: - pass - if code in LOCAL_ONLY_CLOSE_REASONS: - raise ParseFailed( - "remote CLOSE with local-only reason") - if not isinstance(code, CloseReason) and \ - code <= MAX_PROTOCOL_CLOSE_REASON: - raise ParseFailed( - "CLOSE with unknown reserved code") - validator = Utf8Validator() - if validator is not None: - results = validator.validate(bytes(data[2:])) - if not (results[0] and results[1]): - raise ParseFailed(u'encountered invalid UTF-8 while' - ' processing close message at payload' - ' octet index %d' % - results[3], - CloseReason.INVALID_FRAME_PAYLOAD_DATA) - try: - reason = data[2:].decode("utf-8") - except UnicodeDecodeError as exc: - raise ParseFailed( - "Error decoding CLOSE reason: " + str(exc), - CloseReason.INVALID_FRAME_PAYLOAD_DATA) - data = (code, reason) - - return Frame(frame.opcode, data, frame.frame_finished, - frame.message_finished) - - def parse_more_gen(self): - # Consume as much as we can from self._buffer, yielding events, and - # then yield None when we need more data. Or raise ParseFailed. - - # XX FIXME this should probably be refactored so that we never see - # disabled extensions in the first place... - self.extensions = [ext for ext in self.extensions if ext.enabled()] - closed = False - - while not closed: - frame = self._frame_decoder.process_buffer() - - if frame is not None: - if not frame.opcode.iscontrol(): - frame = self._message_decoder.process_frame(frame) - elif frame.opcode == Opcode.CLOSE: - frame = self._process_close(frame) - closed = True - - yield frame - - def receive_bytes(self, data): - self._frame_decoder.receive_bytes(data) - - def received_frames(self): - for event in self._parse_more: - if event is None: - break - else: - yield event - - def close(self, code=None, reason=None): - payload = bytearray() - if code is None and reason is not None: - raise TypeError("cannot specify a reason without a code") - if code in LOCAL_ONLY_CLOSE_REASONS: - code = CloseReason.NORMAL_CLOSURE - if code is not None: - payload += bytearray(struct.pack('!H', code)) - if reason is not None: - payload += _truncate_utf8(reason.encode('utf-8'), - MAX_PAYLOAD_NORMAL - 2) - - return self._serialize_frame(Opcode.CLOSE, payload) - - def ping(self, payload=b''): - return self._serialize_frame(Opcode.PING, payload) - - def pong(self, payload=b''): - return self._serialize_frame(Opcode.PONG, payload) - - def send_data(self, payload=b'', fin=True): - if isinstance(payload, (bytes, bytearray, memoryview)): - opcode = Opcode.BINARY - elif isinstance(payload, unicode): - opcode = Opcode.TEXT - payload = payload.encode('utf-8') - else: - raise ValueError('Must provide bytes or text') - - if self._outbound_opcode is None: - self._outbound_opcode = opcode - elif self._outbound_opcode is not opcode: - raise TypeError('Data type mismatch inside message') - else: - opcode = Opcode.CONTINUATION - - if fin: - self._outbound_opcode = None - - return self._serialize_frame(opcode, payload, fin) - - def _make_fin_rsv_opcode(self, fin, rsv, opcode): - fin = int(fin) << 7 - rsv = (int(rsv.rsv1) << 6) + (int(rsv.rsv2) << 5) + \ - (int(rsv.rsv3) << 4) - opcode = int(opcode) - - return fin | rsv | opcode - - def _serialize_frame(self, opcode, payload=b'', fin=True): - rsv = RsvBits(False, False, False) - for extension in reversed(self.extensions): - rsv, payload = extension.frame_outbound(self, opcode, rsv, payload, - fin) - - fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode) - - payload_length = len(payload) - quad_payload = False - if payload_length <= MAX_PAYLOAD_NORMAL: - first_payload = payload_length - second_payload = None - elif payload_length <= MAX_PAYLOAD_TWO_BYTE: - first_payload = PAYLOAD_LENGTH_TWO_BYTE - second_payload = payload_length - else: - first_payload = PAYLOAD_LENGTH_EIGHT_BYTE - second_payload = payload_length - quad_payload = True - - if self.client: - first_payload |= 1 << 7 - - header = bytearray([fin_rsv_opcode, first_payload]) - if second_payload is not None: - if opcode.iscontrol(): - raise ValueError("payload too long for control frame") - if quad_payload: - header += bytearray(struct.pack('!Q', second_payload)) - else: - header += bytearray(struct.pack('!H', second_payload)) - - if self.client: - # "The masking key is a 32-bit value chosen at random by the - # client. When preparing a masked frame, the client MUST pick a - # fresh masking key from the set of allowed 32-bit values. The - # masking key needs to be unpredictable; thus, the masking key - # MUST be derived from a strong source of entropy, and the masking - # key for a given frame MUST NOT make it simple for a server/proxy - # to predict the masking key for a subsequent frame. The - # unpredictability of the masking key is essential to prevent - # authors of malicious applications from selecting the bytes that - # appear on the wire." - # -- https://tools.ietf.org/html/rfc6455#section-5.3 - masking_key = os.urandom(4) - masker = XorMaskerSimple(masking_key) - return header + masking_key + masker.process(payload) - - return header + payload diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 92f99518..2d8458a5 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -2,10 +2,10 @@ import socket from OpenSSL import SSL -from mitmproxy.contrib import wsproto -from mitmproxy.contrib.wsproto import events -from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection -from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate +import wsproto +from wsproto import events +from wsproto.connection import ConnectionType, WSConnection +from wsproto.extensions import PerMessageDeflate from mitmproxy import exceptions from mitmproxy import flow diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index 20d54bc6..298770c1 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -272,7 +272,7 @@ class ConsoleAddon: @command.command("console.command") def console_command(self, *partial: str) -> None: """ - Prompt the user to edit a command with a (possilby empty) starting value. + Prompt the user to edit a command with a (possibly empty) starting value. """ signals.status_prompt_command.send(partial=" ".join(partial)) # type: ignore diff --git a/mitmproxy/tools/console/grideditor/col.py b/mitmproxy/tools/console/grideditor/col.py deleted file mode 100644 index 3331f3e7..00000000 --- a/mitmproxy/tools/console/grideditor/col.py +++ /dev/null @@ -1,67 +0,0 @@ -import typing - -import urwid - -from mitmproxy.tools.console import signals -from mitmproxy.tools.console.grideditor import base -from mitmproxy.utils import strutils - -strbytes = typing.Union[str, bytes] - - -class Column(base.Column): - def Display(self, data): - return Display(data) - - def Edit(self, data): - return Edit(data) - - def blank(self): - return "" - - def keypress(self, key, editor): - if key in ["m_select"]: - editor.walker.start_edit() - else: - return key - - -class Display(base.Cell): - def __init__(self, data: strbytes) -> None: - self.data = data - if isinstance(data, bytes): - escaped = strutils.bytes_to_escaped_str(data) - else: - escaped = data.encode() - w = urwid.Text(escaped, wrap="any") - super().__init__(w) - - def get_data(self) -> strbytes: - return self.data - - -class Edit(base.Cell): - def __init__(self, data: strbytes) -> None: - if isinstance(data, bytes): - escaped = strutils.bytes_to_escaped_str(data) - else: - escaped = data.encode() - self.type = type(data) # type: typing.Type - w = urwid.Edit(edit_text=escaped, wrap="any", multiline=True) - w = urwid.AttrWrap(w, "editfield") - super().__init__(w) - - def get_data(self) -> strbytes: - txt = self._w.get_text()[0].strip() - try: - if self.type == bytes: - return strutils.escaped_str_to_bytes(txt) - else: - return txt.decode() - except ValueError: - signals.status_message.send( - self, - message="Invalid Python-style string encoding.", - expire=1000 - ) - raise diff --git a/mitmproxy/tools/console/grideditor/col_text.py b/mitmproxy/tools/console/grideditor/col_text.py index f0ac06f8..32518670 100644 --- a/mitmproxy/tools/console/grideditor/col_text.py +++ b/mitmproxy/tools/console/grideditor/col_text.py @@ -21,7 +21,7 @@ class Column(col_bytes.Column): return TEdit(data, self.encoding_args) def blank(self): - return u"" + return "" # This is the same for both edit and display. diff --git a/mitmproxy/tools/console/grideditor/col_viewany.py b/mitmproxy/tools/console/grideditor/col_viewany.py new file mode 100644 index 00000000..f5d35eee --- /dev/null +++ b/mitmproxy/tools/console/grideditor/col_viewany.py @@ -0,0 +1,33 @@ +""" +A display-only column that displays any data type. +""" + +import typing + +import urwid +from mitmproxy.tools.console.grideditor import base +from mitmproxy.utils import strutils + + +class Column(base.Column): + def Display(self, data): + return Display(data) + + Edit = Display + + def blank(self): + return "" + + +class Display(base.Cell): + def __init__(self, data: typing.Any) -> None: + self.data = data + if isinstance(data, bytes): + data = strutils.bytes_to_escaped_str(data) + if not isinstance(data, str): + data = repr(data) + w = urwid.Text(data, wrap="any") + super().__init__(w) + + def get_data(self) -> typing.Any: + return self.data diff --git a/mitmproxy/tools/console/grideditor/editors.py b/mitmproxy/tools/console/grideditor/editors.py index b5d16737..fbe48a1a 100644 --- a/mitmproxy/tools/console/grideditor/editors.py +++ b/mitmproxy/tools/console/grideditor/editors.py @@ -1,13 +1,14 @@ +import typing from mitmproxy import exceptions +from mitmproxy.net.http import Headers from mitmproxy.tools.console import layoutwidget +from mitmproxy.tools.console import signals from mitmproxy.tools.console.grideditor import base -from mitmproxy.tools.console.grideditor import col -from mitmproxy.tools.console.grideditor import col_text from mitmproxy.tools.console.grideditor import col_bytes from mitmproxy.tools.console.grideditor import col_subgrid -from mitmproxy.tools.console import signals -from mitmproxy.net.http import Headers +from mitmproxy.tools.console.grideditor import col_text +from mitmproxy.tools.console.grideditor import col_viewany class QueryEditor(base.FocusEditor): @@ -67,7 +68,6 @@ class RequestFormEditor(base.FocusEditor): class PathEditor(base.FocusEditor): # TODO: Next row on enter? - title = "Edit Path Components" columns = [ col_text.Column("Component"), @@ -175,11 +175,22 @@ class OptionsEditor(base.GridEditor, layoutwidget.LayoutWidget): class DataViewer(base.GridEditor, layoutwidget.LayoutWidget): title = None # type: str - def __init__(self, master, vals): + def __init__( + self, + master, + vals: typing.Union[ + typing.List[typing.List[typing.Any]], + typing.List[typing.Any], + str, + ]) -> None: if vals: + # Whatever vals is, make it a list of rows containing lists of column values. + if isinstance(vals, str): + vals = [vals] if not isinstance(vals[0], list): vals = [[i] for i in vals] - self.columns = [col.Column("")] * len(vals[0]) + + self.columns = [col_viewany.Column("")] * len(vals[0]) super().__init__(master, vals, self.callback) def callback(self, vals): diff --git a/mitmproxy/utils/debug.py b/mitmproxy/utils/debug.py index de01b12c..e8eca906 100644 --- a/mitmproxy/utils/debug.py +++ b/mitmproxy/utils/debug.py @@ -1,43 +1,24 @@ import gc import os +import platform +import re +import signal import sys import threading -import signal -import platform import traceback -import subprocess - -from mitmproxy import version from OpenSSL import SSL +from mitmproxy import version -def dump_system_info(): - mitmproxy_version = version.VERSION - here = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - try: - git_describe = subprocess.check_output( - ['git', 'describe', '--tags', '--long'], - stderr=subprocess.STDOUT, - cwd=here, - ) - except: - pass - else: - last_tag, tag_dist, commit = git_describe.decode().strip().rsplit("-", 2) - - commit = commit.lstrip("g") # remove the 'g' prefix added by recent git versions - tag_dist = int(tag_dist) - - if tag_dist > 0: - tag_dist = "dev{:04}".format(tag_dist) - else: - tag_dist = "" - mitmproxy_version += "{tag_dist} ({commit})".format( - tag_dist=tag_dist, - commit=commit, - ) +def dump_system_info(): + mitmproxy_version = version.get_version(True, True) + mitmproxy_version = re.sub( + r"-0x([0-9a-f]+)", + r" (commit \1)", + mitmproxy_version + ) # PyInstaller builds indicator, if using precompiled binary if getattr(sys, 'frozen', False): diff --git a/mitmproxy/version.py b/mitmproxy/version.py index 3cae2a04..3073c3d3 100644 --- a/mitmproxy/version.py +++ b/mitmproxy/version.py @@ -1,5 +1,9 @@ -IVERSION = (3, 0, 0) -VERSION = ".".join(str(i) for i in IVERSION) +import os +import subprocess + +# The actual version string. For precompiled binaries, this will be changed to include the build +# tag, e.g. "3.0.0.dev0042-0xcafeabc" +VERSION = "3.0.0" PATHOD = "pathod " + VERSION MITMPROXY = "mitmproxy " + VERSION @@ -7,5 +11,54 @@ MITMPROXY = "mitmproxy " + VERSION # for each change in the file format. FLOW_FORMAT_VERSION = 5 + +def get_version(dev: bool = False, build: bool = False, refresh: bool = False) -> str: + """ + Return a detailed version string, sourced either from a hardcoded VERSION constant + or obtained dynamically using git. + + Args: + dev: If True, non-tagged releases will include a ".devXXXX" suffix, where XXXX is the number + of commits since the last tagged release. + build: If True, non-tagged releases will include a "-0xXXXXXXX" suffix, where XXXXXXX are + the first seven digits of the commit hash. + refresh: If True, always try to use git instead of a potentially hardcoded constant. + """ + + mitmproxy_version = VERSION + + if "dev" in VERSION and not refresh: + pass # There is a hardcoded build tag, so we just use what's there. + elif dev or build: + here = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + try: + git_describe = subprocess.check_output( + ['git', 'describe', '--tags', '--long'], + stderr=subprocess.STDOUT, + cwd=here, + ) + last_tag, tag_dist, commit = git_describe.decode().strip().rsplit("-", 2) + commit = commit.lstrip("g")[:7] + tag_dist = int(tag_dist) + except Exception: + pass + else: + # Remove current suffix + mitmproxy_version = mitmproxy_version.split(".dev")[0] + + # Add suffix for non-tagged releases + if tag_dist > 0: + mitmproxy_version += ".dev{tag_dist:04}".format(tag_dist=tag_dist) + # The wheel build tag (we use the commit) must start with a digit, so we include "0x" + mitmproxy_version += "-0x{commit}".format(commit=commit) + + if not dev: + mitmproxy_version = mitmproxy_version.split(".dev")[0] + elif not build: + mitmproxy_version = mitmproxy_version.split("-0x")[0] + + return mitmproxy_version + + if __name__ == "__main__": print(VERSION) diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index a37edb54..66257852 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -1,7 +1,8 @@ import time from typing import List, Optional -from mitmproxy.contrib import wsproto +from wsproto.frame_protocol import CloseReason +from wsproto.frame_protocol import Opcode from mitmproxy import flow from mitmproxy.net import websockets @@ -17,7 +18,7 @@ class WebSocketMessage(serializable.Serializable): def __init__( self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None, killed: bool=False ) -> None: - self.type = wsproto.frame_protocol.Opcode(type) # type: ignore + self.type = Opcode(type) # type: ignore """indicates either TEXT or BINARY (from wsproto.frame_protocol.Opcode).""" self.from_client = from_client """True if this messages was sent by the client.""" @@ -37,10 +38,10 @@ class WebSocketMessage(serializable.Serializable): def set_state(self, state): self.type, self.from_client, self.content, self.timestamp, self.killed = state - self.type = wsproto.frame_protocol.Opcode(self.type) # replace enum with bare int + self.type = Opcode(self.type) # replace enum with bare int def __repr__(self): - if self.type == wsproto.frame_protocol.Opcode.TEXT: + if self.type == Opcode.TEXT: return "text message: {}".format(repr(self.content)) else: return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) @@ -66,7 +67,7 @@ class WebSocketFlow(flow.Flow): """A list containing all WebSocketMessage's.""" self.close_sender = 'client' """'client' if the client initiated connection closing.""" - self.close_code = wsproto.frame_protocol.CloseReason.NORMAL_CLOSURE + self.close_code = CloseReason.NORMAL_CLOSURE """WebSocket close code.""" self.close_message = '(message missing)' """WebSocket close message.""" diff --git a/release/hooks/hook-mitmproxy.py b/release/hooks/hook-mitmproxy.py new file mode 100644 index 00000000..21932507 --- /dev/null +++ b/release/hooks/hook-mitmproxy.py @@ -0,0 +1 @@ +hiddenimports = ["mitmproxy.script"] diff --git a/release/rtool.py b/release/rtool.py index 271392ba..4a07885c 100755 --- a/release/rtool.py +++ b/release/rtool.py @@ -4,6 +4,7 @@ import contextlib import fnmatch import os import platform +import re import runpy import shlex import shutil @@ -79,26 +80,21 @@ def git(args: str) -> str: return subprocess.check_output(["git"] + shlex.split(args)).decode() -def get_version() -> str: - return runpy.run_path(VERSION_FILE)["VERSION"] +def get_version(dev: bool = False, build: bool = False) -> str: + x = runpy.run_path(VERSION_FILE) + return x["get_version"](dev, build, True) -def get_snapshot_version() -> str: - last_tag, tag_dist, commit = git("describe --tags --long").strip().rsplit("-", 2) - tag_dist = int(tag_dist) - if tag_dist == 0: - return get_version() - else: - # remove the 'g' prefix added by recent git versions - if commit.startswith('g'): - commit = commit[1:] - - # The wheel build tag (we use the commit) must start with a digit, so we include "0x" - return "{version}dev{tag_dist:04}-0x{commit}".format( - version=get_version(), # this should already be the next version - tag_dist=tag_dist, - commit=commit - ) +def set_version(dev: bool) -> None: + """ + Update version information in mitmproxy's version.py to either include hardcoded information or not. + """ + version = get_version(dev, dev) + with open(VERSION_FILE, "r") as f: + content = f.read() + content = re.sub(r'^VERSION = ".+?"', 'VERSION = "{}"'.format(version), content, flags=re.M) + with open(VERSION_FILE, "w") as f: + f.write(content) def archive_name(bdist: str) -> str: @@ -116,7 +112,7 @@ def archive_name(bdist: str) -> str: def wheel_name() -> str: return "mitmproxy-{version}-py3-none-any.whl".format( - version=get_version(), + version=get_version(True), ) @@ -179,6 +175,23 @@ def contributors(): f.write(contributors_data.encode()) +@cli.command("wheel") +def make_wheel(): + """ + Build a Python wheel + """ + set_version(True) + try: + subprocess.check_call([ + "tox", "-e", "wheel", + ], env={ + **os.environ, + "VERSION": get_version(True), + }) + finally: + set_version(False) + + @cli.command("bdist") def make_bdist(): """ @@ -206,24 +219,30 @@ def make_bdist(): excludes.append("mitmproxy.tools.web") if tool != "mitmproxy_main": excludes.append("mitmproxy.tools.console") - subprocess.check_call( - [ - "pyinstaller", - "--clean", - "--workpath", PYINSTALLER_TEMP, - "--distpath", PYINSTALLER_DIST, - "--additional-hooks-dir", PYINSTALLER_HOOKS, - "--onefile", - "--console", - "--icon", "icon.ico", - # This is PyInstaller, so setting a - # different log level obviously breaks it :-) - # "--log-level", "WARN", - ] - + [x for e in excludes for x in ["--exclude-module", e]] - + PYINSTALLER_ARGS - + [tool] - ) + + # Overwrite mitmproxy/version.py to include commit info + set_version(True) + try: + subprocess.check_call( + [ + "pyinstaller", + "--clean", + "--workpath", PYINSTALLER_TEMP, + "--distpath", PYINSTALLER_DIST, + "--additional-hooks-dir", PYINSTALLER_HOOKS, + "--onefile", + "--console", + "--icon", "icon.ico", + # This is PyInstaller, so setting a + # different log level obviously breaks it :-) + # "--log-level", "WARN", + ] + + [x for e in excludes for x in ["--exclude-module", e]] + + PYINSTALLER_ARGS + + [tool] + ) + finally: + set_version(False) # Delete the spec file - we're good without. os.remove("{}.spec".format(tool)) @@ -299,7 +318,11 @@ def upload_snapshot(host, port, user, private_key, private_key_password, wheel, for f in files: local_path = join(DIST_DIR, f) - remote_filename = f.replace(get_version(), get_snapshot_version()) + remote_filename = re.sub( + r"{version}(\.dev\d+(-0x[0-9a-f]+)?)?".format(version=get_version()), + get_version(True, True), + f + ) symlink_path = "../{}".format(f.replace(get_version(), "latest")) # Upload new version diff --git a/release/setup.py b/release/setup.py deleted file mode 100644 index 0c4e6605..00000000 --- a/release/setup.py +++ /dev/null @@ -1,18 +0,0 @@ -from setuptools import setup - -setup( - name='mitmproxy-rtool', - version="1.0", - py_modules=["rtool"], - install_requires=[ - "click>=6.2, <7.0", - "twine>=1.6.5, <1.10", - "pysftp==0.2.8", - "cryptography>=2.0.0, <2.1", - ], - entry_points={ - "console_scripts": [ - "rtool=rtool:cli", - ], - }, -) @@ -1,7 +1,7 @@ import os -import runpy from codecs import open +import re from setuptools import setup, find_packages # Based on https://github.com/pypa/sampleproject/blob/master/setup.py @@ -12,7 +12,8 @@ here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: long_description = f.read() -VERSION = runpy.run_path(os.path.join(here, "mitmproxy", "version.py"))["VERSION"] +with open(os.path.join(here, "mitmproxy", "version.py")) as f: + VERSION = re.search(r'VERSION = "(.+?)(?:-0x|")', f.read()).group(1) setup( name="mitmproxy", @@ -80,6 +81,7 @@ setup( "sortedcontainers>=1.5.4, <1.6", "tornado>=4.3, <4.6", "urwid>=1.3.1, <1.4", + "wsproto>=0.11.0,<0.12.0", ], extras_require={ ':sys_platform == "win32"': [ diff --git a/test/mitmproxy/test_version.py b/test/mitmproxy/test_version.py index f87b0851..f8d646dc 100644 --- a/test/mitmproxy/test_version.py +++ b/test/mitmproxy/test_version.py @@ -1,4 +1,6 @@ import runpy +import subprocess +from unittest import mock from mitmproxy import version @@ -8,3 +10,24 @@ def test_version(capsys): stdout, stderr = capsys.readouterr() assert len(stdout) > 0 assert stdout.strip() == version.VERSION + + +def test_get_version_hardcoded(): + version.VERSION = "3.0.0.dev123-0xcafebabe" + assert version.get_version() == "3.0.0" + assert version.get_version(True) == "3.0.0.dev123" + assert version.get_version(True, True) == "3.0.0.dev123-0xcafebabe" + + +def test_get_version(): + version.VERSION = "3.0.0" + + with mock.patch('subprocess.check_output') as m: + m.return_value = b"tag-0-cafecafe" + assert version.get_version(True, True) == "3.0.0" + + m.return_value = b"tag-2-cafecafe" + assert version.get_version(True, True) == "3.0.0.dev0002-0xcafecaf" + + m.side_effect = subprocess.CalledProcessError(-1, 'git describe --tags --long') + assert version.get_version(True, True) == "3.0.0" diff --git a/test/mitmproxy/tools/console/test_defaultkeys.py b/test/mitmproxy/tools/console/test_defaultkeys.py new file mode 100644 index 00000000..1f17c888 --- /dev/null +++ b/test/mitmproxy/tools/console/test_defaultkeys.py @@ -0,0 +1,23 @@ +from mitmproxy.test.tflow import tflow +from mitmproxy.tools.console import defaultkeys +from mitmproxy.tools.console import keymap +from mitmproxy.tools.console import master +from mitmproxy import command + + +def test_commands_exist(): + km = keymap.Keymap(None) + defaultkeys.map(km) + assert km.bindings + m = master.ConsoleMaster(None) + m.load_flow(tflow()) + + for binding in km.bindings: + cmd, *args = command.lexer(binding.command) + assert cmd in m.commands.commands + + cmd_obj = m.commands.commands[cmd] + try: + cmd_obj.prepare_args(args) + except Exception as e: + raise ValueError("Invalid command: {}".format(binding.command)) from e diff --git a/test/mitmproxy/utils/test_debug.py b/test/mitmproxy/utils/test_debug.py index a8e1054d..0ca6ead0 100644 --- a/test/mitmproxy/utils/test_debug.py +++ b/test/mitmproxy/utils/test_debug.py @@ -1,5 +1,4 @@ import io -import subprocess import sys from unittest import mock import pytest @@ -14,18 +13,6 @@ def test_dump_system_info_precompiled(precompiled): assert ("binary" in debug.dump_system_info()) == precompiled -def test_dump_system_info_version(): - with mock.patch('subprocess.check_output') as m: - m.return_value = b"v2.0.0-0-cafecafe" - x = debug.dump_system_info() - assert 'dev' not in x - assert 'cafecafe' in x - - with mock.patch('subprocess.check_output') as m: - m.side_effect = subprocess.CalledProcessError(-1, 'git describe --tags --long') - assert 'dev' not in debug.dump_system_info() - - def test_dump_info(): cs = io.StringIO() debug.dump_info(None, None, file=cs, testing=True) @@ -25,7 +25,7 @@ commands = sphinx-build -W -b html -d {envtmpdir}/doctrees . {envtmpdir}/html commands = mitmdump --version flake8 --jobs 8 mitmproxy pathod examples test release - python3 test/filename_matching.py + python test/filename_matching.py rstcheck README.rst mypy --ignore-missing-imports ./mitmproxy mypy --ignore-missing-imports ./pathod @@ -35,7 +35,7 @@ commands = deps = -rrequirements.txt commands = - python3 test/individual_coverage.py + python test/individual_coverage.py [testenv:wheel] recreate = True @@ -51,14 +51,13 @@ commands = pathoc --version [testenv:rtool] +passenv = SKIP_MITMPROXY SNAPSHOT_HOST SNAPSHOT_PORT SNAPSHOT_USER SNAPSHOT_PASS RTOOL_KEY deps = -rrequirements.txt - -e./release - # The 3.2 release is broken - # the next commit after this updates the bootloaders, which then segfault! - # https://github.com/pyinstaller/pyinstaller/issues/2232 - git+https://github.com/pyinstaller/pyinstaller.git@483c819d6a256b58db6740696a901bd41c313f0c; sys_platform == 'win32' - git+https://github.com/mhils/pyinstaller.git@d094401e4196b1a6a03818b80164a5f555861cef; sys_platform != 'win32' + pyinstaller==3.3.1 + twine==1.9.1 + pysftp==0.2.8 commands = - rtool {posargs} + mitmdump --version + python ./release/rtool.py {posargs} diff --git a/web/src/js/filt/filt.js b/web/src/js/filt/filt.js index 26058649..19a41af2 100644 --- a/web/src/js/filt/filt.js +++ b/web/src/js/filt/filt.js @@ -1929,7 +1929,7 @@ module.exports = (function() { function body(regex){ regex = new RegExp(regex, "i"); function bodyFilter(flow){ - return True; + return true; } bodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return bodyFilter; @@ -1937,7 +1937,7 @@ module.exports = (function() { function requestBody(regex){ regex = new RegExp(regex, "i"); function requestBodyFilter(flow){ - return True; + return true; } requestBodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return requestBodyFilter; @@ -1945,7 +1945,7 @@ module.exports = (function() { function responseBody(regex){ regex = new RegExp(regex, "i"); function responseBodyFilter(flow){ - return True; + return true; } responseBodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return responseBodyFilter; @@ -2104,4 +2104,4 @@ module.exports = (function() { SyntaxError: peg$SyntaxError, parse: peg$parse }; -})();
\ No newline at end of file +})(); diff --git a/web/src/js/filt/filt.peg b/web/src/js/filt/filt.peg index 12959474..e4b151ad 100644 --- a/web/src/js/filt/filt.peg +++ b/web/src/js/filt/filt.peg @@ -1,4 +1,4 @@ -// PEG.js filter rules - see http://pegjs.majda.cz/online +// PEG.js filter rules - see https://pegjs.org/ { var flowutils = require("../flow/utils.js"); @@ -72,7 +72,7 @@ function responseCode(code){ function body(regex){ regex = new RegExp(regex, "i"); function bodyFilter(flow){ - return True; + return true; } bodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return bodyFilter; @@ -80,7 +80,7 @@ function body(regex){ function requestBody(regex){ regex = new RegExp(regex, "i"); function requestBodyFilter(flow){ - return True; + return true; } requestBodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return requestBodyFilter; @@ -88,7 +88,7 @@ function requestBody(regex){ function responseBody(regex){ regex = new RegExp(regex, "i"); function responseBodyFilter(flow){ - return True; + return true; } responseBodyFilter.desc = "body filters are not implemented yet, see https://github.com/mitmproxy/mitmweb/issues/10"; return responseBodyFilter; |