From 4ca78604af2a8ddb596e2f4e95090dabc8495bfe Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 20 Mar 2017 12:50:09 +1300 Subject: Factor out an io module Include tnetstring - we've made enough changes that this no longer belongs in contrib. --- mitmproxy/contrib/tnetstring.py | 250 ------------------------------ mitmproxy/io.py | 87 ----------- mitmproxy/io/__init__.py | 7 + mitmproxy/io/compat.py | 214 +++++++++++++++++++++++++ mitmproxy/io/io.py | 87 +++++++++++ mitmproxy/io/tnetstring.py | 250 ++++++++++++++++++++++++++++++ mitmproxy/io_compat.py | 214 ------------------------- test/mitmproxy/contrib/test_tnetstring.py | 137 ---------------- test/mitmproxy/io/test_compat.py | 28 ++++ test/mitmproxy/io/test_io.py | 1 + test/mitmproxy/io/test_tnetstring.py | 137 ++++++++++++++++ test/mitmproxy/test_flow.py | 2 +- test/mitmproxy/test_io.py | 1 - test/mitmproxy/test_io_compat.py | 28 ---- test/mitmproxy/test_websocket.py | 2 +- 15 files changed, 726 insertions(+), 719 deletions(-) delete mode 100644 mitmproxy/contrib/tnetstring.py delete mode 100644 mitmproxy/io.py create mode 100644 mitmproxy/io/__init__.py create mode 100644 mitmproxy/io/compat.py create mode 100644 mitmproxy/io/io.py create mode 100644 mitmproxy/io/tnetstring.py delete mode 100644 mitmproxy/io_compat.py delete mode 100644 test/mitmproxy/contrib/test_tnetstring.py create mode 100644 test/mitmproxy/io/test_compat.py create mode 100644 test/mitmproxy/io/test_io.py create mode 100644 test/mitmproxy/io/test_tnetstring.py delete mode 100644 test/mitmproxy/test_io.py delete mode 100644 test/mitmproxy/test_io_compat.py diff --git a/mitmproxy/contrib/tnetstring.py b/mitmproxy/contrib/tnetstring.py deleted file mode 100644 index 24ce6ce8..00000000 --- a/mitmproxy/contrib/tnetstring.py +++ /dev/null @@ -1,250 +0,0 @@ -""" -tnetstring: data serialization using typed netstrings -====================================================== - -This is a custom Python 3 implementation of tnetstrings. -Compared to other implementations, the main difference -is that this implementation supports a custom unicode datatype. - -An ordinary tnetstring is a blob of data prefixed with its length and postfixed -with its type. Here are some examples: - - >>> tnetstring.dumps("hello world") - 11:hello world, - >>> tnetstring.dumps(12345) - 5:12345# - >>> tnetstring.dumps([12345, True, 0]) - 19:5:12345#4:true!1:0#] - -This module gives you the following functions: - - :dump: dump an object as a tnetstring to a file - :dumps: dump an object as a tnetstring to a string - :load: load a tnetstring-encoded object from a file - :loads: load a tnetstring-encoded object from a string - -Note that since parsing a tnetstring requires reading all the data into memory -at once, there's no efficiency gain from using the file-based versions of these -functions. They're only here so you can use load() to read precisely one -item from a file or socket without consuming any extra data. - -The tnetstrings specification explicitly states that strings are binary blobs -and forbids the use of unicode at the protocol level. -**This implementation decodes dictionary keys as surrogate-escaped ASCII**, -all other strings are returned as plain bytes. - -:Copyright: (c) 2012-2013 by Ryan Kelly . -:Copyright: (c) 2014 by Carlo Pires . -:Copyright: (c) 2016 by Maximilian Hils . - -:License: MIT -""" - -import collections -from typing import io, Union, Tuple - -TSerializable = Union[None, bool, int, float, bytes, list, tuple, dict] - - -def dumps(value: TSerializable) -> bytes: - """ - This function dumps a python object as a tnetstring. - """ - # This uses a deque to collect output fragments in reverse order, - # then joins them together at the end. It's measurably faster - # than creating all the intermediate strings. - q = collections.deque() - _rdumpq(q, 0, value) - return b''.join(q) - - -def dump(value: TSerializable, file_handle: io.BinaryIO) -> None: - """ - This function dumps a python object as a tnetstring and - writes it to the given file. - """ - file_handle.write(dumps(value)) - - -def _rdumpq(q: collections.deque, size: int, value: TSerializable) -> int: - """ - Dump value as a tnetstring, to a deque instance, last chunks first. - - This function generates the tnetstring representation of the given value, - pushing chunks of the output onto the given deque instance. It pushes - the last chunk first, then recursively generates more chunks. - - When passed in the current size of the string in the queue, it will return - the new size of the string in the queue. - - Operating last-chunk-first makes it easy to calculate the size written - for recursive structures without having to build their representation as - a string. This is measurably faster than generating the intermediate - strings, especially on deeply nested structures. - """ - write = q.appendleft - if value is None: - write(b'0:~') - return size + 3 - elif value is True: - write(b'4:true!') - return size + 7 - elif value is False: - write(b'5:false!') - return size + 8 - elif isinstance(value, int): - data = str(value).encode() - ldata = len(data) - span = str(ldata).encode() - write(b'%s:%s#' % (span, data)) - return size + 2 + len(span) + ldata - elif isinstance(value, float): - # Use repr() for float rather than str(). - # It round-trips more accurately. - # Probably unnecessary in later python versions that - # use David Gay's ftoa routines. - data = repr(value).encode() - ldata = len(data) - span = str(ldata).encode() - write(b'%s:%s^' % (span, data)) - return size + 2 + len(span) + ldata - elif isinstance(value, bytes): - data = value - ldata = len(data) - span = str(ldata).encode() - write(b',') - write(data) - write(b':') - write(span) - return size + 2 + len(span) + ldata - elif isinstance(value, str): - data = value.encode("utf8") - ldata = len(data) - span = str(ldata).encode() - write(b';') - write(data) - write(b':') - write(span) - return size + 2 + len(span) + ldata - elif isinstance(value, (list, tuple)): - write(b']') - init_size = size = size + 1 - for item in reversed(value): - size = _rdumpq(q, size, item) - span = str(size - init_size).encode() - write(b':') - write(span) - return size + 1 + len(span) - elif isinstance(value, dict): - write(b'}') - init_size = size = size + 1 - for (k, v) in value.items(): - size = _rdumpq(q, size, v) - size = _rdumpq(q, size, k) - span = str(size - init_size).encode() - write(b':') - write(span) - return size + 1 + len(span) - else: - raise ValueError("unserializable object: {} ({})".format(value, type(value))) - - -def loads(string: bytes) -> TSerializable: - """ - This function parses a tnetstring into a python object. - """ - return pop(string)[0] - - -def load(file_handle: io.BinaryIO) -> TSerializable: - """load(file) -> object - - This function reads a tnetstring from a file and parses it into a - python object. The file must support the read() method, and this - function promises not to read more data than necessary. - """ - # Read the length prefix one char at a time. - # Note that the netstring spec explicitly forbids padding zeros. - c = file_handle.read(1) - if c == b"": # we want to detect this special case. - raise ValueError("not a tnetstring: empty file") - data_length = b"" - while c.isdigit(): - data_length += c - if len(data_length) > 9: - raise ValueError("not a tnetstring: absurdly large length prefix") - c = file_handle.read(1) - if c != b":": - raise ValueError("not a tnetstring: missing or invalid length prefix") - - data = file_handle.read(int(data_length)) - data_type = file_handle.read(1)[0] - - return parse(data_type, data) - - -def parse(data_type: int, data: bytes) -> TSerializable: - if data_type == ord(b','): - return data - if data_type == ord(b';'): - return data.decode("utf8") - if data_type == ord(b'#'): - try: - return int(data) - except ValueError: - raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) - if data_type == ord(b'^'): - try: - return float(data) - except ValueError: - raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) - if data_type == ord(b'!'): - if data == b'true': - return True - elif data == b'false': - return False - else: - raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) - if data_type == ord(b'~'): - if data: - raise ValueError("not a tnetstring: invalid null literal") - return None - if data_type == ord(b']'): - l = [] - while data: - item, data = pop(data) - l.append(item) - return l - if data_type == ord(b'}'): - d = {} - while data: - key, data = pop(data) - val, data = pop(data) - d[key] = val - return d - raise ValueError("unknown type tag: {}".format(data_type)) - - -def pop(data: bytes) -> Tuple[TSerializable, bytes]: - """ - This function parses a tnetstring into a python object. - It returns a tuple giving the parsed object and a string - containing any unparsed data from the end of the string. - """ - # Parse out data length, type and remaining string. - try: - length, data = data.split(b':', 1) - length = int(length) - except ValueError: - raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data)) - try: - data, data_type, remain = data[:length], data[length], data[length + 1:] - except IndexError: - # This fires if len(data) < dlen, meaning we don't need - # to further validate that data is the right length. - raise ValueError("not a tnetstring: invalid length prefix: {}".format(length)) - # Parse the data based on the type tag. - return parse(data_type, data), remain - - -__all__ = ["dump", "dumps", "load", "loads", "pop"] diff --git a/mitmproxy/io.py b/mitmproxy/io.py deleted file mode 100644 index 0f6c3f5c..00000000 --- a/mitmproxy/io.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -from typing import Type, Iterable, Dict, Union, Any, cast # noqa - -from mitmproxy import exceptions -from mitmproxy import flow -from mitmproxy import flowfilter -from mitmproxy import http -from mitmproxy import tcp -from mitmproxy import websocket -from mitmproxy.contrib import tnetstring -from mitmproxy import io_compat - - -FLOW_TYPES = dict( - http=http.HTTPFlow, - websocket=websocket.WebSocketFlow, - tcp=tcp.TCPFlow, -) # type: Dict[str, Type[flow.Flow]] - - -class FlowWriter: - def __init__(self, fo): - self.fo = fo - - def add(self, flow): - d = flow.get_state() - tnetstring.dump(d, self.fo) - - -class FlowReader: - def __init__(self, fo): - self.fo = fo - - def stream(self) -> Iterable[flow.Flow]: - """ - Yields Flow objects from the dump. - """ - try: - while True: - # FIXME: This cast hides a lack of dynamic type checking - loaded = cast( - Dict[Union[bytes, str], Any], - tnetstring.load(self.fo), - ) - try: - mdata = io_compat.migrate_flow(loaded) - except ValueError as e: - raise exceptions.FlowReadException(str(e)) - if mdata["type"] not in FLOW_TYPES: - raise exceptions.FlowReadException("Unknown flow type: {}".format(mdata["type"])) - yield FLOW_TYPES[mdata["type"]].from_state(mdata) - except ValueError as e: - if str(e) == "not a tnetstring: empty file": - return # Error is due to EOF - raise exceptions.FlowReadException("Invalid data format.") - - -class FilteredFlowWriter: - def __init__(self, fo, flt): - self.fo = fo - self.flt = flt - - def add(self, f: flow.Flow): - if self.flt and not flowfilter.match(self.flt, f): - return - d = f.get_state() - tnetstring.dump(d, self.fo) - - -def read_flows_from_paths(paths): - """ - Given a list of filepaths, read all flows and return a list of them. - From a performance perspective, streaming would be advisable - - however, if there's an error with one of the files, we want it to be raised immediately. - - Raises: - FlowReadException, if any error occurs. - """ - try: - flows = [] - for path in paths: - path = os.path.expanduser(path) - with open(path, "rb") as f: - flows.extend(FlowReader(f).stream()) - except IOError as e: - raise exceptions.FlowReadException(e.strerror) - return flows diff --git a/mitmproxy/io/__init__.py b/mitmproxy/io/__init__.py new file mode 100644 index 00000000..a82f729f --- /dev/null +++ b/mitmproxy/io/__init__.py @@ -0,0 +1,7 @@ + +from .io import FlowWriter, FlowReader, FilteredFlowWriter, read_flows_from_paths + + +__all__ = [ + "FlowWriter", "FlowReader", "FilteredFlowWriter", "read_flows_from_paths" +] \ No newline at end of file diff --git a/mitmproxy/io/compat.py b/mitmproxy/io/compat.py new file mode 100644 index 00000000..9d95f602 --- /dev/null +++ b/mitmproxy/io/compat.py @@ -0,0 +1,214 @@ +""" +This module handles the import of mitmproxy flows generated by old versions. +""" +import uuid +from typing import Any, Dict, Mapping, Union # noqa + +from mitmproxy import version +from mitmproxy.utils import strutils + + +def convert_011_012(data): + data[b"version"] = (0, 12) + return data + + +def convert_012_013(data): + data[b"version"] = (0, 13) + return data + + +def convert_013_014(data): + data[b"request"][b"first_line_format"] = data[b"request"].pop(b"form_in") + data[b"request"][b"http_version"] = b"HTTP/" + ".".join( + str(x) for x in data[b"request"].pop(b"httpversion")).encode() + data[b"response"][b"http_version"] = b"HTTP/" + ".".join( + str(x) for x in data[b"response"].pop(b"httpversion")).encode() + data[b"response"][b"status_code"] = data[b"response"].pop(b"code") + data[b"response"][b"body"] = data[b"response"].pop(b"content") + data[b"server_conn"].pop(b"state") + data[b"server_conn"][b"via"] = None + data[b"version"] = (0, 14) + return data + + +def convert_014_015(data): + data[b"version"] = (0, 15) + return data + + +def convert_015_016(data): + for m in (b"request", b"response"): + if b"body" in data[m]: + data[m][b"content"] = data[m].pop(b"body") + if b"msg" in data[b"response"]: + data[b"response"][b"reason"] = data[b"response"].pop(b"msg") + data[b"request"].pop(b"form_out", None) + data[b"version"] = (0, 16) + return data + + +def convert_016_017(data): + data[b"server_conn"][b"peer_address"] = None + data[b"version"] = (0, 17) + return data + + +def convert_017_018(data): + # convert_unicode needs to be called for every dual release and the first py3-only release + data = convert_unicode(data) + + data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address") + data["marked"] = False + data["version"] = (0, 18) + return data + + +def convert_018_019(data): + # convert_unicode needs to be called for every dual release and the first py3-only release + data = convert_unicode(data) + + data["request"].pop("stickyauth", None) + data["request"].pop("stickycookie", None) + data["client_conn"]["sni"] = None + data["client_conn"]["alpn_proto_negotiated"] = None + data["client_conn"]["cipher_name"] = None + data["client_conn"]["tls_version"] = None + data["server_conn"]["alpn_proto_negotiated"] = None + data["mode"] = "regular" + data["metadata"] = dict() + data["version"] = (0, 19) + return data + + +def convert_019_100(data): + # convert_unicode needs to be called for every dual release and the first py3-only release + data = convert_unicode(data) + + data["version"] = (1, 0, 0) + return data + + +def convert_100_200(data): + data["version"] = (2, 0, 0) + data["client_conn"]["address"] = data["client_conn"]["address"]["address"] + data["server_conn"]["address"] = data["server_conn"]["address"]["address"] + data["server_conn"]["source_address"] = data["server_conn"]["source_address"]["address"] + if data["server_conn"]["ip_address"]: + data["server_conn"]["ip_address"] = data["server_conn"]["ip_address"]["address"] + return data + + +def convert_200_300(data): + data["version"] = (3, 0, 0) + data["client_conn"]["mitmcert"] = None + data["server_conn"]["tls_version"] = None + if data["server_conn"]["via"]: + data["server_conn"]["via"]["tls_version"] = None + return data + + +def convert_300_4(data): + data["version"] = 4 + return data + + +client_connections = {} # type: Mapping[str, str] +server_connections = {} # type: Mapping[str, str] + + +def convert_4_5(data): + data["version"] = 5 + client_conn_key = ( + data["client_conn"]["timestamp_start"], + *data["client_conn"]["address"] + ) + server_conn_key = ( + data["server_conn"]["timestamp_start"], + *data["server_conn"]["source_address"] + ) + data["client_conn"]["id"] = client_connections.setdefault(client_conn_key, str(uuid.uuid4())) + data["server_conn"]["id"] = server_connections.setdefault(server_conn_key, str(uuid.uuid4())) + return data + + +def _convert_dict_keys(o: Any) -> Any: + if isinstance(o, dict): + return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} + else: + return o + + +def _convert_dict_vals(o: dict, values_to_convert: dict) -> dict: + for k, v in values_to_convert.items(): + if not o or k not in o: + continue + if v is True: + o[k] = strutils.always_str(o[k]) + else: + _convert_dict_vals(o[k], v) + return o + + +def convert_unicode(data: dict) -> dict: + """ + This method converts between Python 3 and Python 2 dumpfiles. + """ + data = _convert_dict_keys(data) + data = _convert_dict_vals( + data, { + "type": True, + "id": True, + "request": { + "first_line_format": True + }, + "error": { + "msg": True + } + } + ) + return data + + +converters = { + (0, 11): convert_011_012, + (0, 12): convert_012_013, + (0, 13): convert_013_014, + (0, 14): convert_014_015, + (0, 15): convert_015_016, + (0, 16): convert_016_017, + (0, 17): convert_017_018, + (0, 18): convert_018_019, + (0, 19): convert_019_100, + (1, 0): convert_100_200, + (2, 0): convert_200_300, + (3, 0): convert_300_4, + 4: convert_4_5, +} + + +def migrate_flow(flow_data: Dict[Union[bytes, str], Any]) -> Dict[Union[bytes, str], Any]: + while True: + flow_version = flow_data.get(b"version", flow_data.get("version")) + + # Historically, we used the mitmproxy minor version tuple as the flow format version. + if not isinstance(flow_version, int): + flow_version = tuple(flow_version)[:2] + + if flow_version == version.FLOW_FORMAT_VERSION: + break + elif flow_version in converters: + flow_data = converters[flow_version](flow_data) + else: + should_upgrade = ( + isinstance(flow_version, int) + and flow_version > version.FLOW_FORMAT_VERSION + ) + raise ValueError( + "{} cannot read files with flow format version {}{}.".format( + version.MITMPROXY, + flow_version, + ", please update mitmproxy" if should_upgrade else "" + ) + ) + return flow_data diff --git a/mitmproxy/io/io.py b/mitmproxy/io/io.py new file mode 100644 index 00000000..50e26f49 --- /dev/null +++ b/mitmproxy/io/io.py @@ -0,0 +1,87 @@ +import os +from typing import Type, Iterable, Dict, Union, Any, cast # noqa + +from mitmproxy import exceptions +from mitmproxy import flow +from mitmproxy import flowfilter +from mitmproxy import http +from mitmproxy import tcp +from mitmproxy import websocket + +from mitmproxy.io import compat +from mitmproxy.io import tnetstring + +FLOW_TYPES = dict( + http=http.HTTPFlow, + websocket=websocket.WebSocketFlow, + tcp=tcp.TCPFlow, +) # type: Dict[str, Type[flow.Flow]] + + +class FlowWriter: + def __init__(self, fo): + self.fo = fo + + def add(self, flow): + d = flow.get_state() + tnetstring.dump(d, self.fo) + + +class FlowReader: + def __init__(self, fo): + self.fo = fo + + def stream(self) -> Iterable[flow.Flow]: + """ + Yields Flow objects from the dump. + """ + try: + while True: + # FIXME: This cast hides a lack of dynamic type checking + loaded = cast( + Dict[Union[bytes, str], Any], + tnetstring.load(self.fo), + ) + try: + mdata = compat.migrate_flow(loaded) + except ValueError as e: + raise exceptions.FlowReadException(str(e)) + if mdata["type"] not in FLOW_TYPES: + raise exceptions.FlowReadException("Unknown flow type: {}".format(mdata["type"])) + yield FLOW_TYPES[mdata["type"]].from_state(mdata) + except ValueError as e: + if str(e) == "not a tnetstring: empty file": + return # Error is due to EOF + raise exceptions.FlowReadException("Invalid data format.") + + +class FilteredFlowWriter: + def __init__(self, fo, flt): + self.fo = fo + self.flt = flt + + def add(self, f: flow.Flow): + if self.flt and not flowfilter.match(self.flt, f): + return + d = f.get_state() + tnetstring.dump(d, self.fo) + + +def read_flows_from_paths(paths): + """ + Given a list of filepaths, read all flows and return a list of them. + From a performance perspective, streaming would be advisable - + however, if there's an error with one of the files, we want it to be raised immediately. + + Raises: + FlowReadException, if any error occurs. + """ + try: + flows = [] + for path in paths: + path = os.path.expanduser(path) + with open(path, "rb") as f: + flows.extend(FlowReader(f).stream()) + except IOError as e: + raise exceptions.FlowReadException(e.strerror) + return flows diff --git a/mitmproxy/io/tnetstring.py b/mitmproxy/io/tnetstring.py new file mode 100644 index 00000000..24ce6ce8 --- /dev/null +++ b/mitmproxy/io/tnetstring.py @@ -0,0 +1,250 @@ +""" +tnetstring: data serialization using typed netstrings +====================================================== + +This is a custom Python 3 implementation of tnetstrings. +Compared to other implementations, the main difference +is that this implementation supports a custom unicode datatype. + +An ordinary tnetstring is a blob of data prefixed with its length and postfixed +with its type. Here are some examples: + + >>> tnetstring.dumps("hello world") + 11:hello world, + >>> tnetstring.dumps(12345) + 5:12345# + >>> tnetstring.dumps([12345, True, 0]) + 19:5:12345#4:true!1:0#] + +This module gives you the following functions: + + :dump: dump an object as a tnetstring to a file + :dumps: dump an object as a tnetstring to a string + :load: load a tnetstring-encoded object from a file + :loads: load a tnetstring-encoded object from a string + +Note that since parsing a tnetstring requires reading all the data into memory +at once, there's no efficiency gain from using the file-based versions of these +functions. They're only here so you can use load() to read precisely one +item from a file or socket without consuming any extra data. + +The tnetstrings specification explicitly states that strings are binary blobs +and forbids the use of unicode at the protocol level. +**This implementation decodes dictionary keys as surrogate-escaped ASCII**, +all other strings are returned as plain bytes. + +:Copyright: (c) 2012-2013 by Ryan Kelly . +:Copyright: (c) 2014 by Carlo Pires . +:Copyright: (c) 2016 by Maximilian Hils . + +:License: MIT +""" + +import collections +from typing import io, Union, Tuple + +TSerializable = Union[None, bool, int, float, bytes, list, tuple, dict] + + +def dumps(value: TSerializable) -> bytes: + """ + This function dumps a python object as a tnetstring. + """ + # This uses a deque to collect output fragments in reverse order, + # then joins them together at the end. It's measurably faster + # than creating all the intermediate strings. + q = collections.deque() + _rdumpq(q, 0, value) + return b''.join(q) + + +def dump(value: TSerializable, file_handle: io.BinaryIO) -> None: + """ + This function dumps a python object as a tnetstring and + writes it to the given file. + """ + file_handle.write(dumps(value)) + + +def _rdumpq(q: collections.deque, size: int, value: TSerializable) -> int: + """ + Dump value as a tnetstring, to a deque instance, last chunks first. + + This function generates the tnetstring representation of the given value, + pushing chunks of the output onto the given deque instance. It pushes + the last chunk first, then recursively generates more chunks. + + When passed in the current size of the string in the queue, it will return + the new size of the string in the queue. + + Operating last-chunk-first makes it easy to calculate the size written + for recursive structures without having to build their representation as + a string. This is measurably faster than generating the intermediate + strings, especially on deeply nested structures. + """ + write = q.appendleft + if value is None: + write(b'0:~') + return size + 3 + elif value is True: + write(b'4:true!') + return size + 7 + elif value is False: + write(b'5:false!') + return size + 8 + elif isinstance(value, int): + data = str(value).encode() + ldata = len(data) + span = str(ldata).encode() + write(b'%s:%s#' % (span, data)) + return size + 2 + len(span) + ldata + elif isinstance(value, float): + # Use repr() for float rather than str(). + # It round-trips more accurately. + # Probably unnecessary in later python versions that + # use David Gay's ftoa routines. + data = repr(value).encode() + ldata = len(data) + span = str(ldata).encode() + write(b'%s:%s^' % (span, data)) + return size + 2 + len(span) + ldata + elif isinstance(value, bytes): + data = value + ldata = len(data) + span = str(ldata).encode() + write(b',') + write(data) + write(b':') + write(span) + return size + 2 + len(span) + ldata + elif isinstance(value, str): + data = value.encode("utf8") + ldata = len(data) + span = str(ldata).encode() + write(b';') + write(data) + write(b':') + write(span) + return size + 2 + len(span) + ldata + elif isinstance(value, (list, tuple)): + write(b']') + init_size = size = size + 1 + for item in reversed(value): + size = _rdumpq(q, size, item) + span = str(size - init_size).encode() + write(b':') + write(span) + return size + 1 + len(span) + elif isinstance(value, dict): + write(b'}') + init_size = size = size + 1 + for (k, v) in value.items(): + size = _rdumpq(q, size, v) + size = _rdumpq(q, size, k) + span = str(size - init_size).encode() + write(b':') + write(span) + return size + 1 + len(span) + else: + raise ValueError("unserializable object: {} ({})".format(value, type(value))) + + +def loads(string: bytes) -> TSerializable: + """ + This function parses a tnetstring into a python object. + """ + return pop(string)[0] + + +def load(file_handle: io.BinaryIO) -> TSerializable: + """load(file) -> object + + This function reads a tnetstring from a file and parses it into a + python object. The file must support the read() method, and this + function promises not to read more data than necessary. + """ + # Read the length prefix one char at a time. + # Note that the netstring spec explicitly forbids padding zeros. + c = file_handle.read(1) + if c == b"": # we want to detect this special case. + raise ValueError("not a tnetstring: empty file") + data_length = b"" + while c.isdigit(): + data_length += c + if len(data_length) > 9: + raise ValueError("not a tnetstring: absurdly large length prefix") + c = file_handle.read(1) + if c != b":": + raise ValueError("not a tnetstring: missing or invalid length prefix") + + data = file_handle.read(int(data_length)) + data_type = file_handle.read(1)[0] + + return parse(data_type, data) + + +def parse(data_type: int, data: bytes) -> TSerializable: + if data_type == ord(b','): + return data + if data_type == ord(b';'): + return data.decode("utf8") + if data_type == ord(b'#'): + try: + return int(data) + except ValueError: + raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) + if data_type == ord(b'^'): + try: + return float(data) + except ValueError: + raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) + if data_type == ord(b'!'): + if data == b'true': + return True + elif data == b'false': + return False + else: + raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) + if data_type == ord(b'~'): + if data: + raise ValueError("not a tnetstring: invalid null literal") + return None + if data_type == ord(b']'): + l = [] + while data: + item, data = pop(data) + l.append(item) + return l + if data_type == ord(b'}'): + d = {} + while data: + key, data = pop(data) + val, data = pop(data) + d[key] = val + return d + raise ValueError("unknown type tag: {}".format(data_type)) + + +def pop(data: bytes) -> Tuple[TSerializable, bytes]: + """ + This function parses a tnetstring into a python object. + It returns a tuple giving the parsed object and a string + containing any unparsed data from the end of the string. + """ + # Parse out data length, type and remaining string. + try: + length, data = data.split(b':', 1) + length = int(length) + except ValueError: + raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data)) + try: + data, data_type, remain = data[:length], data[length], data[length + 1:] + except IndexError: + # This fires if len(data) < dlen, meaning we don't need + # to further validate that data is the right length. + raise ValueError("not a tnetstring: invalid length prefix: {}".format(length)) + # Parse the data based on the type tag. + return parse(data_type, data), remain + + +__all__ = ["dump", "dumps", "load", "loads", "pop"] diff --git a/mitmproxy/io_compat.py b/mitmproxy/io_compat.py deleted file mode 100644 index 9d95f602..00000000 --- a/mitmproxy/io_compat.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -This module handles the import of mitmproxy flows generated by old versions. -""" -import uuid -from typing import Any, Dict, Mapping, Union # noqa - -from mitmproxy import version -from mitmproxy.utils import strutils - - -def convert_011_012(data): - data[b"version"] = (0, 12) - return data - - -def convert_012_013(data): - data[b"version"] = (0, 13) - return data - - -def convert_013_014(data): - data[b"request"][b"first_line_format"] = data[b"request"].pop(b"form_in") - data[b"request"][b"http_version"] = b"HTTP/" + ".".join( - str(x) for x in data[b"request"].pop(b"httpversion")).encode() - data[b"response"][b"http_version"] = b"HTTP/" + ".".join( - str(x) for x in data[b"response"].pop(b"httpversion")).encode() - data[b"response"][b"status_code"] = data[b"response"].pop(b"code") - data[b"response"][b"body"] = data[b"response"].pop(b"content") - data[b"server_conn"].pop(b"state") - data[b"server_conn"][b"via"] = None - data[b"version"] = (0, 14) - return data - - -def convert_014_015(data): - data[b"version"] = (0, 15) - return data - - -def convert_015_016(data): - for m in (b"request", b"response"): - if b"body" in data[m]: - data[m][b"content"] = data[m].pop(b"body") - if b"msg" in data[b"response"]: - data[b"response"][b"reason"] = data[b"response"].pop(b"msg") - data[b"request"].pop(b"form_out", None) - data[b"version"] = (0, 16) - return data - - -def convert_016_017(data): - data[b"server_conn"][b"peer_address"] = None - data[b"version"] = (0, 17) - return data - - -def convert_017_018(data): - # convert_unicode needs to be called for every dual release and the first py3-only release - data = convert_unicode(data) - - data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address") - data["marked"] = False - data["version"] = (0, 18) - return data - - -def convert_018_019(data): - # convert_unicode needs to be called for every dual release and the first py3-only release - data = convert_unicode(data) - - data["request"].pop("stickyauth", None) - data["request"].pop("stickycookie", None) - data["client_conn"]["sni"] = None - data["client_conn"]["alpn_proto_negotiated"] = None - data["client_conn"]["cipher_name"] = None - data["client_conn"]["tls_version"] = None - data["server_conn"]["alpn_proto_negotiated"] = None - data["mode"] = "regular" - data["metadata"] = dict() - data["version"] = (0, 19) - return data - - -def convert_019_100(data): - # convert_unicode needs to be called for every dual release and the first py3-only release - data = convert_unicode(data) - - data["version"] = (1, 0, 0) - return data - - -def convert_100_200(data): - data["version"] = (2, 0, 0) - data["client_conn"]["address"] = data["client_conn"]["address"]["address"] - data["server_conn"]["address"] = data["server_conn"]["address"]["address"] - data["server_conn"]["source_address"] = data["server_conn"]["source_address"]["address"] - if data["server_conn"]["ip_address"]: - data["server_conn"]["ip_address"] = data["server_conn"]["ip_address"]["address"] - return data - - -def convert_200_300(data): - data["version"] = (3, 0, 0) - data["client_conn"]["mitmcert"] = None - data["server_conn"]["tls_version"] = None - if data["server_conn"]["via"]: - data["server_conn"]["via"]["tls_version"] = None - return data - - -def convert_300_4(data): - data["version"] = 4 - return data - - -client_connections = {} # type: Mapping[str, str] -server_connections = {} # type: Mapping[str, str] - - -def convert_4_5(data): - data["version"] = 5 - client_conn_key = ( - data["client_conn"]["timestamp_start"], - *data["client_conn"]["address"] - ) - server_conn_key = ( - data["server_conn"]["timestamp_start"], - *data["server_conn"]["source_address"] - ) - data["client_conn"]["id"] = client_connections.setdefault(client_conn_key, str(uuid.uuid4())) - data["server_conn"]["id"] = server_connections.setdefault(server_conn_key, str(uuid.uuid4())) - return data - - -def _convert_dict_keys(o: Any) -> Any: - if isinstance(o, dict): - return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} - else: - return o - - -def _convert_dict_vals(o: dict, values_to_convert: dict) -> dict: - for k, v in values_to_convert.items(): - if not o or k not in o: - continue - if v is True: - o[k] = strutils.always_str(o[k]) - else: - _convert_dict_vals(o[k], v) - return o - - -def convert_unicode(data: dict) -> dict: - """ - This method converts between Python 3 and Python 2 dumpfiles. - """ - data = _convert_dict_keys(data) - data = _convert_dict_vals( - data, { - "type": True, - "id": True, - "request": { - "first_line_format": True - }, - "error": { - "msg": True - } - } - ) - return data - - -converters = { - (0, 11): convert_011_012, - (0, 12): convert_012_013, - (0, 13): convert_013_014, - (0, 14): convert_014_015, - (0, 15): convert_015_016, - (0, 16): convert_016_017, - (0, 17): convert_017_018, - (0, 18): convert_018_019, - (0, 19): convert_019_100, - (1, 0): convert_100_200, - (2, 0): convert_200_300, - (3, 0): convert_300_4, - 4: convert_4_5, -} - - -def migrate_flow(flow_data: Dict[Union[bytes, str], Any]) -> Dict[Union[bytes, str], Any]: - while True: - flow_version = flow_data.get(b"version", flow_data.get("version")) - - # Historically, we used the mitmproxy minor version tuple as the flow format version. - if not isinstance(flow_version, int): - flow_version = tuple(flow_version)[:2] - - if flow_version == version.FLOW_FORMAT_VERSION: - break - elif flow_version in converters: - flow_data = converters[flow_version](flow_data) - else: - should_upgrade = ( - isinstance(flow_version, int) - and flow_version > version.FLOW_FORMAT_VERSION - ) - raise ValueError( - "{} cannot read files with flow format version {}{}.".format( - version.MITMPROXY, - flow_version, - ", please update mitmproxy" if should_upgrade else "" - ) - ) - return flow_data diff --git a/test/mitmproxy/contrib/test_tnetstring.py b/test/mitmproxy/contrib/test_tnetstring.py deleted file mode 100644 index 05c4a7c9..00000000 --- a/test/mitmproxy/contrib/test_tnetstring.py +++ /dev/null @@ -1,137 +0,0 @@ -import unittest -import random -import math -import io -import struct - -from mitmproxy.contrib import tnetstring - -MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1 - -FORMAT_EXAMPLES = { - b'0:}': {}, - b'0:]': [], - b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}': - {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, - b'5:12345#': 12345, - b'12:this is cool,': b'this is cool', - b'19:this is unicode \xe2\x98\x85;': u'this is unicode \u2605', - b'0:,': b'', - b'0:;': u'', - b'0:~': None, - b'4:true!': True, - b'5:false!': False, - b'10:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', - b'24:5:12345#5:67890#5:xxxxx,]': [12345, 67890, b'xxxxx'], - b'18:3:0.1^3:0.2^3:0.3^]': [0.1, 0.2, 0.3], - b'243:238:233:228:223:218:213:208:203:198:193:188:183:178:173:168:163:158:153:148:143:138:133:128:123:118:113:108:103:99:95:91:87:83:79:75:71:67:63:59:55:51:47:43:39:35:31:27:23:19:15:11:hello-there,]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]': [[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[b'hello-there']]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]] # noqa -} - - -def get_random_object(random=random, depth=0): - """Generate a random serializable object.""" - # The probability of generating a scalar value increases as the depth increase. - # This ensures that we bottom out eventually. - if random.randint(depth, 10) <= 4: - what = random.randint(0, 1) - if what == 0: - n = random.randint(0, 10) - l = [] - for _ in range(n): - l.append(get_random_object(random, depth + 1)) - return l - if what == 1: - n = random.randint(0, 10) - d = {} - for _ in range(n): - n = random.randint(0, 100) - k = str([random.randint(32, 126) for _ in range(n)]) - d[k] = get_random_object(random, depth + 1) - return d - else: - what = random.randint(0, 4) - if what == 0: - return None - if what == 1: - return True - if what == 2: - return False - if what == 3: - if random.randint(0, 1) == 0: - return random.randint(0, MAXINT) - else: - return -1 * random.randint(0, MAXINT) - n = random.randint(0, 100) - return bytes([random.randint(32, 126) for _ in range(n)]) - - -class Test_Format(unittest.TestCase): - - def test_roundtrip_format_examples(self): - for data, expect in FORMAT_EXAMPLES.items(): - self.assertEqual(expect, tnetstring.loads(data)) - self.assertEqual( - expect, tnetstring.loads(tnetstring.dumps(expect))) - self.assertEqual((expect, b''), tnetstring.pop(data)) - - def test_roundtrip_format_random(self): - for _ in range(500): - v = get_random_object() - self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) - self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v))) - - def test_roundtrip_format_unicode(self): - for _ in range(500): - v = get_random_object() - self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) - self.assertEqual((v, b''), tnetstring.pop(tnetstring.dumps(v))) - - def test_roundtrip_big_integer(self): - i1 = math.factorial(30000) - s = tnetstring.dumps(i1) - i2 = tnetstring.loads(s) - self.assertEqual(i1, i2) - - -class Test_FileLoading(unittest.TestCase): - - def test_roundtrip_file_examples(self): - for data, expect in FORMAT_EXAMPLES.items(): - s = io.BytesIO() - s.write(data) - s.write(b'OK') - s.seek(0) - self.assertEqual(expect, tnetstring.load(s)) - self.assertEqual(b'OK', s.read()) - s = io.BytesIO() - tnetstring.dump(expect, s) - s.write(b'OK') - s.seek(0) - self.assertEqual(expect, tnetstring.load(s)) - self.assertEqual(b'OK', s.read()) - - def test_roundtrip_file_random(self): - for _ in range(500): - v = get_random_object() - s = io.BytesIO() - tnetstring.dump(v, s) - s.write(b'OK') - s.seek(0) - self.assertEqual(v, tnetstring.load(s)) - self.assertEqual(b'OK', s.read()) - - def test_error_on_absurd_lengths(self): - s = io.BytesIO() - s.write(b'1000000000:pwned!,') - s.seek(0) - with self.assertRaises(ValueError): - tnetstring.load(s) - self.assertEqual(s.read(1), b':') - - -def suite(): - loader = unittest.TestLoader() - suite = unittest.TestSuite() - suite.addTest(loader.loadTestsFromTestCase(Test_Format)) - suite.addTest(loader.loadTestsFromTestCase(Test_FileLoading)) - return suite diff --git a/test/mitmproxy/io/test_compat.py b/test/mitmproxy/io/test_compat.py new file mode 100644 index 00000000..288de4fc --- /dev/null +++ b/test/mitmproxy/io/test_compat.py @@ -0,0 +1,28 @@ +import pytest + +from mitmproxy import io +from mitmproxy import exceptions +from mitmproxy.test import tutils + + +def test_load(): + with open(tutils.test_data.path("mitmproxy/data/dumpfile-011"), "rb") as f: + flow_reader = io.FlowReader(f) + flows = list(flow_reader.stream()) + assert len(flows) == 1 + assert flows[0].request.url == "https://example.com/" + + +def test_load_018(): + with open(tutils.test_data.path("mitmproxy/data/dumpfile-018"), "rb") as f: + flow_reader = io.FlowReader(f) + flows = list(flow_reader.stream()) + assert len(flows) == 1 + assert flows[0].request.url == "https://www.example.com/" + + +def test_cannot_convert(): + with open(tutils.test_data.path("mitmproxy/data/dumpfile-010"), "rb") as f: + flow_reader = io.FlowReader(f) + with pytest.raises(exceptions.FlowReadException): + list(flow_reader.stream()) diff --git a/test/mitmproxy/io/test_io.py b/test/mitmproxy/io/test_io.py new file mode 100644 index 00000000..777ab4dd --- /dev/null +++ b/test/mitmproxy/io/test_io.py @@ -0,0 +1 @@ +# TODO: write tests diff --git a/test/mitmproxy/io/test_tnetstring.py b/test/mitmproxy/io/test_tnetstring.py new file mode 100644 index 00000000..f7141de0 --- /dev/null +++ b/test/mitmproxy/io/test_tnetstring.py @@ -0,0 +1,137 @@ +import unittest +import random +import math +import io +import struct + +from mitmproxy.io import tnetstring + +MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1 + +FORMAT_EXAMPLES = { + b'0:}': {}, + b'0:]': [], + b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}': + {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, + b'5:12345#': 12345, + b'12:this is cool,': b'this is cool', + b'19:this is unicode \xe2\x98\x85;': u'this is unicode \u2605', + b'0:,': b'', + b'0:;': u'', + b'0:~': None, + b'4:true!': True, + b'5:false!': False, + b'10:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + b'24:5:12345#5:67890#5:xxxxx,]': [12345, 67890, b'xxxxx'], + b'18:3:0.1^3:0.2^3:0.3^]': [0.1, 0.2, 0.3], + b'243:238:233:228:223:218:213:208:203:198:193:188:183:178:173:168:163:158:153:148:143:138:133:128:123:118:113:108:103:99:95:91:87:83:79:75:71:67:63:59:55:51:47:43:39:35:31:27:23:19:15:11:hello-there,]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]': [[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[b'hello-there']]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]] # noqa +} + + +def get_random_object(random=random, depth=0): + """Generate a random serializable object.""" + # The probability of generating a scalar value increases as the depth increase. + # This ensures that we bottom out eventually. + if random.randint(depth, 10) <= 4: + what = random.randint(0, 1) + if what == 0: + n = random.randint(0, 10) + l = [] + for _ in range(n): + l.append(get_random_object(random, depth + 1)) + return l + if what == 1: + n = random.randint(0, 10) + d = {} + for _ in range(n): + n = random.randint(0, 100) + k = str([random.randint(32, 126) for _ in range(n)]) + d[k] = get_random_object(random, depth + 1) + return d + else: + what = random.randint(0, 4) + if what == 0: + return None + if what == 1: + return True + if what == 2: + return False + if what == 3: + if random.randint(0, 1) == 0: + return random.randint(0, MAXINT) + else: + return -1 * random.randint(0, MAXINT) + n = random.randint(0, 100) + return bytes([random.randint(32, 126) for _ in range(n)]) + + +class Test_Format(unittest.TestCase): + + def test_roundtrip_format_examples(self): + for data, expect in FORMAT_EXAMPLES.items(): + self.assertEqual(expect, tnetstring.loads(data)) + self.assertEqual( + expect, tnetstring.loads(tnetstring.dumps(expect))) + self.assertEqual((expect, b''), tnetstring.pop(data)) + + def test_roundtrip_format_random(self): + for _ in range(500): + v = get_random_object() + self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) + self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v))) + + def test_roundtrip_format_unicode(self): + for _ in range(500): + v = get_random_object() + self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) + self.assertEqual((v, b''), tnetstring.pop(tnetstring.dumps(v))) + + def test_roundtrip_big_integer(self): + i1 = math.factorial(30000) + s = tnetstring.dumps(i1) + i2 = tnetstring.loads(s) + self.assertEqual(i1, i2) + + +class Test_FileLoading(unittest.TestCase): + + def test_roundtrip_file_examples(self): + for data, expect in FORMAT_EXAMPLES.items(): + s = io.BytesIO() + s.write(data) + s.write(b'OK') + s.seek(0) + self.assertEqual(expect, tnetstring.load(s)) + self.assertEqual(b'OK', s.read()) + s = io.BytesIO() + tnetstring.dump(expect, s) + s.write(b'OK') + s.seek(0) + self.assertEqual(expect, tnetstring.load(s)) + self.assertEqual(b'OK', s.read()) + + def test_roundtrip_file_random(self): + for _ in range(500): + v = get_random_object() + s = io.BytesIO() + tnetstring.dump(v, s) + s.write(b'OK') + s.seek(0) + self.assertEqual(v, tnetstring.load(s)) + self.assertEqual(b'OK', s.read()) + + def test_error_on_absurd_lengths(self): + s = io.BytesIO() + s.write(b'1000000000:pwned!,') + s.seek(0) + with self.assertRaises(ValueError): + tnetstring.load(s) + self.assertEqual(s.read(1), b':') + + +def suite(): + loader = unittest.TestLoader() + suite = unittest.TestSuite() + suite.addTest(loader.loadTestsFromTestCase(Test_Format)) + suite.addTest(loader.loadTestsFromTestCase(Test_FileLoading)) + return suite diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 630fc7e4..78f893c0 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -6,7 +6,7 @@ import mitmproxy.io from mitmproxy import flowfilter from mitmproxy import options from mitmproxy.proxy import config -from mitmproxy.contrib import tnetstring +from mitmproxy.io import tnetstring from mitmproxy.exceptions import FlowReadException from mitmproxy import flow from mitmproxy import http diff --git a/test/mitmproxy/test_io.py b/test/mitmproxy/test_io.py deleted file mode 100644 index 777ab4dd..00000000 --- a/test/mitmproxy/test_io.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: write tests diff --git a/test/mitmproxy/test_io_compat.py b/test/mitmproxy/test_io_compat.py deleted file mode 100644 index 288de4fc..00000000 --- a/test/mitmproxy/test_io_compat.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest - -from mitmproxy import io -from mitmproxy import exceptions -from mitmproxy.test import tutils - - -def test_load(): - with open(tutils.test_data.path("mitmproxy/data/dumpfile-011"), "rb") as f: - flow_reader = io.FlowReader(f) - flows = list(flow_reader.stream()) - assert len(flows) == 1 - assert flows[0].request.url == "https://example.com/" - - -def test_load_018(): - with open(tutils.test_data.path("mitmproxy/data/dumpfile-018"), "rb") as f: - flow_reader = io.FlowReader(f) - flows = list(flow_reader.stream()) - assert len(flows) == 1 - assert flows[0].request.url == "https://www.example.com/" - - -def test_cannot_convert(): - with open(tutils.test_data.path("mitmproxy/data/dumpfile-010"), "rb") as f: - flow_reader = io.FlowReader(f) - with pytest.raises(exceptions.FlowReadException): - list(flow_reader.stream()) diff --git a/test/mitmproxy/test_websocket.py b/test/mitmproxy/test_websocket.py index 62f69e2d..7c53a4b0 100644 --- a/test/mitmproxy/test_websocket.py +++ b/test/mitmproxy/test_websocket.py @@ -1,7 +1,7 @@ import io import pytest -from mitmproxy.contrib import tnetstring +from mitmproxy.io import tnetstring from mitmproxy import flowfilter from mitmproxy.test import tflow -- cgit v1.2.3