From 3a8da31835db37d65637058935f144ece62c1bdd Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 20 Mar 2017 12:19:22 +1300 Subject: mypy all of the codebase bar tnetstring In some places, this involved removing type declarations where our types were terminally confused. The grideditor specifically needs a cleanup and restructure. --- mitmproxy/addons/view.py | 21 ++++++++++++--------- mitmproxy/flowfilter.py | 25 +++++++++++++------------ mitmproxy/io.py | 18 +++++++++++------- mitmproxy/io_compat.py | 8 ++++---- mitmproxy/tools/console/flowlist.py | 5 ++++- mitmproxy/tools/console/flowview.py | 11 ++++++++--- mitmproxy/tools/console/grideditor/base.py | 18 ++++++++++-------- mitmproxy/tools/console/grideditor/col_bytes.py | 7 ++++--- mitmproxy/tools/console/grideditor/col_text.py | 7 +++---- mitmproxy/tools/console/grideditor/editors.py | 2 +- mitmproxy/tools/console/master.py | 1 - mitmproxy/tools/console/palettes.py | 3 ++- mitmproxy/tools/console/statusbar.py | 5 ++++- mitmproxy/tools/web/app.py | 1 + 14 files changed, 77 insertions(+), 55 deletions(-) diff --git a/mitmproxy/addons/view.py b/mitmproxy/addons/view.py index 1b8a30e4..7e9d66a1 100644 --- a/mitmproxy/addons/view.py +++ b/mitmproxy/addons/view.py @@ -18,6 +18,7 @@ import sortedcontainers import mitmproxy.flow from mitmproxy import flowfilter from mitmproxy import exceptions +from mitmproxy import http # noqa # The underlying sorted list implementation expects the sort key to be stable # for the lifetime of the object. However, if we sort by size, for instance, @@ -34,7 +35,7 @@ class _OrderKey: def __init__(self, view): self.view = view - def generate(self, f: mitmproxy.flow.Flow) -> typing.Any: # pragma: no cover + def generate(self, f: http.HTTPFlow) -> typing.Any: # pragma: no cover pass def refresh(self, f): @@ -64,22 +65,22 @@ class _OrderKey: class OrderRequestStart(_OrderKey): - def generate(self, f: mitmproxy.flow.Flow) -> datetime.datetime: + def generate(self, f: http.HTTPFlow) -> datetime.datetime: return f.request.timestamp_start or 0 class OrderRequestMethod(_OrderKey): - def generate(self, f: mitmproxy.flow.Flow) -> str: + def generate(self, f: http.HTTPFlow) -> str: return f.request.method class OrderRequestURL(_OrderKey): - def generate(self, f: mitmproxy.flow.Flow) -> str: + def generate(self, f: http.HTTPFlow) -> str: return f.request.url class OrderKeySize(_OrderKey): - def generate(self, f: mitmproxy.flow.Flow) -> int: + def generate(self, f: http.HTTPFlow) -> int: s = 0 if f.request.raw_content: s += len(f.request.raw_content) @@ -118,7 +119,9 @@ class View(collections.Sequence): self.order_reversed = False self.focus_follow = False - self._view = sortedcontainers.SortedListWithKey(key = self.order_key) + self._view = sortedcontainers.SortedListWithKey( + key = self.order_key + ) # The sig_view* signals broadcast events that affect the view. That is, # an update to a flow in the store but not in the view does not trigger @@ -165,7 +168,7 @@ class View(collections.Sequence): def __len__(self): return len(self._view) - def __getitem__(self, offset) -> mitmproxy.flow.Flow: + def __getitem__(self, offset) -> typing.Any: return self._view[self._rev(offset)] # Reflect some methods to the efficient underlying implementation @@ -177,7 +180,7 @@ class View(collections.Sequence): def index(self, f: mitmproxy.flow.Flow, start: int = 0, stop: typing.Optional[int] = None) -> int: return self._rev(self._view.index(f, start, stop)) - def __contains__(self, f: mitmproxy.flow.Flow) -> bool: + def __contains__(self, f: typing.Any) -> bool: return self._view.__contains__(f) def _order_key_name(self): @@ -402,7 +405,7 @@ class Focus: class Settings(collections.Mapping): def __init__(self, view: View) -> None: self.view = view - self._values = {} # type: typing.MutableMapping[str, mitmproxy.flow.Flow] + self._values = {} # type: typing.MutableMapping[str, typing.Dict] view.sig_store_remove.connect(self._sig_store_remove) view.sig_store_refresh.connect(self._sig_store_refresh) diff --git a/mitmproxy/flowfilter.py b/mitmproxy/flowfilter.py index 2c7fc52f..83c98bad 100644 --- a/mitmproxy/flowfilter.py +++ b/mitmproxy/flowfilter.py @@ -44,7 +44,7 @@ from mitmproxy import flow from mitmproxy.utils import strutils import pyparsing as pp -from typing import Callable +from typing import Callable, Sequence, Type # noqa def only(*types): @@ -69,6 +69,8 @@ class _Token: class _Action(_Token): + code = None # type: str + help = None # type: str @classmethod def make(klass, s, loc, toks): @@ -162,15 +164,14 @@ def _check_content_type(rex, message): class FAsset(_Action): code = "a" help = "Match asset in response: CSS, Javascript, Flash, images." - ASSET_TYPES = [ + ASSET_TYPES = [re.compile(x) for x in [ b"text/javascript", b"application/x-javascript", b"application/javascript", b"text/css", b"image/.*", b"application/x-shockwave-flash" - ] - ASSET_TYPES = [re.compile(x) for x in ASSET_TYPES] + ]] @only(http.HTTPFlow) def __call__(self, f): @@ -436,7 +437,7 @@ filter_unary = [ FResp, FTCP, FWebSocket, -] +] # type: Sequence[Type[_Action]] filter_rex = [ FBod, FBodRequest, @@ -452,7 +453,7 @@ filter_rex = [ FMethod, FSrc, FUrl, -] +] # type: Sequence[Type[_Rex]] filter_int = [ FCode ] @@ -538,17 +539,17 @@ def match(flt, flow): help = [] -for i in filter_unary: +for a in filter_unary: help.append( - ("~%s" % i.code, i.help) + ("~%s" % a.code, a.help) ) -for i in filter_rex: +for b in filter_rex: help.append( - ("~%s regex" % i.code, i.help) + ("~%s regex" % b.code, b.help) ) -for i in filter_int: +for c in filter_int: help.append( - ("~%s int" % i.code, i.help) + ("~%s int" % c.code, c.help) ) help.sort() help.extend( diff --git a/mitmproxy/io.py b/mitmproxy/io.py index 780955a4..0f6c3f5c 100644 --- a/mitmproxy/io.py +++ b/mitmproxy/io.py @@ -1,5 +1,5 @@ import os -from typing import Iterable +from typing import Type, Iterable, Dict, Union, Any, cast # noqa from mitmproxy import exceptions from mitmproxy import flow @@ -15,7 +15,7 @@ FLOW_TYPES = dict( http=http.HTTPFlow, websocket=websocket.WebSocketFlow, tcp=tcp.TCPFlow, -) +) # type: Dict[str, Type[flow.Flow]] class FlowWriter: @@ -37,14 +37,18 @@ class FlowReader: """ try: while True: - data = tnetstring.load(self.fo) + # FIXME: This cast hides a lack of dynamic type checking + loaded = cast( + Dict[Union[bytes, str], Any], + tnetstring.load(self.fo), + ) try: - data = io_compat.migrate_flow(data) + mdata = io_compat.migrate_flow(loaded) except ValueError as e: raise exceptions.FlowReadException(str(e)) - if data["type"] not in FLOW_TYPES: - raise exceptions.FlowReadException("Unknown flow type: {}".format(data["type"])) - yield FLOW_TYPES[data["type"]].from_state(data) + if mdata["type"] not in FLOW_TYPES: + raise exceptions.FlowReadException("Unknown flow type: {}".format(mdata["type"])) + yield FLOW_TYPES[mdata["type"]].from_state(mdata) except ValueError as e: if str(e) == "not a tnetstring: empty file": return # Error is due to EOF diff --git a/mitmproxy/io_compat.py b/mitmproxy/io_compat.py index 7d839ffd..9d95f602 100644 --- a/mitmproxy/io_compat.py +++ b/mitmproxy/io_compat.py @@ -2,7 +2,7 @@ This module handles the import of mitmproxy flows generated by old versions. """ import uuid -from typing import Any, Dict +from typing import Any, Dict, Mapping, Union # noqa from mitmproxy import version from mitmproxy.utils import strutils @@ -113,8 +113,8 @@ def convert_300_4(data): return data -client_connections = {} -server_connections = {} +client_connections = {} # type: Mapping[str, str] +server_connections = {} # type: Mapping[str, str] def convert_4_5(data): @@ -187,7 +187,7 @@ converters = { } -def migrate_flow(flow_data: Dict[str, Any]) -> Dict[str, Any]: +def migrate_flow(flow_data: Dict[Union[bytes, str], Any]) -> Dict[Union[bytes, str], Any]: while True: flow_version = flow_data.get(b"version", flow_data.get("version")) diff --git a/mitmproxy/tools/console/flowlist.py b/mitmproxy/tools/console/flowlist.py index 04052ec8..45377b2c 100644 --- a/mitmproxy/tools/console/flowlist.py +++ b/mitmproxy/tools/console/flowlist.py @@ -5,6 +5,7 @@ from mitmproxy.tools.console import common from mitmproxy.tools.console import signals from mitmproxy.addons import view from mitmproxy import export +import mitmproxy.tools.console.master # noqa def _mkhelp(): @@ -305,7 +306,9 @@ class FlowListWalker(urwid.ListWalker): class FlowListBox(urwid.ListBox): - def __init__(self, master: "mitmproxy.tools.console.master.ConsoleMaster"): + def __init__( + self, master: "mitmproxy.tools.console.master.ConsoleMaster" + ) -> None: self.master = master # type: "mitmproxy.tools.console.master.ConsoleMaster" super().__init__(FlowListWalker(master)) diff --git a/mitmproxy/tools/console/flowview.py b/mitmproxy/tools/console/flowview.py index ba41c947..33c8f2ac 100644 --- a/mitmproxy/tools/console/flowview.py +++ b/mitmproxy/tools/console/flowview.py @@ -19,6 +19,7 @@ from mitmproxy.tools.console import overlay from mitmproxy.tools.console import searchable from mitmproxy.tools.console import signals from mitmproxy.tools.console import tabs +import mitmproxy.tools.console.master # noqa class SearchError(Exception): @@ -103,7 +104,11 @@ footer = [ class FlowViewHeader(urwid.WidgetWrap): - def __init__(self, master: "mitmproxy.console.master.ConsoleMaster", f: http.HTTPFlow): + def __init__( + self, + master: "mitmproxy.tools.console.master.ConsoleMaster", + f: http.HTTPFlow + ) -> None: self.master = master self.flow = f self._w = common.format_flow( @@ -651,8 +656,8 @@ class FlowView(tabs.Tabs): ) elif key == "z": self.flow.backup() - e = conn.headers.get("content-encoding", "identity") - if e != "identity": + enc = conn.headers.get("content-encoding", "identity") + if enc != "identity": try: conn.decode() except ValueError: diff --git a/mitmproxy/tools/console/grideditor/base.py b/mitmproxy/tools/console/grideditor/base.py index d2ba47c3..151479a4 100644 --- a/mitmproxy/tools/console/grideditor/base.py +++ b/mitmproxy/tools/console/grideditor/base.py @@ -7,10 +7,12 @@ from typing import Iterable from typing import Optional from typing import Sequence from typing import Tuple +from typing import Set # noqa import urwid from mitmproxy.tools.console import common from mitmproxy.tools.console import signals +import mitmproxy.tools.console.master # noqa FOOTER = [ ('heading_key', "enter"), ":edit ", @@ -34,7 +36,7 @@ class Cell(urwid.WidgetWrap): class Column(metaclass=abc.ABCMeta): - subeditor = None + subeditor = None # type: urwid.Edit def __init__(self, heading): self.heading = heading @@ -62,13 +64,13 @@ class GridRow(urwid.WidgetWrap): editing: bool, editor: "GridEditor", values: Tuple[Iterable[bytes], Container[int]] - ): + ) -> None: self.focused = focused self.editor = editor self.edit_col = None # type: Optional[Cell] errors = values[1] - self.fields = [] + self.fields = [] # type: Sequence[Any] for i, v in enumerate(values[0]): if focused == i and editing: self.edit_col = self.editor.columns[i].Edit(v) @@ -116,8 +118,8 @@ class GridWalker(urwid.ListWalker): self, lst: Iterable[list], editor: "GridEditor" - ): - self.lst = [(i, set()) for i in lst] + ) -> None: + self.lst = [(i, set()) for i in lst] # type: Sequence[Tuple[Any, Set]] self.editor = editor self.focus = 0 self.focus_col = 0 @@ -256,12 +258,12 @@ class GridEditor(urwid.WidgetWrap): def __init__( self, - master: "mitmproxy.console.master.ConsoleMaster", + master: "mitmproxy.tools.console.master.ConsoleMaster", value: Any, callback: Callable[..., None], *cb_args, **cb_kwargs - ): + ) -> None: value = self.data_in(copy.deepcopy(value)) self.master = master self.value = value @@ -380,7 +382,7 @@ class GridEditor(urwid.WidgetWrap): """ Return None, or a string error message. """ - return False + return None def handle_key(self, key): return False diff --git a/mitmproxy/tools/console/grideditor/col_bytes.py b/mitmproxy/tools/console/grideditor/col_bytes.py index f580e947..e4a53453 100644 --- a/mitmproxy/tools/console/grideditor/col_bytes.py +++ b/mitmproxy/tools/console/grideditor/col_bytes.py @@ -9,7 +9,7 @@ from mitmproxy.utils import strutils def read_file(filename: str, callback: Callable[..., None], escaped: bool) -> Optional[str]: if not filename: - return + return None filename = os.path.expanduser(filename) try: @@ -26,6 +26,7 @@ def read_file(filename: str, callback: Callable[..., None], escaped: bool) -> Op # TODO: Refactor the status_prompt_path signal so that we # can raise exceptions here and return the content instead. callback(d) + return None class Column(base.Column): @@ -68,7 +69,7 @@ class Column(base.Column): class Display(base.Cell): - def __init__(self, data: bytes): + def __init__(self, data: bytes) -> None: self.data = data escaped = strutils.bytes_to_escaped_str(data) w = urwid.Text(escaped, wrap="any") @@ -79,7 +80,7 @@ class Display(base.Cell): class Edit(base.Cell): - def __init__(self, data: bytes): + def __init__(self, data: bytes) -> None: data = strutils.bytes_to_escaped_str(data) w = urwid.Edit(edit_text=data, wrap="any", multiline=True) w = urwid.AttrWrap(w, "editfield") diff --git a/mitmproxy/tools/console/grideditor/col_text.py b/mitmproxy/tools/console/grideditor/col_text.py index 430ad037..f0ac06f8 100644 --- a/mitmproxy/tools/console/grideditor/col_text.py +++ b/mitmproxy/tools/console/grideditor/col_text.py @@ -26,12 +26,11 @@ class Column(col_bytes.Column): # This is the same for both edit and display. class EncodingMixin: - def __init__(self, data: str, encoding_args) -> "TDisplay": + def __init__(self, data, encoding_args): self.encoding_args = encoding_args - data = data.encode(*self.encoding_args) - super().__init__(data) + super().__init__(data.encode(*self.encoding_args)) - def get_data(self) -> str: + def get_data(self): data = super().get_data() try: return data.decode(*self.encoding_args) diff --git a/mitmproxy/tools/console/grideditor/editors.py b/mitmproxy/tools/console/grideditor/editors.py index 313495e4..e069fe2f 100644 --- a/mitmproxy/tools/console/grideditor/editors.py +++ b/mitmproxy/tools/console/grideditor/editors.py @@ -248,7 +248,7 @@ class SetCookieEditor(base.GridEditor): class OptionsEditor(base.GridEditor): - title = None + title = None # type: str columns = [ col_text.Column("") ] diff --git a/mitmproxy/tools/console/master.py b/mitmproxy/tools/console/master.py index c0d8e05c..c1d584ac 100644 --- a/mitmproxy/tools/console/master.py +++ b/mitmproxy/tools/console/master.py @@ -70,7 +70,6 @@ class UnsupportedLog: class ConsoleMaster(master.Master): - palette = [] def __init__(self, options, server): super().__init__(options, server) diff --git a/mitmproxy/tools/console/palettes.py b/mitmproxy/tools/console/palettes.py index 7b15f98f..e349f702 100644 --- a/mitmproxy/tools/console/palettes.py +++ b/mitmproxy/tools/console/palettes.py @@ -1,3 +1,4 @@ +import typing # Low-color themes should ONLY use the standard foreground and background # colours listed here: # @@ -32,7 +33,7 @@ class Palette: # Grid Editor 'focusfield', 'focusfield_error', 'field_error', 'editfield', ] - high = None + high = None # type: typing.Mapping[str, typing.Sequence[str]] def palette(self, transparent): l = [] diff --git a/mitmproxy/tools/console/statusbar.py b/mitmproxy/tools/console/statusbar.py index c7132864..d3a3e1f2 100644 --- a/mitmproxy/tools/console/statusbar.py +++ b/mitmproxy/tools/console/statusbar.py @@ -5,6 +5,7 @@ import urwid from mitmproxy.tools.console import common from mitmproxy.tools.console import pathedit from mitmproxy.tools.console import signals +import mitmproxy.tools.console.master # noqa class PromptPath: @@ -135,7 +136,9 @@ class ActionBar(urwid.WidgetWrap): class StatusBar(urwid.WidgetWrap): - def __init__(self, master: "mitmproxy.console.master.ConsoleMaster", helptext): + def __init__( + self, master: "mitmproxy.tools.console.master.ConsoleMaster", helptext + ) -> None: self.master = master self.helptext = helptext self.ib = urwid.WidgetWrap(urwid.Text("")) diff --git a/mitmproxy/tools/web/app.py b/mitmproxy/tools/web/app.py index 002513b9..23d620e0 100644 --- a/mitmproxy/tools/web/app.py +++ b/mitmproxy/tools/web/app.py @@ -17,6 +17,7 @@ from mitmproxy import http from mitmproxy import io from mitmproxy import log from mitmproxy import version +import mitmproxy.tools.web.master # noqa def flow_to_json(flow: mitmproxy.flow.Flow) -> dict: -- cgit v1.2.3 From 4ca78604af2a8ddb596e2f4e95090dabc8495bfe Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 20 Mar 2017 12:50:09 +1300 Subject: Factor out an io module Include tnetstring - we've made enough changes that this no longer belongs in contrib. --- mitmproxy/contrib/tnetstring.py | 250 ------------------------------ mitmproxy/io.py | 87 ----------- mitmproxy/io/__init__.py | 7 + mitmproxy/io/compat.py | 214 +++++++++++++++++++++++++ mitmproxy/io/io.py | 87 +++++++++++ mitmproxy/io/tnetstring.py | 250 ++++++++++++++++++++++++++++++ mitmproxy/io_compat.py | 214 ------------------------- test/mitmproxy/contrib/test_tnetstring.py | 137 ---------------- test/mitmproxy/io/test_compat.py | 28 ++++ test/mitmproxy/io/test_io.py | 1 + test/mitmproxy/io/test_tnetstring.py | 137 ++++++++++++++++ test/mitmproxy/test_flow.py | 2 +- test/mitmproxy/test_io.py | 1 - test/mitmproxy/test_io_compat.py | 28 ---- test/mitmproxy/test_websocket.py | 2 +- 15 files changed, 726 insertions(+), 719 deletions(-) delete mode 100644 mitmproxy/contrib/tnetstring.py delete mode 100644 mitmproxy/io.py create mode 100644 mitmproxy/io/__init__.py create mode 100644 mitmproxy/io/compat.py create mode 100644 mitmproxy/io/io.py create mode 100644 mitmproxy/io/tnetstring.py delete mode 100644 mitmproxy/io_compat.py delete mode 100644 test/mitmproxy/contrib/test_tnetstring.py create mode 100644 test/mitmproxy/io/test_compat.py create mode 100644 test/mitmproxy/io/test_io.py create mode 100644 test/mitmproxy/io/test_tnetstring.py delete mode 100644 test/mitmproxy/test_io.py delete mode 100644 test/mitmproxy/test_io_compat.py diff --git a/mitmproxy/contrib/tnetstring.py b/mitmproxy/contrib/tnetstring.py deleted file mode 100644 index 24ce6ce8..00000000 --- a/mitmproxy/contrib/tnetstring.py +++ /dev/null @@ -1,250 +0,0 @@ -""" -tnetstring: data serialization using typed netstrings -====================================================== - -This is a custom Python 3 implementation of tnetstrings. -Compared to other implementations, the main difference -is that this implementation supports a custom unicode datatype. - -An ordinary tnetstring is a blob of data prefixed with its length and postfixed -with its type. Here are some examples: - - >>> tnetstring.dumps("hello world") - 11:hello world, - >>> tnetstring.dumps(12345) - 5:12345# - >>> tnetstring.dumps([12345, True, 0]) - 19:5:12345#4:true!1:0#] - -This module gives you the following functions: - - :dump: dump an object as a tnetstring to a file - :dumps: dump an object as a tnetstring to a string - :load: load a tnetstring-encoded object from a file - :loads: load a tnetstring-encoded object from a string - -Note that since parsing a tnetstring requires reading all the data into memory -at once, there's no efficiency gain from using the file-based versions of these -functions. They're only here so you can use load() to read precisely one -item from a file or socket without consuming any extra data. - -The tnetstrings specification explicitly states that strings are binary blobs -and forbids the use of unicode at the protocol level. -**This implementation decodes dictionary keys as surrogate-escaped ASCII**, -all other strings are returned as plain bytes. - -:Copyright: (c) 2012-2013 by Ryan Kelly . -:Copyright: (c) 2014 by Carlo Pires . -:Copyright: (c) 2016 by Maximilian Hils . - -:License: MIT -""" - -import collections -from typing import io, Union, Tuple - -TSerializable = Union[None, bool, int, float, bytes, list, tuple, dict] - - -def dumps(value: TSerializable) -> bytes: - """ - This function dumps a python object as a tnetstring. - """ - # This uses a deque to collect output fragments in reverse order, - # then joins them together at the end. It's measurably faster - # than creating all the intermediate strings. - q = collections.deque() - _rdumpq(q, 0, value) - return b''.join(q) - - -def dump(value: TSerializable, file_handle: io.BinaryIO) -> None: - """ - This function dumps a python object as a tnetstring and - writes it to the given file. - """ - file_handle.write(dumps(value)) - - -def _rdumpq(q: collections.deque, size: int, value: TSerializable) -> int: - """ - Dump value as a tnetstring, to a deque instance, last chunks first. - - This function generates the tnetstring representation of the given value, - pushing chunks of the output onto the given deque instance. It pushes - the last chunk first, then recursively generates more chunks. - - When passed in the current size of the string in the queue, it will return - the new size of the string in the queue. - - Operating last-chunk-first makes it easy to calculate the size written - for recursive structures without having to build their representation as - a string. This is measurably faster than generating the intermediate - strings, especially on deeply nested structures. - """ - write = q.appendleft - if value is None: - write(b'0:~') - return size + 3 - elif value is True: - write(b'4:true!') - return size + 7 - elif value is False: - write(b'5:false!') - return size + 8 - elif isinstance(value, int): - data = str(value).encode() - ldata = len(data) - span = str(ldata).encode() - write(b'%s:%s#' % (span, data)) - return size + 2 + len(span) + ldata - elif isinstance(value, float): - # Use repr() for float rather than str(). - # It round-trips more accurately. - # Probably unnecessary in later python versions that - # use David Gay's ftoa routines. - data = repr(value).encode() - ldata = len(data) - span = str(ldata).encode() - write(b'%s:%s^' % (span, data)) - return size + 2 + len(span) + ldata - elif isinstance(value, bytes): - data = value - ldata = len(data) - span = str(ldata).encode() - write(b',') - write(data) - write(b':') - write(span) - return size + 2 + len(span) + ldata - elif isinstance(value, str): - data = value.encode("utf8") - ldata = len(data) - span = str(ldata).encode() - write(b';') - write(data) - write(b':') - write(span) - return size + 2 + len(span) + ldata - elif isinstance(value, (list, tuple)): - write(b']') - init_size = size = size + 1 - for item in reversed(value): - size = _rdumpq(q, size, item) - span = str(size - init_size).encode() - write(b':') - write(span) - return size + 1 + len(span) - elif isinstance(value, dict): - write(b'}') - init_size = size = size + 1 - for (k, v) in value.items(): - size = _rdumpq(q, size, v) - size = _rdumpq(q, size, k) - span = str(size - init_size).encode() - write(b':') - write(span) - return size + 1 + len(span) - else: - raise ValueError("unserializable object: {} ({})".format(value, type(value))) - - -def loads(string: bytes) -> TSerializable: - """ - This function parses a tnetstring into a python object. - """ - return pop(string)[0] - - -def load(file_handle: io.BinaryIO) -> TSerializable: - """load(file) -> object - - This function reads a tnetstring from a file and parses it into a - python object. The file must support the read() method, and this - function promises not to read more data than necessary. - """ - # Read the length prefix one char at a time. - # Note that the netstring spec explicitly forbids padding zeros. - c = file_handle.read(1) - if c == b"": # we want to detect this special case. - raise ValueError("not a tnetstring: empty file") - data_length = b"" - while c.isdigit(): - data_length += c - if len(data_length) > 9: - raise ValueError("not a tnetstring: absurdly large length prefix") - c = file_handle.read(1) - if c != b":": - raise ValueError("not a tnetstring: missing or invalid length prefix") - - data = file_handle.read(int(data_length)) - data_type = file_handle.read(1)[0] - - return parse(data_type, data) - - -def parse(data_type: int, data: bytes) -> TSerializable: - if data_type == ord(b','): - return data - if data_type == ord(b';'): - return data.decode("utf8") - if data_type == ord(b'#'): - try: - return int(data) - except ValueError: - raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) - if data_type == ord(b'^'): - try: - return float(data) - except ValueError: - raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) - if data_type == ord(b'!'): - if data == b'true': - return True - elif data == b'false': - return False - else: - raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) - if data_type == ord(b'~'): - if data: - raise ValueError("not a tnetstring: invalid null literal") - return None - if data_type == ord(b']'): - l = [] - while data: - item, data = pop(data) - l.append(item) - return l - if data_type == ord(b'}'): - d = {} - while data: - key, data = pop(data) - val, data = pop(data) - d[key] = val - return d - raise ValueError("unknown type tag: {}".format(data_type)) - - -def pop(data: bytes) -> Tuple[TSerializable, bytes]: - """ - This function parses a tnetstring into a python object. - It returns a tuple giving the parsed object and a string - containing any unparsed data from the end of the string. - """ - # Parse out data length, type and remaining string. - try: - length, data = data.split(b':', 1) - length = int(length) - except ValueError: - raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data)) - try: - data, data_type, remain = data[:length], data[length], data[length + 1:] - except IndexError: - # This fires if len(data) < dlen, meaning we don't need - # to further validate that data is the right length. - raise ValueError("not a tnetstring: invalid length prefix: {}".format(length)) - # Parse the data based on the type tag. - return parse(data_type, data), remain - - -__all__ = ["dump", "dumps", "load", "loads", "pop"] diff --git a/mitmproxy/io.py b/mitmproxy/io.py deleted file mode 100644 index 0f6c3f5c..00000000 --- a/mitmproxy/io.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -from typing import Type, Iterable, Dict, Union, Any, cast # noqa - -from mitmproxy import exceptions -from mitmproxy import flow -from mitmproxy import flowfilter -from mitmproxy import http -from mitmproxy import tcp -from mitmproxy import websocket -from mitmproxy.contrib import tnetstring -from mitmproxy import io_compat - - -FLOW_TYPES = dict( - http=http.HTTPFlow, - websocket=websocket.WebSocketFlow, - tcp=tcp.TCPFlow, -) # type: Dict[str, Type[flow.Flow]] - - -class FlowWriter: - def __init__(self, fo): - self.fo = fo - - def add(self, flow): - d = flow.get_state() - tnetstring.dump(d, self.fo) - - -class FlowReader: - def __init__(self, fo): - self.fo = fo - - def stream(self) -> Iterable[flow.Flow]: - """ - Yields Flow objects from the dump. - """ - try: - while True: - # FIXME: This cast hides a lack of dynamic type checking - loaded = cast( - Dict[Union[bytes, str], Any], - tnetstring.load(self.fo), - ) - try: - mdata = io_compat.migrate_flow(loaded) - except ValueError as e: - raise exceptions.FlowReadException(str(e)) - if mdata["type"] not in FLOW_TYPES: - raise exceptions.FlowReadException("Unknown flow type: {}".format(mdata["type"])) - yield FLOW_TYPES[mdata["type"]].from_state(mdata) - except ValueError as e: - if str(e) == "not a tnetstring: empty file": - return # Error is due to EOF - raise exceptions.FlowReadException("Invalid data format.") - - -class FilteredFlowWriter: - def __init__(self, fo, flt): - self.fo = fo - self.flt = flt - - def add(self, f: flow.Flow): - if self.flt and not flowfilter.match(self.flt, f): - return - d = f.get_state() - tnetstring.dump(d, self.fo) - - -def read_flows_from_paths(paths): - """ - Given a list of filepaths, read all flows and return a list of them. - From a performance perspective, streaming would be advisable - - however, if there's an error with one of the files, we want it to be raised immediately. - - Raises: - FlowReadException, if any error occurs. - """ - try: - flows = [] - for path in paths: - path = os.path.expanduser(path) - with open(path, "rb") as f: - flows.extend(FlowReader(f).stream()) - except IOError as e: - raise exceptions.FlowReadException(e.strerror) - return flows diff --git a/mitmproxy/io/__init__.py b/mitmproxy/io/__init__.py new file mode 100644 index 00000000..a82f729f --- /dev/null +++ b/mitmproxy/io/__init__.py @@ -0,0 +1,7 @@ + +from .io import FlowWriter, FlowReader, FilteredFlowWriter, read_flows_from_paths + + +__all__ = [ + "FlowWriter", "FlowReader", "FilteredFlowWriter", "read_flows_from_paths" +] \ No newline at end of file diff --git a/mitmproxy/io/compat.py b/mitmproxy/io/compat.py new file mode 100644 index 00000000..9d95f602 --- /dev/null +++ b/mitmproxy/io/compat.py @@ -0,0 +1,214 @@ +""" +This module handles the import of mitmproxy flows generated by old versions. +""" +import uuid +from typing import Any, Dict, Mapping, Union # noqa + +from mitmproxy import version +from mitmproxy.utils import strutils + + +def convert_011_012(data): + data[b"version"] = (0, 12) + return data + + +def convert_012_013(data): + data[b"version"] = (0, 13) + return data + + +def convert_013_014(data): + data[b"request"][b"first_line_format"] = data[b"request"].pop(b"form_in") + data[b"request"][b"http_version"] = b"HTTP/" + ".".join( + str(x) for x in data[b"request"].pop(b"httpversion")).encode() + data[b"response"][b"http_version"] = b"HTTP/" + ".".join( + str(x) for x in data[b"response"].pop(b"httpversion")).encode() + data[b"response"][b"status_code"] = data[b"response"].pop(b"code") + data[b"response"][b"body"] = data[b"response"].pop(b"content") + data[b"server_conn"].pop(b"state") + data[b"server_conn"][b"via"] = None + data[b"version"] = (0, 14) + return data + + +def convert_014_015(data): + data[b"version"] = (0, 15) + return data + + +def convert_015_016(data): + for m in (b"request", b"response"): + if b"body" in data[m]: + data[m][b"content"] = data[m].pop(b"body") + if b"msg" in data[b"response"]: + data[b"response"][b"reason"] = data[b"response"].pop(b"msg") + data[b"request"].pop(b"form_out", None) + data[b"version"] = (0, 16) + return data + + +def convert_016_017(data): + data[b"server_conn"][b"peer_address"] = None + data[b"version"] = (0, 17) + return data + + +def convert_017_018(data): + # convert_unicode needs to be called for every dual release and the first py3-only release + data = convert_unicode(data) + + data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address") + data["marked"] = False + data["version"] = (0, 18) + return data + + +def convert_018_019(data): + # convert_unicode needs to be called for every dual release and the first py3-only release + data = convert_unicode(data) + + data["request"].pop("stickyauth", None) + data["request"].pop("stickycookie", None) + data["client_conn"]["sni"] = None + data["client_conn"]["alpn_proto_negotiated"] = None + data["client_conn"]["cipher_name"] = None + data["client_conn"]["tls_version"] = None + data["server_conn"]["alpn_proto_negotiated"] = None + data["mode"] = "regular" + data["metadata"] = dict() + data["version"] = (0, 19) + return data + + +def convert_019_100(data): + # convert_unicode needs to be called for every dual release and the first py3-only release + data = convert_unicode(data) + + data["version"] = (1, 0, 0) + return data + + +def convert_100_200(data): + data["version"] = (2, 0, 0) + data["client_conn"]["address"] = data["client_conn"]["address"]["address"] + data["server_conn"]["address"] = data["server_conn"]["address"]["address"] + data["server_conn"]["source_address"] = data["server_conn"]["source_address"]["address"] + if data["server_conn"]["ip_address"]: + data["server_conn"]["ip_address"] = data["server_conn"]["ip_address"]["address"] + return data + + +def convert_200_300(data): + data["version"] = (3, 0, 0) + data["client_conn"]["mitmcert"] = None + data["server_conn"]["tls_version"] = None + if data["server_conn"]["via"]: + data["server_conn"]["via"]["tls_version"] = None + return data + + +def convert_300_4(data): + data["version"] = 4 + return data + + +client_connections = {} # type: Mapping[str, str] +server_connections = {} # type: Mapping[str, str] + + +def convert_4_5(data): + data["version"] = 5 + client_conn_key = ( + data["client_conn"]["timestamp_start"], + *data["client_conn"]["address"] + ) + server_conn_key = ( + data["server_conn"]["timestamp_start"], + *data["server_conn"]["source_address"] + ) + data["client_conn"]["id"] = client_connections.setdefault(client_conn_key, str(uuid.uuid4())) + data["server_conn"]["id"] = server_connections.setdefault(server_conn_key, str(uuid.uuid4())) + return data + + +def _convert_dict_keys(o: Any) -> Any: + if isinstance(o, dict): + return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} + else: + return o + + +def _convert_dict_vals(o: dict, values_to_convert: dict) -> dict: + for k, v in values_to_convert.items(): + if not o or k not in o: + continue + if v is True: + o[k] = strutils.always_str(o[k]) + else: + _convert_dict_vals(o[k], v) + return o + + +def convert_unicode(data: dict) -> dict: + """ + This method converts between Python 3 and Python 2 dumpfiles. + """ + data = _convert_dict_keys(data) + data = _convert_dict_vals( + data, { + "type": True, + "id": True, + "request": { + "first_line_format": True + }, + "error": { + "msg": True + } + } + ) + return data + + +converters = { + (0, 11): convert_011_012, + (0, 12): convert_012_013, + (0, 13): convert_013_014, + (0, 14): convert_014_015, + (0, 15): convert_015_016, + (0, 16): convert_016_017, + (0, 17): convert_017_018, + (0, 18): convert_018_019, + (0, 19): convert_019_100, + (1, 0): convert_100_200, + (2, 0): convert_200_300, + (3, 0): convert_300_4, + 4: convert_4_5, +} + + +def migrate_flow(flow_data: Dict[Union[bytes, str], Any]) -> Dict[Union[bytes, str], Any]: + while True: + flow_version = flow_data.get(b"version", flow_data.get("version")) + + # Historically, we used the mitmproxy minor version tuple as the flow format version. + if not isinstance(flow_version, int): + flow_version = tuple(flow_version)[:2] + + if flow_version == version.FLOW_FORMAT_VERSION: + break + elif flow_version in converters: + flow_data = converters[flow_version](flow_data) + else: + should_upgrade = ( + isinstance(flow_version, int) + and flow_version > version.FLOW_FORMAT_VERSION + ) + raise ValueError( + "{} cannot read files with flow format version {}{}.".format( + version.MITMPROXY, + flow_version, + ", please update mitmproxy" if should_upgrade else "" + ) + ) + return flow_data diff --git a/mitmproxy/io/io.py b/mitmproxy/io/io.py new file mode 100644 index 00000000..50e26f49 --- /dev/null +++ b/mitmproxy/io/io.py @@ -0,0 +1,87 @@ +import os +from typing import Type, Iterable, Dict, Union, Any, cast # noqa + +from mitmproxy import exceptions +from mitmproxy import flow +from mitmproxy import flowfilter +from mitmproxy import http +from mitmproxy import tcp +from mitmproxy import websocket + +from mitmproxy.io import compat +from mitmproxy.io import tnetstring + +FLOW_TYPES = dict( + http=http.HTTPFlow, + websocket=websocket.WebSocketFlow, + tcp=tcp.TCPFlow, +) # type: Dict[str, Type[flow.Flow]] + + +class FlowWriter: + def __init__(self, fo): + self.fo = fo + + def add(self, flow): + d = flow.get_state() + tnetstring.dump(d, self.fo) + + +class FlowReader: + def __init__(self, fo): + self.fo = fo + + def stream(self) -> Iterable[flow.Flow]: + """ + Yields Flow objects from the dump. + """ + try: + while True: + # FIXME: This cast hides a lack of dynamic type checking + loaded = cast( + Dict[Union[bytes, str], Any], + tnetstring.load(self.fo), + ) + try: + mdata = compat.migrate_flow(loaded) + except ValueError as e: + raise exceptions.FlowReadException(str(e)) + if mdata["type"] not in FLOW_TYPES: + raise exceptions.FlowReadException("Unknown flow type: {}".format(mdata["type"])) + yield FLOW_TYPES[mdata["type"]].from_state(mdata) + except ValueError as e: + if str(e) == "not a tnetstring: empty file": + return # Error is due to EOF + raise exceptions.FlowReadException("Invalid data format.") + + +class FilteredFlowWriter: + def __init__(self, fo, flt): + self.fo = fo + self.flt = flt + + def add(self, f: flow.Flow): + if self.flt and not flowfilter.match(self.flt, f): + return + d = f.get_state() + tnetstring.dump(d, self.fo) + + +def read_flows_from_paths(paths): + """ + Given a list of filepaths, read all flows and return a list of them. + From a performance perspective, streaming would be advisable - + however, if there's an error with one of the files, we want it to be raised immediately. + + Raises: + FlowReadException, if any error occurs. + """ + try: + flows = [] + for path in paths: + path = os.path.expanduser(path) + with open(path, "rb") as f: + flows.extend(FlowReader(f).stream()) + except IOError as e: + raise exceptions.FlowReadException(e.strerror) + return flows diff --git a/mitmproxy/io/tnetstring.py b/mitmproxy/io/tnetstring.py new file mode 100644 index 00000000..24ce6ce8 --- /dev/null +++ b/mitmproxy/io/tnetstring.py @@ -0,0 +1,250 @@ +""" +tnetstring: data serialization using typed netstrings +====================================================== + +This is a custom Python 3 implementation of tnetstrings. +Compared to other implementations, the main difference +is that this implementation supports a custom unicode datatype. + +An ordinary tnetstring is a blob of data prefixed with its length and postfixed +with its type. Here are some examples: + + >>> tnetstring.dumps("hello world") + 11:hello world, + >>> tnetstring.dumps(12345) + 5:12345# + >>> tnetstring.dumps([12345, True, 0]) + 19:5:12345#4:true!1:0#] + +This module gives you the following functions: + + :dump: dump an object as a tnetstring to a file + :dumps: dump an object as a tnetstring to a string + :load: load a tnetstring-encoded object from a file + :loads: load a tnetstring-encoded object from a string + +Note that since parsing a tnetstring requires reading all the data into memory +at once, there's no efficiency gain from using the file-based versions of these +functions. They're only here so you can use load() to read precisely one +item from a file or socket without consuming any extra data. + +The tnetstrings specification explicitly states that strings are binary blobs +and forbids the use of unicode at the protocol level. +**This implementation decodes dictionary keys as surrogate-escaped ASCII**, +all other strings are returned as plain bytes. + +:Copyright: (c) 2012-2013 by Ryan Kelly . +:Copyright: (c) 2014 by Carlo Pires . +:Copyright: (c) 2016 by Maximilian Hils . + +:License: MIT +""" + +import collections +from typing import io, Union, Tuple + +TSerializable = Union[None, bool, int, float, bytes, list, tuple, dict] + + +def dumps(value: TSerializable) -> bytes: + """ + This function dumps a python object as a tnetstring. + """ + # This uses a deque to collect output fragments in reverse order, + # then joins them together at the end. It's measurably faster + # than creating all the intermediate strings. + q = collections.deque() + _rdumpq(q, 0, value) + return b''.join(q) + + +def dump(value: TSerializable, file_handle: io.BinaryIO) -> None: + """ + This function dumps a python object as a tnetstring and + writes it to the given file. + """ + file_handle.write(dumps(value)) + + +def _rdumpq(q: collections.deque, size: int, value: TSerializable) -> int: + """ + Dump value as a tnetstring, to a deque instance, last chunks first. + + This function generates the tnetstring representation of the given value, + pushing chunks of the output onto the given deque instance. It pushes + the last chunk first, then recursively generates more chunks. + + When passed in the current size of the string in the queue, it will return + the new size of the string in the queue. + + Operating last-chunk-first makes it easy to calculate the size written + for recursive structures without having to build their representation as + a string. This is measurably faster than generating the intermediate + strings, especially on deeply nested structures. + """ + write = q.appendleft + if value is None: + write(b'0:~') + return size + 3 + elif value is True: + write(b'4:true!') + return size + 7 + elif value is False: + write(b'5:false!') + return size + 8 + elif isinstance(value, int): + data = str(value).encode() + ldata = len(data) + span = str(ldata).encode() + write(b'%s:%s#' % (span, data)) + return size + 2 + len(span) + ldata + elif isinstance(value, float): + # Use repr() for float rather than str(). + # It round-trips more accurately. + # Probably unnecessary in later python versions that + # use David Gay's ftoa routines. + data = repr(value).encode() + ldata = len(data) + span = str(ldata).encode() + write(b'%s:%s^' % (span, data)) + return size + 2 + len(span) + ldata + elif isinstance(value, bytes): + data = value + ldata = len(data) + span = str(ldata).encode() + write(b',') + write(data) + write(b':') + write(span) + return size + 2 + len(span) + ldata + elif isinstance(value, str): + data = value.encode("utf8") + ldata = len(data) + span = str(ldata).encode() + write(b';') + write(data) + write(b':') + write(span) + return size + 2 + len(span) + ldata + elif isinstance(value, (list, tuple)): + write(b']') + init_size = size = size + 1 + for item in reversed(value): + size = _rdumpq(q, size, item) + span = str(size - init_size).encode() + write(b':') + write(span) + return size + 1 + len(span) + elif isinstance(value, dict): + write(b'}') + init_size = size = size + 1 + for (k, v) in value.items(): + size = _rdumpq(q, size, v) + size = _rdumpq(q, size, k) + span = str(size - init_size).encode() + write(b':') + write(span) + return size + 1 + len(span) + else: + raise ValueError("unserializable object: {} ({})".format(value, type(value))) + + +def loads(string: bytes) -> TSerializable: + """ + This function parses a tnetstring into a python object. + """ + return pop(string)[0] + + +def load(file_handle: io.BinaryIO) -> TSerializable: + """load(file) -> object + + This function reads a tnetstring from a file and parses it into a + python object. The file must support the read() method, and this + function promises not to read more data than necessary. + """ + # Read the length prefix one char at a time. + # Note that the netstring spec explicitly forbids padding zeros. + c = file_handle.read(1) + if c == b"": # we want to detect this special case. + raise ValueError("not a tnetstring: empty file") + data_length = b"" + while c.isdigit(): + data_length += c + if len(data_length) > 9: + raise ValueError("not a tnetstring: absurdly large length prefix") + c = file_handle.read(1) + if c != b":": + raise ValueError("not a tnetstring: missing or invalid length prefix") + + data = file_handle.read(int(data_length)) + data_type = file_handle.read(1)[0] + + return parse(data_type, data) + + +def parse(data_type: int, data: bytes) -> TSerializable: + if data_type == ord(b','): + return data + if data_type == ord(b';'): + return data.decode("utf8") + if data_type == ord(b'#'): + try: + return int(data) + except ValueError: + raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) + if data_type == ord(b'^'): + try: + return float(data) + except ValueError: + raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) + if data_type == ord(b'!'): + if data == b'true': + return True + elif data == b'false': + return False + else: + raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) + if data_type == ord(b'~'): + if data: + raise ValueError("not a tnetstring: invalid null literal") + return None + if data_type == ord(b']'): + l = [] + while data: + item, data = pop(data) + l.append(item) + return l + if data_type == ord(b'}'): + d = {} + while data: + key, data = pop(data) + val, data = pop(data) + d[key] = val + return d + raise ValueError("unknown type tag: {}".format(data_type)) + + +def pop(data: bytes) -> Tuple[TSerializable, bytes]: + """ + This function parses a tnetstring into a python object. + It returns a tuple giving the parsed object and a string + containing any unparsed data from the end of the string. + """ + # Parse out data length, type and remaining string. + try: + length, data = data.split(b':', 1) + length = int(length) + except ValueError: + raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data)) + try: + data, data_type, remain = data[:length], data[length], data[length + 1:] + except IndexError: + # This fires if len(data) < dlen, meaning we don't need + # to further validate that data is the right length. + raise ValueError("not a tnetstring: invalid length prefix: {}".format(length)) + # Parse the data based on the type tag. + return parse(data_type, data), remain + + +__all__ = ["dump", "dumps", "load", "loads", "pop"] diff --git a/mitmproxy/io_compat.py b/mitmproxy/io_compat.py deleted file mode 100644 index 9d95f602..00000000 --- a/mitmproxy/io_compat.py +++ /dev/null @@ -1,214 +0,0 @@ -""" -This module handles the import of mitmproxy flows generated by old versions. -""" -import uuid -from typing import Any, Dict, Mapping, Union # noqa - -from mitmproxy import version -from mitmproxy.utils import strutils - - -def convert_011_012(data): - data[b"version"] = (0, 12) - return data - - -def convert_012_013(data): - data[b"version"] = (0, 13) - return data - - -def convert_013_014(data): - data[b"request"][b"first_line_format"] = data[b"request"].pop(b"form_in") - data[b"request"][b"http_version"] = b"HTTP/" + ".".join( - str(x) for x in data[b"request"].pop(b"httpversion")).encode() - data[b"response"][b"http_version"] = b"HTTP/" + ".".join( - str(x) for x in data[b"response"].pop(b"httpversion")).encode() - data[b"response"][b"status_code"] = data[b"response"].pop(b"code") - data[b"response"][b"body"] = data[b"response"].pop(b"content") - data[b"server_conn"].pop(b"state") - data[b"server_conn"][b"via"] = None - data[b"version"] = (0, 14) - return data - - -def convert_014_015(data): - data[b"version"] = (0, 15) - return data - - -def convert_015_016(data): - for m in (b"request", b"response"): - if b"body" in data[m]: - data[m][b"content"] = data[m].pop(b"body") - if b"msg" in data[b"response"]: - data[b"response"][b"reason"] = data[b"response"].pop(b"msg") - data[b"request"].pop(b"form_out", None) - data[b"version"] = (0, 16) - return data - - -def convert_016_017(data): - data[b"server_conn"][b"peer_address"] = None - data[b"version"] = (0, 17) - return data - - -def convert_017_018(data): - # convert_unicode needs to be called for every dual release and the first py3-only release - data = convert_unicode(data) - - data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address") - data["marked"] = False - data["version"] = (0, 18) - return data - - -def convert_018_019(data): - # convert_unicode needs to be called for every dual release and the first py3-only release - data = convert_unicode(data) - - data["request"].pop("stickyauth", None) - data["request"].pop("stickycookie", None) - data["client_conn"]["sni"] = None - data["client_conn"]["alpn_proto_negotiated"] = None - data["client_conn"]["cipher_name"] = None - data["client_conn"]["tls_version"] = None - data["server_conn"]["alpn_proto_negotiated"] = None - data["mode"] = "regular" - data["metadata"] = dict() - data["version"] = (0, 19) - return data - - -def convert_019_100(data): - # convert_unicode needs to be called for every dual release and the first py3-only release - data = convert_unicode(data) - - data["version"] = (1, 0, 0) - return data - - -def convert_100_200(data): - data["version"] = (2, 0, 0) - data["client_conn"]["address"] = data["client_conn"]["address"]["address"] - data["server_conn"]["address"] = data["server_conn"]["address"]["address"] - data["server_conn"]["source_address"] = data["server_conn"]["source_address"]["address"] - if data["server_conn"]["ip_address"]: - data["server_conn"]["ip_address"] = data["server_conn"]["ip_address"]["address"] - return data - - -def convert_200_300(data): - data["version"] = (3, 0, 0) - data["client_conn"]["mitmcert"] = None - data["server_conn"]["tls_version"] = None - if data["server_conn"]["via"]: - data["server_conn"]["via"]["tls_version"] = None - return data - - -def convert_300_4(data): - data["version"] = 4 - return data - - -client_connections = {} # type: Mapping[str, str] -server_connections = {} # type: Mapping[str, str] - - -def convert_4_5(data): - data["version"] = 5 - client_conn_key = ( - data["client_conn"]["timestamp_start"], - *data["client_conn"]["address"] - ) - server_conn_key = ( - data["server_conn"]["timestamp_start"], - *data["server_conn"]["source_address"] - ) - data["client_conn"]["id"] = client_connections.setdefault(client_conn_key, str(uuid.uuid4())) - data["server_conn"]["id"] = server_connections.setdefault(server_conn_key, str(uuid.uuid4())) - return data - - -def _convert_dict_keys(o: Any) -> Any: - if isinstance(o, dict): - return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} - else: - return o - - -def _convert_dict_vals(o: dict, values_to_convert: dict) -> dict: - for k, v in values_to_convert.items(): - if not o or k not in o: - continue - if v is True: - o[k] = strutils.always_str(o[k]) - else: - _convert_dict_vals(o[k], v) - return o - - -def convert_unicode(data: dict) -> dict: - """ - This method converts between Python 3 and Python 2 dumpfiles. - """ - data = _convert_dict_keys(data) - data = _convert_dict_vals( - data, { - "type": True, - "id": True, - "request": { - "first_line_format": True - }, - "error": { - "msg": True - } - } - ) - return data - - -converters = { - (0, 11): convert_011_012, - (0, 12): convert_012_013, - (0, 13): convert_013_014, - (0, 14): convert_014_015, - (0, 15): convert_015_016, - (0, 16): convert_016_017, - (0, 17): convert_017_018, - (0, 18): convert_018_019, - (0, 19): convert_019_100, - (1, 0): convert_100_200, - (2, 0): convert_200_300, - (3, 0): convert_300_4, - 4: convert_4_5, -} - - -def migrate_flow(flow_data: Dict[Union[bytes, str], Any]) -> Dict[Union[bytes, str], Any]: - while True: - flow_version = flow_data.get(b"version", flow_data.get("version")) - - # Historically, we used the mitmproxy minor version tuple as the flow format version. - if not isinstance(flow_version, int): - flow_version = tuple(flow_version)[:2] - - if flow_version == version.FLOW_FORMAT_VERSION: - break - elif flow_version in converters: - flow_data = converters[flow_version](flow_data) - else: - should_upgrade = ( - isinstance(flow_version, int) - and flow_version > version.FLOW_FORMAT_VERSION - ) - raise ValueError( - "{} cannot read files with flow format version {}{}.".format( - version.MITMPROXY, - flow_version, - ", please update mitmproxy" if should_upgrade else "" - ) - ) - return flow_data diff --git a/test/mitmproxy/contrib/test_tnetstring.py b/test/mitmproxy/contrib/test_tnetstring.py deleted file mode 100644 index 05c4a7c9..00000000 --- a/test/mitmproxy/contrib/test_tnetstring.py +++ /dev/null @@ -1,137 +0,0 @@ -import unittest -import random -import math -import io -import struct - -from mitmproxy.contrib import tnetstring - -MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1 - -FORMAT_EXAMPLES = { - b'0:}': {}, - b'0:]': [], - b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}': - {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, - b'5:12345#': 12345, - b'12:this is cool,': b'this is cool', - b'19:this is unicode \xe2\x98\x85;': u'this is unicode \u2605', - b'0:,': b'', - b'0:;': u'', - b'0:~': None, - b'4:true!': True, - b'5:false!': False, - b'10:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', - b'24:5:12345#5:67890#5:xxxxx,]': [12345, 67890, b'xxxxx'], - b'18:3:0.1^3:0.2^3:0.3^]': [0.1, 0.2, 0.3], - b'243:238:233:228:223:218:213:208:203:198:193:188:183:178:173:168:163:158:153:148:143:138:133:128:123:118:113:108:103:99:95:91:87:83:79:75:71:67:63:59:55:51:47:43:39:35:31:27:23:19:15:11:hello-there,]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]': [[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[b'hello-there']]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]] # noqa -} - - -def get_random_object(random=random, depth=0): - """Generate a random serializable object.""" - # The probability of generating a scalar value increases as the depth increase. - # This ensures that we bottom out eventually. - if random.randint(depth, 10) <= 4: - what = random.randint(0, 1) - if what == 0: - n = random.randint(0, 10) - l = [] - for _ in range(n): - l.append(get_random_object(random, depth + 1)) - return l - if what == 1: - n = random.randint(0, 10) - d = {} - for _ in range(n): - n = random.randint(0, 100) - k = str([random.randint(32, 126) for _ in range(n)]) - d[k] = get_random_object(random, depth + 1) - return d - else: - what = random.randint(0, 4) - if what == 0: - return None - if what == 1: - return True - if what == 2: - return False - if what == 3: - if random.randint(0, 1) == 0: - return random.randint(0, MAXINT) - else: - return -1 * random.randint(0, MAXINT) - n = random.randint(0, 100) - return bytes([random.randint(32, 126) for _ in range(n)]) - - -class Test_Format(unittest.TestCase): - - def test_roundtrip_format_examples(self): - for data, expect in FORMAT_EXAMPLES.items(): - self.assertEqual(expect, tnetstring.loads(data)) - self.assertEqual( - expect, tnetstring.loads(tnetstring.dumps(expect))) - self.assertEqual((expect, b''), tnetstring.pop(data)) - - def test_roundtrip_format_random(self): - for _ in range(500): - v = get_random_object() - self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) - self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v))) - - def test_roundtrip_format_unicode(self): - for _ in range(500): - v = get_random_object() - self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) - self.assertEqual((v, b''), tnetstring.pop(tnetstring.dumps(v))) - - def test_roundtrip_big_integer(self): - i1 = math.factorial(30000) - s = tnetstring.dumps(i1) - i2 = tnetstring.loads(s) - self.assertEqual(i1, i2) - - -class Test_FileLoading(unittest.TestCase): - - def test_roundtrip_file_examples(self): - for data, expect in FORMAT_EXAMPLES.items(): - s = io.BytesIO() - s.write(data) - s.write(b'OK') - s.seek(0) - self.assertEqual(expect, tnetstring.load(s)) - self.assertEqual(b'OK', s.read()) - s = io.BytesIO() - tnetstring.dump(expect, s) - s.write(b'OK') - s.seek(0) - self.assertEqual(expect, tnetstring.load(s)) - self.assertEqual(b'OK', s.read()) - - def test_roundtrip_file_random(self): - for _ in range(500): - v = get_random_object() - s = io.BytesIO() - tnetstring.dump(v, s) - s.write(b'OK') - s.seek(0) - self.assertEqual(v, tnetstring.load(s)) - self.assertEqual(b'OK', s.read()) - - def test_error_on_absurd_lengths(self): - s = io.BytesIO() - s.write(b'1000000000:pwned!,') - s.seek(0) - with self.assertRaises(ValueError): - tnetstring.load(s) - self.assertEqual(s.read(1), b':') - - -def suite(): - loader = unittest.TestLoader() - suite = unittest.TestSuite() - suite.addTest(loader.loadTestsFromTestCase(Test_Format)) - suite.addTest(loader.loadTestsFromTestCase(Test_FileLoading)) - return suite diff --git a/test/mitmproxy/io/test_compat.py b/test/mitmproxy/io/test_compat.py new file mode 100644 index 00000000..288de4fc --- /dev/null +++ b/test/mitmproxy/io/test_compat.py @@ -0,0 +1,28 @@ +import pytest + +from mitmproxy import io +from mitmproxy import exceptions +from mitmproxy.test import tutils + + +def test_load(): + with open(tutils.test_data.path("mitmproxy/data/dumpfile-011"), "rb") as f: + flow_reader = io.FlowReader(f) + flows = list(flow_reader.stream()) + assert len(flows) == 1 + assert flows[0].request.url == "https://example.com/" + + +def test_load_018(): + with open(tutils.test_data.path("mitmproxy/data/dumpfile-018"), "rb") as f: + flow_reader = io.FlowReader(f) + flows = list(flow_reader.stream()) + assert len(flows) == 1 + assert flows[0].request.url == "https://www.example.com/" + + +def test_cannot_convert(): + with open(tutils.test_data.path("mitmproxy/data/dumpfile-010"), "rb") as f: + flow_reader = io.FlowReader(f) + with pytest.raises(exceptions.FlowReadException): + list(flow_reader.stream()) diff --git a/test/mitmproxy/io/test_io.py b/test/mitmproxy/io/test_io.py new file mode 100644 index 00000000..777ab4dd --- /dev/null +++ b/test/mitmproxy/io/test_io.py @@ -0,0 +1 @@ +# TODO: write tests diff --git a/test/mitmproxy/io/test_tnetstring.py b/test/mitmproxy/io/test_tnetstring.py new file mode 100644 index 00000000..f7141de0 --- /dev/null +++ b/test/mitmproxy/io/test_tnetstring.py @@ -0,0 +1,137 @@ +import unittest +import random +import math +import io +import struct + +from mitmproxy.io import tnetstring + +MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1 + +FORMAT_EXAMPLES = { + b'0:}': {}, + b'0:]': [], + b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}': + {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, + b'5:12345#': 12345, + b'12:this is cool,': b'this is cool', + b'19:this is unicode \xe2\x98\x85;': u'this is unicode \u2605', + b'0:,': b'', + b'0:;': u'', + b'0:~': None, + b'4:true!': True, + b'5:false!': False, + b'10:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + b'24:5:12345#5:67890#5:xxxxx,]': [12345, 67890, b'xxxxx'], + b'18:3:0.1^3:0.2^3:0.3^]': [0.1, 0.2, 0.3], + b'243:238:233:228:223:218:213:208:203:198:193:188:183:178:173:168:163:158:153:148:143:138:133:128:123:118:113:108:103:99:95:91:87:83:79:75:71:67:63:59:55:51:47:43:39:35:31:27:23:19:15:11:hello-there,]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]': [[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[b'hello-there']]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]] # noqa +} + + +def get_random_object(random=random, depth=0): + """Generate a random serializable object.""" + # The probability of generating a scalar value increases as the depth increase. + # This ensures that we bottom out eventually. + if random.randint(depth, 10) <= 4: + what = random.randint(0, 1) + if what == 0: + n = random.randint(0, 10) + l = [] + for _ in range(n): + l.append(get_random_object(random, depth + 1)) + return l + if what == 1: + n = random.randint(0, 10) + d = {} + for _ in range(n): + n = random.randint(0, 100) + k = str([random.randint(32, 126) for _ in range(n)]) + d[k] = get_random_object(random, depth + 1) + return d + else: + what = random.randint(0, 4) + if what == 0: + return None + if what == 1: + return True + if what == 2: + return False + if what == 3: + if random.randint(0, 1) == 0: + return random.randint(0, MAXINT) + else: + return -1 * random.randint(0, MAXINT) + n = random.randint(0, 100) + return bytes([random.randint(32, 126) for _ in range(n)]) + + +class Test_Format(unittest.TestCase): + + def test_roundtrip_format_examples(self): + for data, expect in FORMAT_EXAMPLES.items(): + self.assertEqual(expect, tnetstring.loads(data)) + self.assertEqual( + expect, tnetstring.loads(tnetstring.dumps(expect))) + self.assertEqual((expect, b''), tnetstring.pop(data)) + + def test_roundtrip_format_random(self): + for _ in range(500): + v = get_random_object() + self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) + self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v))) + + def test_roundtrip_format_unicode(self): + for _ in range(500): + v = get_random_object() + self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) + self.assertEqual((v, b''), tnetstring.pop(tnetstring.dumps(v))) + + def test_roundtrip_big_integer(self): + i1 = math.factorial(30000) + s = tnetstring.dumps(i1) + i2 = tnetstring.loads(s) + self.assertEqual(i1, i2) + + +class Test_FileLoading(unittest.TestCase): + + def test_roundtrip_file_examples(self): + for data, expect in FORMAT_EXAMPLES.items(): + s = io.BytesIO() + s.write(data) + s.write(b'OK') + s.seek(0) + self.assertEqual(expect, tnetstring.load(s)) + self.assertEqual(b'OK', s.read()) + s = io.BytesIO() + tnetstring.dump(expect, s) + s.write(b'OK') + s.seek(0) + self.assertEqual(expect, tnetstring.load(s)) + self.assertEqual(b'OK', s.read()) + + def test_roundtrip_file_random(self): + for _ in range(500): + v = get_random_object() + s = io.BytesIO() + tnetstring.dump(v, s) + s.write(b'OK') + s.seek(0) + self.assertEqual(v, tnetstring.load(s)) + self.assertEqual(b'OK', s.read()) + + def test_error_on_absurd_lengths(self): + s = io.BytesIO() + s.write(b'1000000000:pwned!,') + s.seek(0) + with self.assertRaises(ValueError): + tnetstring.load(s) + self.assertEqual(s.read(1), b':') + + +def suite(): + loader = unittest.TestLoader() + suite = unittest.TestSuite() + suite.addTest(loader.loadTestsFromTestCase(Test_Format)) + suite.addTest(loader.loadTestsFromTestCase(Test_FileLoading)) + return suite diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 630fc7e4..78f893c0 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -6,7 +6,7 @@ import mitmproxy.io from mitmproxy import flowfilter from mitmproxy import options from mitmproxy.proxy import config -from mitmproxy.contrib import tnetstring +from mitmproxy.io import tnetstring from mitmproxy.exceptions import FlowReadException from mitmproxy import flow from mitmproxy import http diff --git a/test/mitmproxy/test_io.py b/test/mitmproxy/test_io.py deleted file mode 100644 index 777ab4dd..00000000 --- a/test/mitmproxy/test_io.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: write tests diff --git a/test/mitmproxy/test_io_compat.py b/test/mitmproxy/test_io_compat.py deleted file mode 100644 index 288de4fc..00000000 --- a/test/mitmproxy/test_io_compat.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest - -from mitmproxy import io -from mitmproxy import exceptions -from mitmproxy.test import tutils - - -def test_load(): - with open(tutils.test_data.path("mitmproxy/data/dumpfile-011"), "rb") as f: - flow_reader = io.FlowReader(f) - flows = list(flow_reader.stream()) - assert len(flows) == 1 - assert flows[0].request.url == "https://example.com/" - - -def test_load_018(): - with open(tutils.test_data.path("mitmproxy/data/dumpfile-018"), "rb") as f: - flow_reader = io.FlowReader(f) - flows = list(flow_reader.stream()) - assert len(flows) == 1 - assert flows[0].request.url == "https://www.example.com/" - - -def test_cannot_convert(): - with open(tutils.test_data.path("mitmproxy/data/dumpfile-010"), "rb") as f: - flow_reader = io.FlowReader(f) - with pytest.raises(exceptions.FlowReadException): - list(flow_reader.stream()) diff --git a/test/mitmproxy/test_websocket.py b/test/mitmproxy/test_websocket.py index 62f69e2d..7c53a4b0 100644 --- a/test/mitmproxy/test_websocket.py +++ b/test/mitmproxy/test_websocket.py @@ -1,7 +1,7 @@ import io import pytest -from mitmproxy.contrib import tnetstring +from mitmproxy.io import tnetstring from mitmproxy import flowfilter from mitmproxy.test import tflow -- cgit v1.2.3 From cacad8373baade87a160c11dbef728739e6f4848 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 20 Mar 2017 13:09:24 +1300 Subject: Make tnetstrings pass mypy Mypy doesn't support recursive types yet, so we can't properly express TSerializable nested structures. For now, we just disable type checking in the appropriate locations. https://github.com/python/mypy/issues/731 --- mitmproxy/io/tnetstring.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mitmproxy/io/tnetstring.py b/mitmproxy/io/tnetstring.py index 24ce6ce8..82c92f33 100644 --- a/mitmproxy/io/tnetstring.py +++ b/mitmproxy/io/tnetstring.py @@ -41,9 +41,9 @@ all other strings are returned as plain bytes. """ import collections -from typing import io, Union, Tuple +import typing -TSerializable = Union[None, bool, int, float, bytes, list, tuple, dict] +TSerializable = typing.Union[None, str, bool, int, float, bytes, list, tuple, dict] def dumps(value: TSerializable) -> bytes: @@ -53,12 +53,12 @@ def dumps(value: TSerializable) -> bytes: # This uses a deque to collect output fragments in reverse order, # then joins them together at the end. It's measurably faster # than creating all the intermediate strings. - q = collections.deque() + q = collections.deque() # type: collections.deque _rdumpq(q, 0, value) return b''.join(q) -def dump(value: TSerializable, file_handle: io.BinaryIO) -> None: +def dump(value: TSerializable, file_handle: typing.BinaryIO) -> None: """ This function dumps a python object as a tnetstring and writes it to the given file. @@ -156,7 +156,7 @@ def loads(string: bytes) -> TSerializable: return pop(string)[0] -def load(file_handle: io.BinaryIO) -> TSerializable: +def load(file_handle: typing.BinaryIO) -> TSerializable: """load(file) -> object This function reads a tnetstring from a file and parses it into a @@ -213,19 +213,19 @@ def parse(data_type: int, data: bytes) -> TSerializable: l = [] while data: item, data = pop(data) - l.append(item) + l.append(item) # type: ignore return l if data_type == ord(b'}'): d = {} while data: key, data = pop(data) val, data = pop(data) - d[key] = val + d[key] = val # type: ignore return d raise ValueError("unknown type tag: {}".format(data_type)) -def pop(data: bytes) -> Tuple[TSerializable, bytes]: +def pop(data: bytes) -> typing.Tuple[TSerializable, bytes]: """ This function parses a tnetstring into a python object. It returns a tuple giving the parsed object and a string @@ -233,8 +233,8 @@ def pop(data: bytes) -> Tuple[TSerializable, bytes]: """ # Parse out data length, type and remaining string. try: - length, data = data.split(b':', 1) - length = int(length) + blength, data = data.split(b':', 1) + length = int(blength) except ValueError: raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data)) try: -- cgit v1.2.3 From 95d9ec88ac3570984c46fab0b3e5fc5ea7500d78 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 20 Mar 2017 13:13:40 +1300 Subject: tox: mypy checking for entire codebase Also fix a few linting errors. --- mitmproxy/io/__init__.py | 2 +- mitmproxy/tools/console/flowlist.py | 2 +- mitmproxy/tools/console/palettes.py | 2 +- setup.cfg | 7 ++++--- tox.ini | 12 +----------- 5 files changed, 8 insertions(+), 17 deletions(-) diff --git a/mitmproxy/io/__init__.py b/mitmproxy/io/__init__.py index a82f729f..540e6871 100644 --- a/mitmproxy/io/__init__.py +++ b/mitmproxy/io/__init__.py @@ -4,4 +4,4 @@ from .io import FlowWriter, FlowReader, FilteredFlowWriter, read_flows_from_path __all__ = [ "FlowWriter", "FlowReader", "FilteredFlowWriter", "read_flows_from_paths" -] \ No newline at end of file +] diff --git a/mitmproxy/tools/console/flowlist.py b/mitmproxy/tools/console/flowlist.py index 45377b2c..31d48ee3 100644 --- a/mitmproxy/tools/console/flowlist.py +++ b/mitmproxy/tools/console/flowlist.py @@ -308,7 +308,7 @@ class FlowListBox(urwid.ListBox): def __init__( self, master: "mitmproxy.tools.console.master.ConsoleMaster" - ) -> None: + ) -> None: self.master = master # type: "mitmproxy.tools.console.master.ConsoleMaster" super().__init__(FlowListWalker(master)) diff --git a/mitmproxy/tools/console/palettes.py b/mitmproxy/tools/console/palettes.py index e349f702..7fbdcfd8 100644 --- a/mitmproxy/tools/console/palettes.py +++ b/mitmproxy/tools/console/palettes.py @@ -1,4 +1,4 @@ -import typing +import typing # noqa # Low-color themes should ONLY use the standard foreground and background # colours listed here: # diff --git a/setup.cfg b/setup.cfg index 7fbb7f73..8e231f28 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ exclude = mitmproxy/controller.py mitmproxy/export.py mitmproxy/flow.py - mitmproxy/io_compat.py + mitmproxy/io/compat.py mitmproxy/master.py pathod/pathoc.py pathod/pathod.py @@ -57,8 +57,9 @@ exclude = mitmproxy/exceptions.py mitmproxy/export.py mitmproxy/flow.py - mitmproxy/io.py - mitmproxy/io_compat.py + mitmproxy/io/io.py + mitmproxy/io/compat.py + mitmproxy/io/tnetstring.py mitmproxy/log.py mitmproxy/master.py mitmproxy/net/check.py diff --git a/tox.ini b/tox.ini index a1ed53f7..fafb455e 100644 --- a/tox.ini +++ b/tox.ini @@ -27,17 +27,7 @@ commands = flake8 --jobs 8 mitmproxy pathod examples test release python3 test/filename_matching.py rstcheck README.rst - mypy --ignore-missing-imports --follow-imports=skip \ - mitmproxy/addons/ \ - mitmproxy/addonmanager.py \ - mitmproxy/optmanager.py \ - mitmproxy/proxy/protocol/ \ - mitmproxy/log.py \ - mitmproxy/tools/dump.py \ - mitmproxy/tools/web/ \ - mitmproxy/contentviews/ - mypy --ignore-missing-imports \ - mitmproxy/master.py + mypy --ignore-missing-imports ./mitmproxy [testenv:individual_coverage] deps = -- cgit v1.2.3