From 978b8d095c3106e973258376e4a15264288d20f2 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 17 Dec 2017 13:31:36 +1300 Subject: mitmproxy.types -> mitmproxy.coretypes The types name is valuable, and we have a better use for it in collecting and exposing types for options and commands. The coretypes module should probably be split up anyway - it contains a threading base class, a few container objects, and the defintion of our serialization protocol. I was tempted to rename it to "uncagegorized" for the sake of honesty. --- mitmproxy/certs.py | 2 +- mitmproxy/contentviews/image/view.py | 2 +- mitmproxy/contentviews/multipart.py | 2 +- mitmproxy/contentviews/urlencoded.py | 2 +- mitmproxy/coretypes/__init__.py | 0 mitmproxy/coretypes/basethread.py | 14 ++ mitmproxy/coretypes/bidi.py | 29 ++++ mitmproxy/coretypes/multidict.py | 216 ++++++++++++++++++++++++++ mitmproxy/coretypes/serializable.py | 36 +++++ mitmproxy/master.py | 2 +- mitmproxy/net/http/cookies.py | 2 +- mitmproxy/net/http/headers.py | 2 +- mitmproxy/net/http/message.py | 2 +- mitmproxy/net/http/request.py | 2 +- mitmproxy/net/http/response.py | 2 +- mitmproxy/net/socks.py | 2 +- mitmproxy/net/tcp.py | 2 +- mitmproxy/net/websockets/frame.py | 2 +- mitmproxy/proxy/protocol/http2.py | 2 +- mitmproxy/proxy/protocol/http_replay.py | 2 +- mitmproxy/script/concurrent.py | 2 +- mitmproxy/stateobject.py | 2 +- mitmproxy/tcp.py | 2 +- mitmproxy/types/__init__.py | 0 mitmproxy/types/basethread.py | 14 -- mitmproxy/types/bidi.py | 29 ---- mitmproxy/types/multidict.py | 216 -------------------------- mitmproxy/types/serializable.py | 36 ----- mitmproxy/websocket.py | 2 +- pathod/pathoc.py | 2 +- pathod/protocols/http2.py | 2 +- pathod/test.py | 2 +- test/mitmproxy/contentviews/test_auto.py | 2 +- test/mitmproxy/contentviews/test_query.py | 2 +- test/mitmproxy/coretypes/__init__.py | 0 test/mitmproxy/coretypes/test_basethread.py | 7 + test/mitmproxy/coretypes/test_bidi.py | 13 ++ test/mitmproxy/coretypes/test_multidict.py | 211 +++++++++++++++++++++++++ test/mitmproxy/coretypes/test_serializable.py | 39 +++++ test/mitmproxy/types/__init__.py | 0 test/mitmproxy/types/test_basethread.py | 7 - test/mitmproxy/types/test_bidi.py | 13 -- test/mitmproxy/types/test_multidict.py | 211 ------------------------- test/mitmproxy/types/test_serializable.py | 39 ----- 44 files changed, 589 insertions(+), 589 deletions(-) create mode 100644 mitmproxy/coretypes/__init__.py create mode 100644 mitmproxy/coretypes/basethread.py create mode 100644 mitmproxy/coretypes/bidi.py create mode 100644 mitmproxy/coretypes/multidict.py create mode 100644 mitmproxy/coretypes/serializable.py delete mode 100644 mitmproxy/types/__init__.py delete mode 100644 mitmproxy/types/basethread.py delete mode 100644 mitmproxy/types/bidi.py delete mode 100644 mitmproxy/types/multidict.py delete mode 100644 mitmproxy/types/serializable.py create mode 100644 test/mitmproxy/coretypes/__init__.py create mode 100644 test/mitmproxy/coretypes/test_basethread.py create mode 100644 test/mitmproxy/coretypes/test_bidi.py create mode 100644 test/mitmproxy/coretypes/test_multidict.py create mode 100644 test/mitmproxy/coretypes/test_serializable.py delete mode 100644 test/mitmproxy/types/__init__.py delete mode 100644 test/mitmproxy/types/test_basethread.py delete mode 100644 test/mitmproxy/types/test_bidi.py delete mode 100644 test/mitmproxy/types/test_multidict.py delete mode 100644 test/mitmproxy/types/test_serializable.py diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py index 572a12d0..c29d67f3 100644 --- a/mitmproxy/certs.py +++ b/mitmproxy/certs.py @@ -11,7 +11,7 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 diff --git a/mitmproxy/contentviews/image/view.py b/mitmproxy/contentviews/image/view.py index 6f75473b..fde9b39d 100644 --- a/mitmproxy/contentviews/image/view.py +++ b/mitmproxy/contentviews/image/view.py @@ -1,7 +1,7 @@ import imghdr from mitmproxy.contentviews import base -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import image_parser diff --git a/mitmproxy/contentviews/multipart.py b/mitmproxy/contentviews/multipart.py index 0b0e51e2..be3dc135 100644 --- a/mitmproxy/contentviews/multipart.py +++ b/mitmproxy/contentviews/multipart.py @@ -1,5 +1,5 @@ from mitmproxy.net import http -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import base diff --git a/mitmproxy/contentviews/urlencoded.py b/mitmproxy/contentviews/urlencoded.py index 79fe9c1c..a24f342a 100644 --- a/mitmproxy/contentviews/urlencoded.py +++ b/mitmproxy/contentviews/urlencoded.py @@ -1,5 +1,5 @@ from mitmproxy.net.http import url -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import base diff --git a/mitmproxy/coretypes/__init__.py b/mitmproxy/coretypes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mitmproxy/coretypes/basethread.py b/mitmproxy/coretypes/basethread.py new file mode 100644 index 00000000..a3c81d19 --- /dev/null +++ b/mitmproxy/coretypes/basethread.py @@ -0,0 +1,14 @@ +import time +import threading + + +class BaseThread(threading.Thread): + def __init__(self, name, *args, **kwargs): + super().__init__(name=name, *args, **kwargs) + self._thread_started = time.time() + + def _threadinfo(self): + return "%s - age: %is" % ( + self.name, + int(time.time() - self._thread_started) + ) diff --git a/mitmproxy/coretypes/bidi.py b/mitmproxy/coretypes/bidi.py new file mode 100644 index 00000000..0982a34a --- /dev/null +++ b/mitmproxy/coretypes/bidi.py @@ -0,0 +1,29 @@ + + +class BiDi: + + """ + A wee utility class for keeping bi-directional mappings, like field + constants in protocols. Names are attributes on the object, dict-like + access maps values to names: + + CONST = BiDi(a=1, b=2) + assert CONST.a == 1 + assert CONST.get_name(1) == "a" + """ + + def __init__(self, **kwargs): + self.names = kwargs + self.values = {} + for k, v in kwargs.items(): + self.values[v] = k + if len(self.names) != len(self.values): + raise ValueError("Duplicate values not allowed.") + + def __getattr__(self, k): + if k in self.names: + return self.names[k] + raise AttributeError("No such attribute: %s", k) + + def get_name(self, n, default=None): + return self.values.get(n, default) diff --git a/mitmproxy/coretypes/multidict.py b/mitmproxy/coretypes/multidict.py new file mode 100644 index 00000000..90f3013e --- /dev/null +++ b/mitmproxy/coretypes/multidict.py @@ -0,0 +1,216 @@ +from abc import ABCMeta, abstractmethod + +from collections.abc import MutableMapping +from mitmproxy.coretypes import serializable + + +class _MultiDict(MutableMapping, metaclass=ABCMeta): + def __repr__(self): + fields = ( + repr(field) + for field in self.fields + ) + return "{cls}[{fields}]".format( + cls=type(self).__name__, + fields=", ".join(fields) + ) + + @staticmethod + @abstractmethod + def _reduce_values(values): + """ + If a user accesses multidict["foo"], this method + reduces all values for "foo" to a single value that is returned. + For example, HTTP headers are folded, whereas we will just take + the first cookie we found with that name. + """ + + @staticmethod + @abstractmethod + def _kconv(key): + """ + This method converts a key to its canonical representation. + For example, HTTP headers are case-insensitive, so this method returns key.lower(). + """ + + def __getitem__(self, key): + values = self.get_all(key) + if not values: + raise KeyError(key) + return self._reduce_values(values) + + def __setitem__(self, key, value): + self.set_all(key, [value]) + + def __delitem__(self, key): + if key not in self: + raise KeyError(key) + key = self._kconv(key) + self.fields = tuple( + field for field in self.fields + if key != self._kconv(field[0]) + ) + + def __iter__(self): + seen = set() + for key, _ in self.fields: + key_kconv = self._kconv(key) + if key_kconv not in seen: + seen.add(key_kconv) + yield key + + def __len__(self): + return len(set(self._kconv(key) for key, _ in self.fields)) + + def __eq__(self, other): + if isinstance(other, MultiDict): + return self.fields == other.fields + return False + + def get_all(self, key): + """ + Return the list of all values for a given key. + If that key is not in the MultiDict, the return value will be an empty list. + """ + key = self._kconv(key) + return [ + value + for k, value in self.fields + if self._kconv(k) == key + ] + + def set_all(self, key, values): + """ + Remove the old values for a key and add new ones. + """ + key_kconv = self._kconv(key) + + new_fields = [] + for field in self.fields: + if self._kconv(field[0]) == key_kconv: + if values: + new_fields.append( + (field[0], values.pop(0)) + ) + else: + new_fields.append(field) + while values: + new_fields.append( + (key, values.pop(0)) + ) + self.fields = tuple(new_fields) + + def add(self, key, value): + """ + Add an additional value for the given key at the bottom. + """ + self.insert(len(self.fields), key, value) + + def insert(self, index, key, value): + """ + Insert an additional value for the given key at the specified position. + """ + item = (key, value) + self.fields = self.fields[:index] + (item,) + self.fields[index:] + + def keys(self, multi=False): + """ + Get all keys. + + Args: + multi(bool): + If True, one key per value will be returned. + If False, duplicate keys will only be returned once. + """ + return ( + k + for k, _ in self.items(multi) + ) + + def values(self, multi=False): + """ + Get all values. + + Args: + multi(bool): + If True, all values will be returned. + If False, only the first value per key will be returned. + """ + return ( + v + for _, v in self.items(multi) + ) + + def items(self, multi=False): + """ + Get all (key, value) tuples. + + Args: + multi(bool): + If True, all (key, value) pairs will be returned + If False, only the first (key, value) pair per unique key will be returned. + """ + if multi: + return self.fields + else: + return super().items() + + +class MultiDict(_MultiDict, serializable.Serializable): + def __init__(self, fields=()): + super().__init__() + self.fields = tuple( + tuple(i) for i in fields + ) + + @staticmethod + def _reduce_values(values): + return values[0] + + @staticmethod + def _kconv(key): + return key + + def get_state(self): + return self.fields + + def set_state(self, state): + self.fields = tuple(tuple(x) for x in state) + + @classmethod + def from_state(cls, state): + return cls(state) + + +class MultiDictView(_MultiDict): + """ + The MultiDictView provides the MultiDict interface over calculated data. + The view itself contains no state - data is retrieved from the parent on + request, and stored back to the parent on change. + """ + def __init__(self, getter, setter): + self._getter = getter + self._setter = setter + super().__init__() + + @staticmethod + def _kconv(key): + # All request-attributes are case-sensitive. + return key + + @staticmethod + def _reduce_values(values): + # We just return the first element if + # multiple elements exist with the same key. + return values[0] + + @property + def fields(self): + return self._getter() + + @fields.setter + def fields(self, value): + self._setter(value) + + def copy(self): + return MultiDict(self.fields) diff --git a/mitmproxy/coretypes/serializable.py b/mitmproxy/coretypes/serializable.py new file mode 100644 index 00000000..cd8539b0 --- /dev/null +++ b/mitmproxy/coretypes/serializable.py @@ -0,0 +1,36 @@ +import abc +import uuid + + +class Serializable(metaclass=abc.ABCMeta): + """ + Abstract Base Class that defines an API to save an object's state and restore it later on. + """ + + @classmethod + @abc.abstractmethod + def from_state(cls, state): + """ + Create a new object from the given state. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_state(self): + """ + Retrieve object state. + """ + raise NotImplementedError() + + @abc.abstractmethod + def set_state(self, state): + """ + Set object state to the given state. + """ + raise NotImplementedError() + + def copy(self): + state = self.get_state() + if isinstance(state, dict) and "id" in state: + state["id"] = str(uuid.uuid4()) + return self.from_state(state) diff --git a/mitmproxy/master.py b/mitmproxy/master.py index b41e2a8d..5997ff6d 100644 --- a/mitmproxy/master.py +++ b/mitmproxy/master.py @@ -12,7 +12,7 @@ from mitmproxy import http from mitmproxy import log from mitmproxy.net import server_spec from mitmproxy.proxy.protocol import http_replay -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread from . import ctx as mitmproxy_ctx diff --git a/mitmproxy/net/http/cookies.py b/mitmproxy/net/http/cookies.py index 5b410acc..4824bf56 100644 --- a/mitmproxy/net/http/cookies.py +++ b/mitmproxy/net/http/cookies.py @@ -3,7 +3,7 @@ import re import time from typing import Tuple, List, Iterable -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict """ A flexible module for cookie parsing and manipulation. diff --git a/mitmproxy/net/http/headers.py b/mitmproxy/net/http/headers.py index 8fc0cd43..8a58cbbc 100644 --- a/mitmproxy/net/http/headers.py +++ b/mitmproxy/net/http/headers.py @@ -1,7 +1,7 @@ import re import collections -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from mitmproxy.utils import strutils # See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ diff --git a/mitmproxy/net/http/message.py b/mitmproxy/net/http/message.py index cb32aee4..65820f67 100644 --- a/mitmproxy/net/http/message.py +++ b/mitmproxy/net/http/message.py @@ -3,7 +3,7 @@ from typing import Optional, Union # noqa from mitmproxy.utils import strutils from mitmproxy.net.http import encoding -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable from mitmproxy.net.http import headers diff --git a/mitmproxy/net/http/request.py b/mitmproxy/net/http/request.py index 6f366a4f..6b4041f6 100644 --- a/mitmproxy/net/http/request.py +++ b/mitmproxy/net/http/request.py @@ -2,7 +2,7 @@ import re import urllib from typing import Optional, AnyStr, Dict, Iterable, Tuple, Union -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from mitmproxy.utils import strutils from mitmproxy.net.http import multipart from mitmproxy.net.http import cookies diff --git a/mitmproxy/net/http/response.py b/mitmproxy/net/http/response.py index 18950fc7..48527d63 100644 --- a/mitmproxy/net/http/response.py +++ b/mitmproxy/net/http/response.py @@ -1,7 +1,7 @@ import time from email.utils import parsedate_tz, formatdate, mktime_tz from mitmproxy.utils import human -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from mitmproxy.net.http import cookies from mitmproxy.net.http import headers as nheaders from mitmproxy.net.http import message diff --git a/mitmproxy/net/socks.py b/mitmproxy/net/socks.py index fdfcfb80..0b2790df 100644 --- a/mitmproxy/net/socks.py +++ b/mitmproxy/net/socks.py @@ -3,7 +3,7 @@ import array import ipaddress from mitmproxy.net import check -from mitmproxy.types import bidi +from mitmproxy.coretypes import bidi class SocksError(Exception): diff --git a/mitmproxy/net/tcp.py b/mitmproxy/net/tcp.py index 47c80e80..d08938c9 100644 --- a/mitmproxy/net/tcp.py +++ b/mitmproxy/net/tcp.py @@ -14,7 +14,7 @@ from OpenSSL import SSL from mitmproxy import certs from mitmproxy import exceptions -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread socket_fileobject = socket.SocketIO diff --git a/mitmproxy/net/websockets/frame.py b/mitmproxy/net/websockets/frame.py index 28881f64..ac6a0812 100644 --- a/mitmproxy/net/websockets/frame.py +++ b/mitmproxy/net/websockets/frame.py @@ -6,7 +6,7 @@ from mitmproxy.net import tcp from mitmproxy.utils import strutils from mitmproxy.utils import bits from mitmproxy.utils import human -from mitmproxy.types import bidi +from mitmproxy.coretypes import bidi from .masker import Masker diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index cf021291..cc99a715 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -15,7 +15,7 @@ from mitmproxy.proxy.protocol import base from mitmproxy.proxy.protocol import http as httpbase import mitmproxy.net.http from mitmproxy.net import tcp -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread from mitmproxy.net.http import http2, headers from mitmproxy.utils import human diff --git a/mitmproxy/proxy/protocol/http_replay.py b/mitmproxy/proxy/protocol/http_replay.py index 00bb31c9..cc22c0b7 100644 --- a/mitmproxy/proxy/protocol/http_replay.py +++ b/mitmproxy/proxy/protocol/http_replay.py @@ -11,7 +11,7 @@ from mitmproxy import options from mitmproxy import connections from mitmproxy.net import server_spec from mitmproxy.net.http import http1 -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread from mitmproxy.utils import human diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index cbb3beb0..1d935585 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -4,7 +4,7 @@ offload computations from mitmproxy's main master thread. """ from mitmproxy import eventsequence -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread class ScriptThread(basethread.BaseThread): diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py index a0deaec9..007339e8 100644 --- a/mitmproxy/stateobject.py +++ b/mitmproxy/stateobject.py @@ -2,7 +2,7 @@ from typing import Any from typing import List from typing import MutableMapping # noqa -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable def _is_list(cls): diff --git a/mitmproxy/tcp.py b/mitmproxy/tcp.py index fe9f217b..11de80e9 100644 --- a/mitmproxy/tcp.py +++ b/mitmproxy/tcp.py @@ -3,7 +3,7 @@ import time from typing import List from mitmproxy import flow -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable class TCPMessage(serializable.Serializable): diff --git a/mitmproxy/types/__init__.py b/mitmproxy/types/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mitmproxy/types/basethread.py b/mitmproxy/types/basethread.py deleted file mode 100644 index a3c81d19..00000000 --- a/mitmproxy/types/basethread.py +++ /dev/null @@ -1,14 +0,0 @@ -import time -import threading - - -class BaseThread(threading.Thread): - def __init__(self, name, *args, **kwargs): - super().__init__(name=name, *args, **kwargs) - self._thread_started = time.time() - - def _threadinfo(self): - return "%s - age: %is" % ( - self.name, - int(time.time() - self._thread_started) - ) diff --git a/mitmproxy/types/bidi.py b/mitmproxy/types/bidi.py deleted file mode 100644 index 0982a34a..00000000 --- a/mitmproxy/types/bidi.py +++ /dev/null @@ -1,29 +0,0 @@ - - -class BiDi: - - """ - A wee utility class for keeping bi-directional mappings, like field - constants in protocols. Names are attributes on the object, dict-like - access maps values to names: - - CONST = BiDi(a=1, b=2) - assert CONST.a == 1 - assert CONST.get_name(1) == "a" - """ - - def __init__(self, **kwargs): - self.names = kwargs - self.values = {} - for k, v in kwargs.items(): - self.values[v] = k - if len(self.names) != len(self.values): - raise ValueError("Duplicate values not allowed.") - - def __getattr__(self, k): - if k in self.names: - return self.names[k] - raise AttributeError("No such attribute: %s", k) - - def get_name(self, n, default=None): - return self.values.get(n, default) diff --git a/mitmproxy/types/multidict.py b/mitmproxy/types/multidict.py deleted file mode 100644 index bd9766a3..00000000 --- a/mitmproxy/types/multidict.py +++ /dev/null @@ -1,216 +0,0 @@ -from abc import ABCMeta, abstractmethod - -from collections.abc import MutableMapping -from mitmproxy.types import serializable - - -class _MultiDict(MutableMapping, metaclass=ABCMeta): - def __repr__(self): - fields = ( - repr(field) - for field in self.fields - ) - return "{cls}[{fields}]".format( - cls=type(self).__name__, - fields=", ".join(fields) - ) - - @staticmethod - @abstractmethod - def _reduce_values(values): - """ - If a user accesses multidict["foo"], this method - reduces all values for "foo" to a single value that is returned. - For example, HTTP headers are folded, whereas we will just take - the first cookie we found with that name. - """ - - @staticmethod - @abstractmethod - def _kconv(key): - """ - This method converts a key to its canonical representation. - For example, HTTP headers are case-insensitive, so this method returns key.lower(). - """ - - def __getitem__(self, key): - values = self.get_all(key) - if not values: - raise KeyError(key) - return self._reduce_values(values) - - def __setitem__(self, key, value): - self.set_all(key, [value]) - - def __delitem__(self, key): - if key not in self: - raise KeyError(key) - key = self._kconv(key) - self.fields = tuple( - field for field in self.fields - if key != self._kconv(field[0]) - ) - - def __iter__(self): - seen = set() - for key, _ in self.fields: - key_kconv = self._kconv(key) - if key_kconv not in seen: - seen.add(key_kconv) - yield key - - def __len__(self): - return len(set(self._kconv(key) for key, _ in self.fields)) - - def __eq__(self, other): - if isinstance(other, MultiDict): - return self.fields == other.fields - return False - - def get_all(self, key): - """ - Return the list of all values for a given key. - If that key is not in the MultiDict, the return value will be an empty list. - """ - key = self._kconv(key) - return [ - value - for k, value in self.fields - if self._kconv(k) == key - ] - - def set_all(self, key, values): - """ - Remove the old values for a key and add new ones. - """ - key_kconv = self._kconv(key) - - new_fields = [] - for field in self.fields: - if self._kconv(field[0]) == key_kconv: - if values: - new_fields.append( - (field[0], values.pop(0)) - ) - else: - new_fields.append(field) - while values: - new_fields.append( - (key, values.pop(0)) - ) - self.fields = tuple(new_fields) - - def add(self, key, value): - """ - Add an additional value for the given key at the bottom. - """ - self.insert(len(self.fields), key, value) - - def insert(self, index, key, value): - """ - Insert an additional value for the given key at the specified position. - """ - item = (key, value) - self.fields = self.fields[:index] + (item,) + self.fields[index:] - - def keys(self, multi=False): - """ - Get all keys. - - Args: - multi(bool): - If True, one key per value will be returned. - If False, duplicate keys will only be returned once. - """ - return ( - k - for k, _ in self.items(multi) - ) - - def values(self, multi=False): - """ - Get all values. - - Args: - multi(bool): - If True, all values will be returned. - If False, only the first value per key will be returned. - """ - return ( - v - for _, v in self.items(multi) - ) - - def items(self, multi=False): - """ - Get all (key, value) tuples. - - Args: - multi(bool): - If True, all (key, value) pairs will be returned - If False, only the first (key, value) pair per unique key will be returned. - """ - if multi: - return self.fields - else: - return super().items() - - -class MultiDict(_MultiDict, serializable.Serializable): - def __init__(self, fields=()): - super().__init__() - self.fields = tuple( - tuple(i) for i in fields - ) - - @staticmethod - def _reduce_values(values): - return values[0] - - @staticmethod - def _kconv(key): - return key - - def get_state(self): - return self.fields - - def set_state(self, state): - self.fields = tuple(tuple(x) for x in state) - - @classmethod - def from_state(cls, state): - return cls(state) - - -class MultiDictView(_MultiDict): - """ - The MultiDictView provides the MultiDict interface over calculated data. - The view itself contains no state - data is retrieved from the parent on - request, and stored back to the parent on change. - """ - def __init__(self, getter, setter): - self._getter = getter - self._setter = setter - super().__init__() - - @staticmethod - def _kconv(key): - # All request-attributes are case-sensitive. - return key - - @staticmethod - def _reduce_values(values): - # We just return the first element if - # multiple elements exist with the same key. - return values[0] - - @property - def fields(self): - return self._getter() - - @fields.setter - def fields(self, value): - self._setter(value) - - def copy(self): - return MultiDict(self.fields) diff --git a/mitmproxy/types/serializable.py b/mitmproxy/types/serializable.py deleted file mode 100644 index cd8539b0..00000000 --- a/mitmproxy/types/serializable.py +++ /dev/null @@ -1,36 +0,0 @@ -import abc -import uuid - - -class Serializable(metaclass=abc.ABCMeta): - """ - Abstract Base Class that defines an API to save an object's state and restore it later on. - """ - - @classmethod - @abc.abstractmethod - def from_state(cls, state): - """ - Create a new object from the given state. - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_state(self): - """ - Retrieve object state. - """ - raise NotImplementedError() - - @abc.abstractmethod - def set_state(self, state): - """ - Set object state to the given state. - """ - raise NotImplementedError() - - def copy(self): - state = self.get_state() - if isinstance(state, dict) and "id" in state: - state["id"] = str(uuid.uuid4()) - return self.from_state(state) diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index ded09f65..6c1e7000 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -3,7 +3,7 @@ from typing import List, Optional from mitmproxy import flow from mitmproxy.net import websockets -from mitmproxy.types import serializable +from mitmproxy.coretypes import serializable from mitmproxy.utils import strutils, human diff --git a/pathod/pathoc.py b/pathod/pathoc.py index 20a915c0..e5fe4c2d 100644 --- a/pathod/pathoc.py +++ b/pathod/pathoc.py @@ -17,7 +17,7 @@ from mitmproxy.net import tcp, tls from mitmproxy.net import websockets from mitmproxy.net import socks from mitmproxy.net import http as net_http -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread from mitmproxy.utils import strutils from pathod import log diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py index cfc71650..c56d304d 100644 --- a/pathod/protocols/http2.py +++ b/pathod/protocols/http2.py @@ -8,7 +8,7 @@ from mitmproxy.net.http import http2 import mitmproxy.net.http.headers import mitmproxy.net.http.response import mitmproxy.net.http.request -from mitmproxy.types import bidi +from mitmproxy.coretypes import bidi from .. import language diff --git a/pathod/test.py b/pathod/test.py index 52f3ba02..819c7a94 100644 --- a/pathod/test.py +++ b/pathod/test.py @@ -2,7 +2,7 @@ import io import time import queue from . import pathod -from mitmproxy.types import basethread +from mitmproxy.coretypes import basethread import typing # noqa diff --git a/test/mitmproxy/contentviews/test_auto.py b/test/mitmproxy/contentviews/test_auto.py index 2ff43139..cd888a2d 100644 --- a/test/mitmproxy/contentviews/test_auto.py +++ b/test/mitmproxy/contentviews/test_auto.py @@ -1,6 +1,6 @@ from mitmproxy.contentviews import auto from mitmproxy.net import http -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import full_eval diff --git a/test/mitmproxy/contentviews/test_query.py b/test/mitmproxy/contentviews/test_query.py index d2bddd05..741b23f1 100644 --- a/test/mitmproxy/contentviews/test_query.py +++ b/test/mitmproxy/contentviews/test_query.py @@ -1,5 +1,5 @@ from mitmproxy.contentviews import query -from mitmproxy.types import multidict +from mitmproxy.coretypes import multidict from . import full_eval diff --git a/test/mitmproxy/coretypes/__init__.py b/test/mitmproxy/coretypes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/mitmproxy/coretypes/test_basethread.py b/test/mitmproxy/coretypes/test_basethread.py new file mode 100644 index 00000000..4a383fea --- /dev/null +++ b/test/mitmproxy/coretypes/test_basethread.py @@ -0,0 +1,7 @@ +import re +from mitmproxy.coretypes import basethread + + +def test_basethread(): + t = basethread.BaseThread('foobar') + assert re.match('foobar - age: \d+s', t._threadinfo()) diff --git a/test/mitmproxy/coretypes/test_bidi.py b/test/mitmproxy/coretypes/test_bidi.py new file mode 100644 index 00000000..3bdad3c2 --- /dev/null +++ b/test/mitmproxy/coretypes/test_bidi.py @@ -0,0 +1,13 @@ +import pytest +from mitmproxy.coretypes import bidi + + +def test_bidi(): + b = bidi.BiDi(a=1, b=2) + assert b.a == 1 + assert b.get_name(1) == "a" + assert b.get_name(5) is None + with pytest.raises(AttributeError): + getattr(b, "c") + with pytest.raises(ValueError): + bidi.BiDi(one=1, two=1) diff --git a/test/mitmproxy/coretypes/test_multidict.py b/test/mitmproxy/coretypes/test_multidict.py new file mode 100644 index 00000000..273d8ca2 --- /dev/null +++ b/test/mitmproxy/coretypes/test_multidict.py @@ -0,0 +1,211 @@ +import pytest + +from mitmproxy.coretypes import multidict + + +class _TMulti: + @staticmethod + def _kconv(key): + return key.lower() + + +class TMultiDict(_TMulti, multidict.MultiDict): + pass + + +class TestMultiDict: + @staticmethod + def _multi(): + return TMultiDict(( + ("foo", "bar"), + ("bar", "baz"), + ("Bar", "bam") + )) + + def test_init(self): + md = TMultiDict() + assert len(md) == 0 + + md = TMultiDict([("foo", "bar")]) + assert len(md) == 1 + assert md.fields == (("foo", "bar"),) + + def test_repr(self): + assert repr(self._multi()) == ( + "TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]" + ) + + def test_getitem(self): + md = TMultiDict([("foo", "bar")]) + assert "foo" in md + assert "Foo" in md + assert md["foo"] == "bar" + + with pytest.raises(KeyError): + assert md["bar"] + + md_multi = TMultiDict( + [("foo", "a"), ("foo", "b")] + ) + assert md_multi["foo"] == "a" + + def test_setitem(self): + md = TMultiDict() + md["foo"] = "bar" + assert md.fields == (("foo", "bar"),) + + md["foo"] = "baz" + assert md.fields == (("foo", "baz"),) + + md["bar"] = "bam" + assert md.fields == (("foo", "baz"), ("bar", "bam")) + + def test_delitem(self): + md = self._multi() + del md["foo"] + assert "foo" not in md + assert "bar" in md + + with pytest.raises(KeyError): + del md["foo"] + + del md["bar"] + assert md.fields == () + + def test_iter(self): + md = self._multi() + assert list(md.__iter__()) == ["foo", "bar"] + + def test_len(self): + md = TMultiDict() + assert len(md) == 0 + + md = self._multi() + assert len(md) == 2 + + def test_eq(self): + assert TMultiDict() == TMultiDict() + assert not (TMultiDict() == 42) + + md1 = self._multi() + md2 = self._multi() + assert md1 == md2 + md1.fields = md1.fields[1:] + md1.fields[:1] + assert not (md1 == md2) + + def test_hash(self): + """ + If a class defines mutable objects and implements an __eq__() method, + it should not implement __hash__(), since the implementation of hashable + collections requires that a key's hash value is immutable. + """ + with pytest.raises(TypeError): + assert hash(TMultiDict()) + + def test_get_all(self): + md = self._multi() + assert md.get_all("foo") == ["bar"] + assert md.get_all("bar") == ["baz", "bam"] + assert md.get_all("baz") == [] + + def test_set_all(self): + md = TMultiDict() + md.set_all("foo", ["bar", "baz"]) + assert md.fields == (("foo", "bar"), ("foo", "baz")) + + md = TMultiDict(( + ("a", "b"), + ("x", "x"), + ("c", "d"), + ("X", "X"), + ("e", "f"), + )) + md.set_all("x", ["1", "2", "3"]) + assert md.fields == ( + ("a", "b"), + ("x", "1"), + ("c", "d"), + ("X", "2"), + ("e", "f"), + ("x", "3"), + ) + md.set_all("x", ["4"]) + assert md.fields == ( + ("a", "b"), + ("x", "4"), + ("c", "d"), + ("e", "f"), + ) + + def test_add(self): + md = self._multi() + md.add("foo", "foo") + assert md.fields == ( + ("foo", "bar"), + ("bar", "baz"), + ("Bar", "bam"), + ("foo", "foo") + ) + + def test_insert(self): + md = TMultiDict([("b", "b")]) + md.insert(0, "a", "a") + md.insert(2, "c", "c") + assert md.fields == (("a", "a"), ("b", "b"), ("c", "c")) + + def test_keys(self): + md = self._multi() + assert list(md.keys()) == ["foo", "bar"] + assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"] + + def test_values(self): + md = self._multi() + assert list(md.values()) == ["bar", "baz"] + assert list(md.values(multi=True)) == ["bar", "baz", "bam"] + + def test_items(self): + md = self._multi() + assert list(md.items()) == [("foo", "bar"), ("bar", "baz")] + assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")] + + def test_state(self): + md = self._multi() + assert len(md.get_state()) == 3 + assert md == TMultiDict.from_state(md.get_state()) + + md2 = TMultiDict() + assert md != md2 + md2.set_state(md.get_state()) + assert md == md2 + + +class TParent: + def __init__(self): + self.vals = tuple() + + def setter(self, vals): + self.vals = vals + + def getter(self): + return self.vals + + +class TestMultiDictView: + def test_modify(self): + p = TParent() + tv = multidict.MultiDictView(p.getter, p.setter) + assert len(tv) == 0 + tv["a"] = "b" + assert p.vals == (("a", "b"),) + tv["c"] = "b" + assert p.vals == (("a", "b"), ("c", "b")) + assert tv["a"] == "b" + + def test_copy(self): + p = TParent() + tv = multidict.MultiDictView(p.getter, p.setter) + c = tv.copy() + assert isinstance(c, multidict.MultiDict) + assert tv.items() == c.items() + c["foo"] = "bar" + assert tv.items() != c.items() diff --git a/test/mitmproxy/coretypes/test_serializable.py b/test/mitmproxy/coretypes/test_serializable.py new file mode 100644 index 00000000..a316f876 --- /dev/null +++ b/test/mitmproxy/coretypes/test_serializable.py @@ -0,0 +1,39 @@ +import copy + +from mitmproxy.coretypes import serializable + + +class SerializableDummy(serializable.Serializable): + def __init__(self, i): + self.i = i + + def get_state(self): + return copy.copy(self.i) + + def set_state(self, i): + self.i = i + + @classmethod + def from_state(cls, state): + return cls(state) + + +class TestSerializable: + def test_copy(self): + a = SerializableDummy(42) + assert a.i == 42 + b = a.copy() + assert b.i == 42 + + a.set_state(1) + assert a.i == 1 + assert b.i == 42 + + def test_copy_id(self): + a = SerializableDummy({ + "id": "foo", + "foo": 42 + }) + b = a.copy() + assert a.get_state()["id"] != b.get_state()["id"] + assert a.get_state()["foo"] == b.get_state()["foo"] diff --git a/test/mitmproxy/types/__init__.py b/test/mitmproxy/types/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/mitmproxy/types/test_basethread.py b/test/mitmproxy/types/test_basethread.py deleted file mode 100644 index a91588eb..00000000 --- a/test/mitmproxy/types/test_basethread.py +++ /dev/null @@ -1,7 +0,0 @@ -import re -from mitmproxy.types import basethread - - -def test_basethread(): - t = basethread.BaseThread('foobar') - assert re.match('foobar - age: \d+s', t._threadinfo()) diff --git a/test/mitmproxy/types/test_bidi.py b/test/mitmproxy/types/test_bidi.py deleted file mode 100644 index e3a259fd..00000000 --- a/test/mitmproxy/types/test_bidi.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest -from mitmproxy.types import bidi - - -def test_bidi(): - b = bidi.BiDi(a=1, b=2) - assert b.a == 1 - assert b.get_name(1) == "a" - assert b.get_name(5) is None - with pytest.raises(AttributeError): - getattr(b, "c") - with pytest.raises(ValueError): - bidi.BiDi(one=1, two=1) diff --git a/test/mitmproxy/types/test_multidict.py b/test/mitmproxy/types/test_multidict.py deleted file mode 100644 index c76cd753..00000000 --- a/test/mitmproxy/types/test_multidict.py +++ /dev/null @@ -1,211 +0,0 @@ -import pytest - -from mitmproxy.types import multidict - - -class _TMulti: - @staticmethod - def _kconv(key): - return key.lower() - - -class TMultiDict(_TMulti, multidict.MultiDict): - pass - - -class TestMultiDict: - @staticmethod - def _multi(): - return TMultiDict(( - ("foo", "bar"), - ("bar", "baz"), - ("Bar", "bam") - )) - - def test_init(self): - md = TMultiDict() - assert len(md) == 0 - - md = TMultiDict([("foo", "bar")]) - assert len(md) == 1 - assert md.fields == (("foo", "bar"),) - - def test_repr(self): - assert repr(self._multi()) == ( - "TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]" - ) - - def test_getitem(self): - md = TMultiDict([("foo", "bar")]) - assert "foo" in md - assert "Foo" in md - assert md["foo"] == "bar" - - with pytest.raises(KeyError): - assert md["bar"] - - md_multi = TMultiDict( - [("foo", "a"), ("foo", "b")] - ) - assert md_multi["foo"] == "a" - - def test_setitem(self): - md = TMultiDict() - md["foo"] = "bar" - assert md.fields == (("foo", "bar"),) - - md["foo"] = "baz" - assert md.fields == (("foo", "baz"),) - - md["bar"] = "bam" - assert md.fields == (("foo", "baz"), ("bar", "bam")) - - def test_delitem(self): - md = self._multi() - del md["foo"] - assert "foo" not in md - assert "bar" in md - - with pytest.raises(KeyError): - del md["foo"] - - del md["bar"] - assert md.fields == () - - def test_iter(self): - md = self._multi() - assert list(md.__iter__()) == ["foo", "bar"] - - def test_len(self): - md = TMultiDict() - assert len(md) == 0 - - md = self._multi() - assert len(md) == 2 - - def test_eq(self): - assert TMultiDict() == TMultiDict() - assert not (TMultiDict() == 42) - - md1 = self._multi() - md2 = self._multi() - assert md1 == md2 - md1.fields = md1.fields[1:] + md1.fields[:1] - assert not (md1 == md2) - - def test_hash(self): - """ - If a class defines mutable objects and implements an __eq__() method, - it should not implement __hash__(), since the implementation of hashable - collections requires that a key's hash value is immutable. - """ - with pytest.raises(TypeError): - assert hash(TMultiDict()) - - def test_get_all(self): - md = self._multi() - assert md.get_all("foo") == ["bar"] - assert md.get_all("bar") == ["baz", "bam"] - assert md.get_all("baz") == [] - - def test_set_all(self): - md = TMultiDict() - md.set_all("foo", ["bar", "baz"]) - assert md.fields == (("foo", "bar"), ("foo", "baz")) - - md = TMultiDict(( - ("a", "b"), - ("x", "x"), - ("c", "d"), - ("X", "X"), - ("e", "f"), - )) - md.set_all("x", ["1", "2", "3"]) - assert md.fields == ( - ("a", "b"), - ("x", "1"), - ("c", "d"), - ("X", "2"), - ("e", "f"), - ("x", "3"), - ) - md.set_all("x", ["4"]) - assert md.fields == ( - ("a", "b"), - ("x", "4"), - ("c", "d"), - ("e", "f"), - ) - - def test_add(self): - md = self._multi() - md.add("foo", "foo") - assert md.fields == ( - ("foo", "bar"), - ("bar", "baz"), - ("Bar", "bam"), - ("foo", "foo") - ) - - def test_insert(self): - md = TMultiDict([("b", "b")]) - md.insert(0, "a", "a") - md.insert(2, "c", "c") - assert md.fields == (("a", "a"), ("b", "b"), ("c", "c")) - - def test_keys(self): - md = self._multi() - assert list(md.keys()) == ["foo", "bar"] - assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"] - - def test_values(self): - md = self._multi() - assert list(md.values()) == ["bar", "baz"] - assert list(md.values(multi=True)) == ["bar", "baz", "bam"] - - def test_items(self): - md = self._multi() - assert list(md.items()) == [("foo", "bar"), ("bar", "baz")] - assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")] - - def test_state(self): - md = self._multi() - assert len(md.get_state()) == 3 - assert md == TMultiDict.from_state(md.get_state()) - - md2 = TMultiDict() - assert md != md2 - md2.set_state(md.get_state()) - assert md == md2 - - -class TParent: - def __init__(self): - self.vals = tuple() - - def setter(self, vals): - self.vals = vals - - def getter(self): - return self.vals - - -class TestMultiDictView: - def test_modify(self): - p = TParent() - tv = multidict.MultiDictView(p.getter, p.setter) - assert len(tv) == 0 - tv["a"] = "b" - assert p.vals == (("a", "b"),) - tv["c"] = "b" - assert p.vals == (("a", "b"), ("c", "b")) - assert tv["a"] == "b" - - def test_copy(self): - p = TParent() - tv = multidict.MultiDictView(p.getter, p.setter) - c = tv.copy() - assert isinstance(c, multidict.MultiDict) - assert tv.items() == c.items() - c["foo"] = "bar" - assert tv.items() != c.items() diff --git a/test/mitmproxy/types/test_serializable.py b/test/mitmproxy/types/test_serializable.py deleted file mode 100644 index 390d17e1..00000000 --- a/test/mitmproxy/types/test_serializable.py +++ /dev/null @@ -1,39 +0,0 @@ -import copy - -from mitmproxy.types import serializable - - -class SerializableDummy(serializable.Serializable): - def __init__(self, i): - self.i = i - - def get_state(self): - return copy.copy(self.i) - - def set_state(self, i): - self.i = i - - @classmethod - def from_state(cls, state): - return cls(state) - - -class TestSerializable: - def test_copy(self): - a = SerializableDummy(42) - assert a.i == 42 - b = a.copy() - assert b.i == 42 - - a.set_state(1) - assert a.i == 1 - assert b.i == 42 - - def test_copy_id(self): - a = SerializableDummy({ - "id": "foo", - "foo": 42 - }) - b = a.copy() - assert a.get_state()["id"] != b.get_state()["id"] - assert a.get_state()["foo"] == b.get_state()["foo"] -- cgit v1.2.3