aboutsummaryrefslogtreecommitdiffstats
path: root/mitmproxy/io
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@corte.si>2017-03-20 12:50:09 +1300
committerAldo Cortesi <aldo@corte.si>2017-03-20 12:50:09 +1300
commit4ca78604af2a8ddb596e2f4e95090dabc8495bfe (patch)
treefff817d49cd5f4d8a3989f64be94b13cac17fd67 /mitmproxy/io
parent3a8da31835db37d65637058935f144ece62c1bdd (diff)
downloadmitmproxy-4ca78604af2a8ddb596e2f4e95090dabc8495bfe.tar.gz
mitmproxy-4ca78604af2a8ddb596e2f4e95090dabc8495bfe.tar.bz2
mitmproxy-4ca78604af2a8ddb596e2f4e95090dabc8495bfe.zip
Factor out an io module
Include tnetstring - we've made enough changes that this no longer belongs in contrib.
Diffstat (limited to 'mitmproxy/io')
-rw-r--r--mitmproxy/io/__init__.py7
-rw-r--r--mitmproxy/io/compat.py214
-rw-r--r--mitmproxy/io/io.py87
-rw-r--r--mitmproxy/io/tnetstring.py250
4 files changed, 558 insertions, 0 deletions
diff --git a/mitmproxy/io/__init__.py b/mitmproxy/io/__init__.py
new file mode 100644
index 00000000..a82f729f
--- /dev/null
+++ b/mitmproxy/io/__init__.py
@@ -0,0 +1,7 @@
+
+from .io import FlowWriter, FlowReader, FilteredFlowWriter, read_flows_from_paths
+
+
+__all__ = [
+ "FlowWriter", "FlowReader", "FilteredFlowWriter", "read_flows_from_paths"
+] \ No newline at end of file
diff --git a/mitmproxy/io/compat.py b/mitmproxy/io/compat.py
new file mode 100644
index 00000000..9d95f602
--- /dev/null
+++ b/mitmproxy/io/compat.py
@@ -0,0 +1,214 @@
+"""
+This module handles the import of mitmproxy flows generated by old versions.
+"""
+import uuid
+from typing import Any, Dict, Mapping, Union # noqa
+
+from mitmproxy import version
+from mitmproxy.utils import strutils
+
+
+def convert_011_012(data):
+ data[b"version"] = (0, 12)
+ return data
+
+
+def convert_012_013(data):
+ data[b"version"] = (0, 13)
+ return data
+
+
+def convert_013_014(data):
+ data[b"request"][b"first_line_format"] = data[b"request"].pop(b"form_in")
+ data[b"request"][b"http_version"] = b"HTTP/" + ".".join(
+ str(x) for x in data[b"request"].pop(b"httpversion")).encode()
+ data[b"response"][b"http_version"] = b"HTTP/" + ".".join(
+ str(x) for x in data[b"response"].pop(b"httpversion")).encode()
+ data[b"response"][b"status_code"] = data[b"response"].pop(b"code")
+ data[b"response"][b"body"] = data[b"response"].pop(b"content")
+ data[b"server_conn"].pop(b"state")
+ data[b"server_conn"][b"via"] = None
+ data[b"version"] = (0, 14)
+ return data
+
+
+def convert_014_015(data):
+ data[b"version"] = (0, 15)
+ return data
+
+
+def convert_015_016(data):
+ for m in (b"request", b"response"):
+ if b"body" in data[m]:
+ data[m][b"content"] = data[m].pop(b"body")
+ if b"msg" in data[b"response"]:
+ data[b"response"][b"reason"] = data[b"response"].pop(b"msg")
+ data[b"request"].pop(b"form_out", None)
+ data[b"version"] = (0, 16)
+ return data
+
+
+def convert_016_017(data):
+ data[b"server_conn"][b"peer_address"] = None
+ data[b"version"] = (0, 17)
+ return data
+
+
+def convert_017_018(data):
+ # convert_unicode needs to be called for every dual release and the first py3-only release
+ data = convert_unicode(data)
+
+ data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address")
+ data["marked"] = False
+ data["version"] = (0, 18)
+ return data
+
+
+def convert_018_019(data):
+ # convert_unicode needs to be called for every dual release and the first py3-only release
+ data = convert_unicode(data)
+
+ data["request"].pop("stickyauth", None)
+ data["request"].pop("stickycookie", None)
+ data["client_conn"]["sni"] = None
+ data["client_conn"]["alpn_proto_negotiated"] = None
+ data["client_conn"]["cipher_name"] = None
+ data["client_conn"]["tls_version"] = None
+ data["server_conn"]["alpn_proto_negotiated"] = None
+ data["mode"] = "regular"
+ data["metadata"] = dict()
+ data["version"] = (0, 19)
+ return data
+
+
+def convert_019_100(data):
+ # convert_unicode needs to be called for every dual release and the first py3-only release
+ data = convert_unicode(data)
+
+ data["version"] = (1, 0, 0)
+ return data
+
+
+def convert_100_200(data):
+ data["version"] = (2, 0, 0)
+ data["client_conn"]["address"] = data["client_conn"]["address"]["address"]
+ data["server_conn"]["address"] = data["server_conn"]["address"]["address"]
+ data["server_conn"]["source_address"] = data["server_conn"]["source_address"]["address"]
+ if data["server_conn"]["ip_address"]:
+ data["server_conn"]["ip_address"] = data["server_conn"]["ip_address"]["address"]
+ return data
+
+
+def convert_200_300(data):
+ data["version"] = (3, 0, 0)
+ data["client_conn"]["mitmcert"] = None
+ data["server_conn"]["tls_version"] = None
+ if data["server_conn"]["via"]:
+ data["server_conn"]["via"]["tls_version"] = None
+ return data
+
+
+def convert_300_4(data):
+ data["version"] = 4
+ return data
+
+
+client_connections = {} # type: Mapping[str, str]
+server_connections = {} # type: Mapping[str, str]
+
+
+def convert_4_5(data):
+ data["version"] = 5
+ client_conn_key = (
+ data["client_conn"]["timestamp_start"],
+ *data["client_conn"]["address"]
+ )
+ server_conn_key = (
+ data["server_conn"]["timestamp_start"],
+ *data["server_conn"]["source_address"]
+ )
+ data["client_conn"]["id"] = client_connections.setdefault(client_conn_key, str(uuid.uuid4()))
+ data["server_conn"]["id"] = server_connections.setdefault(server_conn_key, str(uuid.uuid4()))
+ return data
+
+
+def _convert_dict_keys(o: Any) -> Any:
+ if isinstance(o, dict):
+ return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
+ else:
+ return o
+
+
+def _convert_dict_vals(o: dict, values_to_convert: dict) -> dict:
+ for k, v in values_to_convert.items():
+ if not o or k not in o:
+ continue
+ if v is True:
+ o[k] = strutils.always_str(o[k])
+ else:
+ _convert_dict_vals(o[k], v)
+ return o
+
+
+def convert_unicode(data: dict) -> dict:
+ """
+ This method converts between Python 3 and Python 2 dumpfiles.
+ """
+ data = _convert_dict_keys(data)
+ data = _convert_dict_vals(
+ data, {
+ "type": True,
+ "id": True,
+ "request": {
+ "first_line_format": True
+ },
+ "error": {
+ "msg": True
+ }
+ }
+ )
+ return data
+
+
+converters = {
+ (0, 11): convert_011_012,
+ (0, 12): convert_012_013,
+ (0, 13): convert_013_014,
+ (0, 14): convert_014_015,
+ (0, 15): convert_015_016,
+ (0, 16): convert_016_017,
+ (0, 17): convert_017_018,
+ (0, 18): convert_018_019,
+ (0, 19): convert_019_100,
+ (1, 0): convert_100_200,
+ (2, 0): convert_200_300,
+ (3, 0): convert_300_4,
+ 4: convert_4_5,
+}
+
+
+def migrate_flow(flow_data: Dict[Union[bytes, str], Any]) -> Dict[Union[bytes, str], Any]:
+ while True:
+ flow_version = flow_data.get(b"version", flow_data.get("version"))
+
+ # Historically, we used the mitmproxy minor version tuple as the flow format version.
+ if not isinstance(flow_version, int):
+ flow_version = tuple(flow_version)[:2]
+
+ if flow_version == version.FLOW_FORMAT_VERSION:
+ break
+ elif flow_version in converters:
+ flow_data = converters[flow_version](flow_data)
+ else:
+ should_upgrade = (
+ isinstance(flow_version, int)
+ and flow_version > version.FLOW_FORMAT_VERSION
+ )
+ raise ValueError(
+ "{} cannot read files with flow format version {}{}.".format(
+ version.MITMPROXY,
+ flow_version,
+ ", please update mitmproxy" if should_upgrade else ""
+ )
+ )
+ return flow_data
diff --git a/mitmproxy/io/io.py b/mitmproxy/io/io.py
new file mode 100644
index 00000000..50e26f49
--- /dev/null
+++ b/mitmproxy/io/io.py
@@ -0,0 +1,87 @@
+import os
+from typing import Type, Iterable, Dict, Union, Any, cast # noqa
+
+from mitmproxy import exceptions
+from mitmproxy import flow
+from mitmproxy import flowfilter
+from mitmproxy import http
+from mitmproxy import tcp
+from mitmproxy import websocket
+
+from mitmproxy.io import compat
+from mitmproxy.io import tnetstring
+
+FLOW_TYPES = dict(
+ http=http.HTTPFlow,
+ websocket=websocket.WebSocketFlow,
+ tcp=tcp.TCPFlow,
+) # type: Dict[str, Type[flow.Flow]]
+
+
+class FlowWriter:
+ def __init__(self, fo):
+ self.fo = fo
+
+ def add(self, flow):
+ d = flow.get_state()
+ tnetstring.dump(d, self.fo)
+
+
+class FlowReader:
+ def __init__(self, fo):
+ self.fo = fo
+
+ def stream(self) -> Iterable[flow.Flow]:
+ """
+ Yields Flow objects from the dump.
+ """
+ try:
+ while True:
+ # FIXME: This cast hides a lack of dynamic type checking
+ loaded = cast(
+ Dict[Union[bytes, str], Any],
+ tnetstring.load(self.fo),
+ )
+ try:
+ mdata = compat.migrate_flow(loaded)
+ except ValueError as e:
+ raise exceptions.FlowReadException(str(e))
+ if mdata["type"] not in FLOW_TYPES:
+ raise exceptions.FlowReadException("Unknown flow type: {}".format(mdata["type"]))
+ yield FLOW_TYPES[mdata["type"]].from_state(mdata)
+ except ValueError as e:
+ if str(e) == "not a tnetstring: empty file":
+ return # Error is due to EOF
+ raise exceptions.FlowReadException("Invalid data format.")
+
+
+class FilteredFlowWriter:
+ def __init__(self, fo, flt):
+ self.fo = fo
+ self.flt = flt
+
+ def add(self, f: flow.Flow):
+ if self.flt and not flowfilter.match(self.flt, f):
+ return
+ d = f.get_state()
+ tnetstring.dump(d, self.fo)
+
+
+def read_flows_from_paths(paths):
+ """
+ Given a list of filepaths, read all flows and return a list of them.
+ From a performance perspective, streaming would be advisable -
+ however, if there's an error with one of the files, we want it to be raised immediately.
+
+ Raises:
+ FlowReadException, if any error occurs.
+ """
+ try:
+ flows = []
+ for path in paths:
+ path = os.path.expanduser(path)
+ with open(path, "rb") as f:
+ flows.extend(FlowReader(f).stream())
+ except IOError as e:
+ raise exceptions.FlowReadException(e.strerror)
+ return flows
diff --git a/mitmproxy/io/tnetstring.py b/mitmproxy/io/tnetstring.py
new file mode 100644
index 00000000..24ce6ce8
--- /dev/null
+++ b/mitmproxy/io/tnetstring.py
@@ -0,0 +1,250 @@
+"""
+tnetstring: data serialization using typed netstrings
+======================================================
+
+This is a custom Python 3 implementation of tnetstrings.
+Compared to other implementations, the main difference
+is that this implementation supports a custom unicode datatype.
+
+An ordinary tnetstring is a blob of data prefixed with its length and postfixed
+with its type. Here are some examples:
+
+ >>> tnetstring.dumps("hello world")
+ 11:hello world,
+ >>> tnetstring.dumps(12345)
+ 5:12345#
+ >>> tnetstring.dumps([12345, True, 0])
+ 19:5:12345#4:true!1:0#]
+
+This module gives you the following functions:
+
+ :dump: dump an object as a tnetstring to a file
+ :dumps: dump an object as a tnetstring to a string
+ :load: load a tnetstring-encoded object from a file
+ :loads: load a tnetstring-encoded object from a string
+
+Note that since parsing a tnetstring requires reading all the data into memory
+at once, there's no efficiency gain from using the file-based versions of these
+functions. They're only here so you can use load() to read precisely one
+item from a file or socket without consuming any extra data.
+
+The tnetstrings specification explicitly states that strings are binary blobs
+and forbids the use of unicode at the protocol level.
+**This implementation decodes dictionary keys as surrogate-escaped ASCII**,
+all other strings are returned as plain bytes.
+
+:Copyright: (c) 2012-2013 by Ryan Kelly <ryan@rfk.id.au>.
+:Copyright: (c) 2014 by Carlo Pires <carlopires@gmail.com>.
+:Copyright: (c) 2016 by Maximilian Hils <tnetstring3@maximilianhils.com>.
+
+:License: MIT
+"""
+
+import collections
+from typing import io, Union, Tuple
+
+TSerializable = Union[None, bool, int, float, bytes, list, tuple, dict]
+
+
+def dumps(value: TSerializable) -> bytes:
+ """
+ This function dumps a python object as a tnetstring.
+ """
+ # This uses a deque to collect output fragments in reverse order,
+ # then joins them together at the end. It's measurably faster
+ # than creating all the intermediate strings.
+ q = collections.deque()
+ _rdumpq(q, 0, value)
+ return b''.join(q)
+
+
+def dump(value: TSerializable, file_handle: io.BinaryIO) -> None:
+ """
+ This function dumps a python object as a tnetstring and
+ writes it to the given file.
+ """
+ file_handle.write(dumps(value))
+
+
+def _rdumpq(q: collections.deque, size: int, value: TSerializable) -> int:
+ """
+ Dump value as a tnetstring, to a deque instance, last chunks first.
+
+ This function generates the tnetstring representation of the given value,
+ pushing chunks of the output onto the given deque instance. It pushes
+ the last chunk first, then recursively generates more chunks.
+
+ When passed in the current size of the string in the queue, it will return
+ the new size of the string in the queue.
+
+ Operating last-chunk-first makes it easy to calculate the size written
+ for recursive structures without having to build their representation as
+ a string. This is measurably faster than generating the intermediate
+ strings, especially on deeply nested structures.
+ """
+ write = q.appendleft
+ if value is None:
+ write(b'0:~')
+ return size + 3
+ elif value is True:
+ write(b'4:true!')
+ return size + 7
+ elif value is False:
+ write(b'5:false!')
+ return size + 8
+ elif isinstance(value, int):
+ data = str(value).encode()
+ ldata = len(data)
+ span = str(ldata).encode()
+ write(b'%s:%s#' % (span, data))
+ return size + 2 + len(span) + ldata
+ elif isinstance(value, float):
+ # Use repr() for float rather than str().
+ # It round-trips more accurately.
+ # Probably unnecessary in later python versions that
+ # use David Gay's ftoa routines.
+ data = repr(value).encode()
+ ldata = len(data)
+ span = str(ldata).encode()
+ write(b'%s:%s^' % (span, data))
+ return size + 2 + len(span) + ldata
+ elif isinstance(value, bytes):
+ data = value
+ ldata = len(data)
+ span = str(ldata).encode()
+ write(b',')
+ write(data)
+ write(b':')
+ write(span)
+ return size + 2 + len(span) + ldata
+ elif isinstance(value, str):
+ data = value.encode("utf8")
+ ldata = len(data)
+ span = str(ldata).encode()
+ write(b';')
+ write(data)
+ write(b':')
+ write(span)
+ return size + 2 + len(span) + ldata
+ elif isinstance(value, (list, tuple)):
+ write(b']')
+ init_size = size = size + 1
+ for item in reversed(value):
+ size = _rdumpq(q, size, item)
+ span = str(size - init_size).encode()
+ write(b':')
+ write(span)
+ return size + 1 + len(span)
+ elif isinstance(value, dict):
+ write(b'}')
+ init_size = size = size + 1
+ for (k, v) in value.items():
+ size = _rdumpq(q, size, v)
+ size = _rdumpq(q, size, k)
+ span = str(size - init_size).encode()
+ write(b':')
+ write(span)
+ return size + 1 + len(span)
+ else:
+ raise ValueError("unserializable object: {} ({})".format(value, type(value)))
+
+
+def loads(string: bytes) -> TSerializable:
+ """
+ This function parses a tnetstring into a python object.
+ """
+ return pop(string)[0]
+
+
+def load(file_handle: io.BinaryIO) -> TSerializable:
+ """load(file) -> object
+
+ This function reads a tnetstring from a file and parses it into a
+ python object. The file must support the read() method, and this
+ function promises not to read more data than necessary.
+ """
+ # Read the length prefix one char at a time.
+ # Note that the netstring spec explicitly forbids padding zeros.
+ c = file_handle.read(1)
+ if c == b"": # we want to detect this special case.
+ raise ValueError("not a tnetstring: empty file")
+ data_length = b""
+ while c.isdigit():
+ data_length += c
+ if len(data_length) > 9:
+ raise ValueError("not a tnetstring: absurdly large length prefix")
+ c = file_handle.read(1)
+ if c != b":":
+ raise ValueError("not a tnetstring: missing or invalid length prefix")
+
+ data = file_handle.read(int(data_length))
+ data_type = file_handle.read(1)[0]
+
+ return parse(data_type, data)
+
+
+def parse(data_type: int, data: bytes) -> TSerializable:
+ if data_type == ord(b','):
+ return data
+ if data_type == ord(b';'):
+ return data.decode("utf8")
+ if data_type == ord(b'#'):
+ try:
+ return int(data)
+ except ValueError:
+ raise ValueError("not a tnetstring: invalid integer literal: {}".format(data))
+ if data_type == ord(b'^'):
+ try:
+ return float(data)
+ except ValueError:
+ raise ValueError("not a tnetstring: invalid float literal: {}".format(data))
+ if data_type == ord(b'!'):
+ if data == b'true':
+ return True
+ elif data == b'false':
+ return False
+ else:
+ raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data))
+ if data_type == ord(b'~'):
+ if data:
+ raise ValueError("not a tnetstring: invalid null literal")
+ return None
+ if data_type == ord(b']'):
+ l = []
+ while data:
+ item, data = pop(data)
+ l.append(item)
+ return l
+ if data_type == ord(b'}'):
+ d = {}
+ while data:
+ key, data = pop(data)
+ val, data = pop(data)
+ d[key] = val
+ return d
+ raise ValueError("unknown type tag: {}".format(data_type))
+
+
+def pop(data: bytes) -> Tuple[TSerializable, bytes]:
+ """
+ This function parses a tnetstring into a python object.
+ It returns a tuple giving the parsed object and a string
+ containing any unparsed data from the end of the string.
+ """
+ # Parse out data length, type and remaining string.
+ try:
+ length, data = data.split(b':', 1)
+ length = int(length)
+ except ValueError:
+ raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data))
+ try:
+ data, data_type, remain = data[:length], data[length], data[length + 1:]
+ except IndexError:
+ # This fires if len(data) < dlen, meaning we don't need
+ # to further validate that data is the right length.
+ raise ValueError("not a tnetstring: invalid length prefix: {}".format(length))
+ # Parse the data based on the type tag.
+ return parse(data_type, data), remain
+
+
+__all__ = ["dump", "dumps", "load", "loads", "pop"]