From 1ecb25cdc10116c5341dc1024581365bec328b4e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 20 Oct 2016 10:22:23 +1300 Subject: mitmproxy.types.[basethread,multidict,serializable] --- mitmproxy/contentviews.py | 2 +- mitmproxy/master.py | 2 +- mitmproxy/proxy/protocol/http2.py | 2 +- mitmproxy/proxy/protocol/http_replay.py | 2 +- mitmproxy/script/concurrent.py | 2 +- mitmproxy/stateobject.py | 4 +- mitmproxy/tcp.py | 4 +- mitmproxy/types/__init__.py | 0 mitmproxy/types/basethread.py | 14 ++ mitmproxy/types/multidict.py | 298 ++++++++++++++++++++++++++++++ mitmproxy/types/serializable.py | 32 ++++ mitmproxy/utils/__init__.py | 0 netlib/basethread.py | 14 -- netlib/basetypes.py | 32 ---- netlib/certutils.py | 4 +- netlib/http/cookies.py | 2 +- netlib/http/headers.py | 2 +- netlib/http/message.py | 6 +- netlib/http/request.py | 2 +- netlib/http/response.py | 2 +- netlib/multidict.py | 298 ------------------------------ netlib/tcp.py | 6 +- pathod/pathoc.py | 2 +- pathod/test.py | 2 +- test/mitmproxy/test_contentview.py | 2 +- test/mitmproxy/test_types_multidict.py | 247 +++++++++++++++++++++++++ test/mitmproxy/test_types_serializable.py | 28 +++ test/netlib/test_basetypes.py | 28 --- test/netlib/test_multidict.py | 247 ------------------------- 29 files changed, 643 insertions(+), 643 deletions(-) create mode 100644 mitmproxy/types/__init__.py create mode 100644 mitmproxy/types/basethread.py create mode 100644 mitmproxy/types/multidict.py create mode 100644 mitmproxy/types/serializable.py create mode 100644 mitmproxy/utils/__init__.py delete mode 100644 netlib/basethread.py delete mode 100644 netlib/basetypes.py delete mode 100644 netlib/multidict.py create mode 100644 test/mitmproxy/test_types_multidict.py create mode 100644 test/mitmproxy/test_types_serializable.py delete mode 100644 test/netlib/test_basetypes.py delete mode 100644 test/netlib/test_multidict.py diff --git a/mitmproxy/contentviews.py b/mitmproxy/contentviews.py index 07bf09f5..a171f36b 100644 --- a/mitmproxy/contentviews.py +++ b/mitmproxy/contentviews.py @@ -34,7 +34,7 @@ from PIL import Image from mitmproxy import exceptions from mitmproxy.contrib.wbxml import ASCommandResponse from netlib import http -from netlib import multidict +from mitmproxy.types import multidict from mitmproxy.utils import strutils from netlib.http import url diff --git a/mitmproxy/master.py b/mitmproxy/master.py index 1fc00112..2e57e57d 100644 --- a/mitmproxy/master.py +++ b/mitmproxy/master.py @@ -14,7 +14,7 @@ from mitmproxy import http from mitmproxy import log from mitmproxy import io from mitmproxy.proxy.protocol import http_replay -from netlib import basethread +from mitmproxy.types import basethread import netlib.http from . import ctx as mitmproxy_ctx diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index cbd8b34c..93ac51bc 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -15,7 +15,7 @@ from mitmproxy.proxy.protocol import base from mitmproxy.proxy.protocol import http as httpbase import netlib.http from netlib import tcp -from netlib import basethread +from mitmproxy.types import basethread from netlib.http import http2 diff --git a/mitmproxy/proxy/protocol/http_replay.py b/mitmproxy/proxy/protocol/http_replay.py index bf0697be..eef5a109 100644 --- a/mitmproxy/proxy/protocol/http_replay.py +++ b/mitmproxy/proxy/protocol/http_replay.py @@ -8,7 +8,7 @@ from mitmproxy import http from mitmproxy import flow from mitmproxy import connections from netlib.http import http1 -from netlib import basethread +from mitmproxy.types import basethread # TODO: Doesn't really belong into mitmproxy.proxy.protocol... diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index dc72e5b7..2fd7ad8d 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -4,7 +4,7 @@ offload computations from mitmproxy's main master thread. """ from mitmproxy import events -from netlib import basethread +from mitmproxy.types import basethread class ScriptThread(basethread.BaseThread): diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py index f4415ecf..1ab744a5 100644 --- a/mitmproxy/stateobject.py +++ b/mitmproxy/stateobject.py @@ -1,7 +1,7 @@ from typing import Any from typing import List -import netlib.basetypes +from mitmproxy.types import serializable def _is_list(cls): @@ -10,7 +10,7 @@ def _is_list(cls): return issubclass(cls, List) or is_list_bugfix -class StateObject(netlib.basetypes.Serializable): +class StateObject(serializable.Serializable): """ An object with serializable state. diff --git a/mitmproxy/tcp.py b/mitmproxy/tcp.py index af54c9d4..d73be98d 100644 --- a/mitmproxy/tcp.py +++ b/mitmproxy/tcp.py @@ -2,11 +2,11 @@ import time from typing import List -import netlib.basetypes from mitmproxy import flow +from mitmproxy.types import serializable -class TCPMessage(netlib.basetypes.Serializable): +class TCPMessage(serializable.Serializable): def __init__(self, from_client, content, timestamp=None): self.content = content diff --git a/mitmproxy/types/__init__.py b/mitmproxy/types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mitmproxy/types/basethread.py b/mitmproxy/types/basethread.py new file mode 100644 index 00000000..a3c81d19 --- /dev/null +++ b/mitmproxy/types/basethread.py @@ -0,0 +1,14 @@ +import time +import threading + + +class BaseThread(threading.Thread): + def __init__(self, name, *args, **kwargs): + super().__init__(name=name, *args, **kwargs) + self._thread_started = time.time() + + def _threadinfo(self): + return "%s - age: %is" % ( + self.name, + int(time.time() - self._thread_started) + ) diff --git a/mitmproxy/types/multidict.py b/mitmproxy/types/multidict.py new file mode 100644 index 00000000..d351e48b --- /dev/null +++ b/mitmproxy/types/multidict.py @@ -0,0 +1,298 @@ +from abc import ABCMeta, abstractmethod + + +try: + from collections.abc import MutableMapping +except ImportError: # pragma: no cover + from collections import MutableMapping # Workaround for Python < 3.3 + +from mitmproxy.types import serializable + + +class _MultiDict(MutableMapping, serializable.Serializable, metaclass=ABCMeta): + def __repr__(self): + fields = ( + repr(field) + for field in self.fields + ) + return "{cls}[{fields}]".format( + cls=type(self).__name__, + fields=", ".join(fields) + ) + + @staticmethod + @abstractmethod + def _reduce_values(values): + """ + If a user accesses multidict["foo"], this method + reduces all values for "foo" to a single value that is returned. + For example, HTTP headers are folded, whereas we will just take + the first cookie we found with that name. + """ + + @staticmethod + @abstractmethod + def _kconv(key): + """ + This method converts a key to its canonical representation. + For example, HTTP headers are case-insensitive, so this method returns key.lower(). + """ + + def __getitem__(self, key): + values = self.get_all(key) + if not values: + raise KeyError(key) + return self._reduce_values(values) + + def __setitem__(self, key, value): + self.set_all(key, [value]) + + def __delitem__(self, key): + if key not in self: + raise KeyError(key) + key = self._kconv(key) + self.fields = tuple( + field for field in self.fields + if key != self._kconv(field[0]) + ) + + def __iter__(self): + seen = set() + for key, _ in self.fields: + key_kconv = self._kconv(key) + if key_kconv not in seen: + seen.add(key_kconv) + yield key + + def __len__(self): + return len(set(self._kconv(key) for key, _ in self.fields)) + + def __eq__(self, other): + if isinstance(other, MultiDict): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_all(self, key): + """ + Return the list of all values for a given key. + If that key is not in the MultiDict, the return value will be an empty list. + """ + key = self._kconv(key) + return [ + value + for k, value in self.fields + if self._kconv(k) == key + ] + + def set_all(self, key, values): + """ + Remove the old values for a key and add new ones. + """ + key_kconv = self._kconv(key) + + new_fields = [] + for field in self.fields: + if self._kconv(field[0]) == key_kconv: + if values: + new_fields.append( + (field[0], values.pop(0)) + ) + else: + new_fields.append(field) + while values: + new_fields.append( + (key, values.pop(0)) + ) + self.fields = tuple(new_fields) + + def add(self, key, value): + """ + Add an additional value for the given key at the bottom. + """ + self.insert(len(self.fields), key, value) + + def insert(self, index, key, value): + """ + Insert an additional value for the given key at the specified position. + """ + item = (key, value) + self.fields = self.fields[:index] + (item,) + self.fields[index:] + + def keys(self, multi=False): + """ + Get all keys. + + Args: + multi(bool): + If True, one key per value will be returned. + If False, duplicate keys will only be returned once. + """ + return ( + k + for k, _ in self.items(multi) + ) + + def values(self, multi=False): + """ + Get all values. + + Args: + multi(bool): + If True, all values will be returned. + If False, only the first value per key will be returned. + """ + return ( + v + for _, v in self.items(multi) + ) + + def items(self, multi=False): + """ + Get all (key, value) tuples. + + Args: + multi(bool): + If True, all (key, value) pairs will be returned + If False, only the first (key, value) pair per unique key will be returned. + """ + if multi: + return self.fields + else: + return super().items() + + def collect(self): + """ + Returns a list of (key, value) tuples, where values are either + singular if there is only one matching item for a key, or a list + if there are more than one. The order of the keys matches the order + in the underlying fields list. + """ + coll = [] + for key in self: + values = self.get_all(key) + if len(values) == 1: + coll.append([key, values[0]]) + else: + coll.append([key, values]) + return coll + + def to_dict(self): + """ + Get the MultiDict as a plain Python dict. + Keys with multiple values are returned as lists. + + Example: + + .. code-block:: python + + # Simple dict with duplicate values. + >>> d = MultiDict([("name", "value"), ("a", False), ("a", 42)]) + >>> d.to_dict() + { + "name": "value", + "a": [False, 42] + } + """ + return { + k: v for k, v in self.collect() + } + + def get_state(self): + return self.fields + + def set_state(self, state): + self.fields = tuple(tuple(x) for x in state) + + @classmethod + def from_state(cls, state): + return cls(state) + + +class MultiDict(_MultiDict): + def __init__(self, fields=()): + super().__init__() + self.fields = tuple( + tuple(i) for i in fields + ) + + @staticmethod + def _reduce_values(values): + return values[0] + + @staticmethod + def _kconv(key): + return key + + +class ImmutableMultiDict(MultiDict, metaclass=ABCMeta): + def _immutable(self, *_): + raise TypeError('{} objects are immutable'.format(self.__class__.__name__)) + + __delitem__ = set_all = insert = _immutable + + def __hash__(self): + return hash(self.fields) + + def with_delitem(self, key): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + # FIXME: This is filthy... + super(ImmutableMultiDict, ret).__delitem__(key) + return ret + + def with_set_all(self, key, values): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + # FIXME: This is filthy... + super(ImmutableMultiDict, ret).set_all(key, values) + return ret + + def with_insert(self, index, key, value): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + # FIXME: This is filthy... + super(ImmutableMultiDict, ret).insert(index, key, value) + return ret + + +class MultiDictView(_MultiDict): + """ + The MultiDictView provides the MultiDict interface over calculated data. + The view itself contains no state - data is retrieved from the parent on + request, and stored back to the parent on change. + """ + def __init__(self, getter, setter): + self._getter = getter + self._setter = setter + super().__init__() + + @staticmethod + def _kconv(key): + # All request-attributes are case-sensitive. + return key + + @staticmethod + def _reduce_values(values): + # We just return the first element if + # multiple elements exist with the same key. + return values[0] + + @property + def fields(self): + return self._getter() + + @fields.setter + def fields(self, value): + self._setter(value) diff --git a/mitmproxy/types/serializable.py b/mitmproxy/types/serializable.py new file mode 100644 index 00000000..49892ffc --- /dev/null +++ b/mitmproxy/types/serializable.py @@ -0,0 +1,32 @@ +import abc + + +class Serializable(metaclass=abc.ABCMeta): + """ + Abstract Base Class that defines an API to save an object's state and restore it later on. + """ + + @classmethod + @abc.abstractmethod + def from_state(cls, state): + """ + Create a new object from the given state. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_state(self): + """ + Retrieve object state. + """ + raise NotImplementedError() + + @abc.abstractmethod + def set_state(self, state): + """ + Set object state to the given state. + """ + raise NotImplementedError() + + def copy(self): + return self.from_state(self.get_state()) diff --git a/mitmproxy/utils/__init__.py b/mitmproxy/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netlib/basethread.py b/netlib/basethread.py deleted file mode 100644 index a3c81d19..00000000 --- a/netlib/basethread.py +++ /dev/null @@ -1,14 +0,0 @@ -import time -import threading - - -class BaseThread(threading.Thread): - def __init__(self, name, *args, **kwargs): - super().__init__(name=name, *args, **kwargs) - self._thread_started = time.time() - - def _threadinfo(self): - return "%s - age: %is" % ( - self.name, - int(time.time() - self._thread_started) - ) diff --git a/netlib/basetypes.py b/netlib/basetypes.py deleted file mode 100644 index 49892ffc..00000000 --- a/netlib/basetypes.py +++ /dev/null @@ -1,32 +0,0 @@ -import abc - - -class Serializable(metaclass=abc.ABCMeta): - """ - Abstract Base Class that defines an API to save an object's state and restore it later on. - """ - - @classmethod - @abc.abstractmethod - def from_state(cls, state): - """ - Create a new object from the given state. - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_state(self): - """ - Retrieve object state. - """ - raise NotImplementedError() - - @abc.abstractmethod - def set_state(self, state): - """ - Set object state to the given state. - """ - raise NotImplementedError() - - def copy(self): - return self.from_state(self.get_state()) diff --git a/netlib/certutils.py b/netlib/certutils.py index 6a97f99e..9cb8a40e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -10,7 +10,7 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -from netlib import basetypes +from mitmproxy.types import serializable # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 @@ -373,7 +373,7 @@ class _GeneralNames(univ.SequenceOf): constraint.ValueSizeConstraint(1, 1024) -class SSLCert(basetypes.Serializable): +class SSLCert(serializable.Serializable): def __init__(self, cert): """ diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index cb816ca0..9f32fa5e 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -3,7 +3,7 @@ import email.utils import re import time -from netlib import multidict +from mitmproxy.types import multidict """ A flexible module for cookie parsing and manipulation. diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 6c30d278..8fc0cd43 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -1,7 +1,7 @@ import re import collections -from netlib import multidict +from mitmproxy.types import multidict from mitmproxy.utils import strutils # See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ diff --git a/netlib/http/message.py b/netlib/http/message.py index 133a53ce..62c3aa38 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,7 +4,7 @@ from typing import Optional from mitmproxy.utils import strutils from netlib import encoding -from netlib import basetypes +from mitmproxy.types import serializable from netlib.http import headers @@ -17,7 +17,7 @@ def _always_bytes(x): return strutils.always_bytes(x, "utf-8", "surrogateescape") -class MessageData(basetypes.Serializable): +class MessageData(serializable.Serializable): def __eq__(self, other): if isinstance(other, MessageData): return self.__dict__ == other.__dict__ @@ -43,7 +43,7 @@ class MessageData(basetypes.Serializable): return cls(**state) -class Message(basetypes.Serializable): +class Message(serializable.Serializable): def __eq__(self, other): if isinstance(other, Message): return self.data == other.data diff --git a/netlib/http/request.py b/netlib/http/request.py index 3479fa4c..16b0c986 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -1,7 +1,7 @@ import re import urllib -from netlib import multidict +from mitmproxy.types import multidict from mitmproxy.utils import strutils from netlib.http import multipart from netlib.http import cookies diff --git a/netlib/http/response.py b/netlib/http/response.py index 12dba92a..4d1d5d24 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,7 +1,7 @@ import time from email.utils import parsedate_tz, formatdate, mktime_tz from mitmproxy.utils import human -from netlib import multidict +from mitmproxy.types import multidict from netlib.http import cookies from netlib.http import headers as nheaders from netlib.http import message diff --git a/netlib/multidict.py b/netlib/multidict.py deleted file mode 100644 index 191d1cc6..00000000 --- a/netlib/multidict.py +++ /dev/null @@ -1,298 +0,0 @@ -from abc import ABCMeta, abstractmethod - - -try: - from collections.abc import MutableMapping -except ImportError: # pragma: no cover - from collections import MutableMapping # Workaround for Python < 3.3 - -from netlib import basetypes - - -class _MultiDict(MutableMapping, basetypes.Serializable, metaclass=ABCMeta): - def __repr__(self): - fields = ( - repr(field) - for field in self.fields - ) - return "{cls}[{fields}]".format( - cls=type(self).__name__, - fields=", ".join(fields) - ) - - @staticmethod - @abstractmethod - def _reduce_values(values): - """ - If a user accesses multidict["foo"], this method - reduces all values for "foo" to a single value that is returned. - For example, HTTP headers are folded, whereas we will just take - the first cookie we found with that name. - """ - - @staticmethod - @abstractmethod - def _kconv(key): - """ - This method converts a key to its canonical representation. - For example, HTTP headers are case-insensitive, so this method returns key.lower(). - """ - - def __getitem__(self, key): - values = self.get_all(key) - if not values: - raise KeyError(key) - return self._reduce_values(values) - - def __setitem__(self, key, value): - self.set_all(key, [value]) - - def __delitem__(self, key): - if key not in self: - raise KeyError(key) - key = self._kconv(key) - self.fields = tuple( - field for field in self.fields - if key != self._kconv(field[0]) - ) - - def __iter__(self): - seen = set() - for key, _ in self.fields: - key_kconv = self._kconv(key) - if key_kconv not in seen: - seen.add(key_kconv) - yield key - - def __len__(self): - return len(set(self._kconv(key) for key, _ in self.fields)) - - def __eq__(self, other): - if isinstance(other, MultiDict): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def get_all(self, key): - """ - Return the list of all values for a given key. - If that key is not in the MultiDict, the return value will be an empty list. - """ - key = self._kconv(key) - return [ - value - for k, value in self.fields - if self._kconv(k) == key - ] - - def set_all(self, key, values): - """ - Remove the old values for a key and add new ones. - """ - key_kconv = self._kconv(key) - - new_fields = [] - for field in self.fields: - if self._kconv(field[0]) == key_kconv: - if values: - new_fields.append( - (field[0], values.pop(0)) - ) - else: - new_fields.append(field) - while values: - new_fields.append( - (key, values.pop(0)) - ) - self.fields = tuple(new_fields) - - def add(self, key, value): - """ - Add an additional value for the given key at the bottom. - """ - self.insert(len(self.fields), key, value) - - def insert(self, index, key, value): - """ - Insert an additional value for the given key at the specified position. - """ - item = (key, value) - self.fields = self.fields[:index] + (item,) + self.fields[index:] - - def keys(self, multi=False): - """ - Get all keys. - - Args: - multi(bool): - If True, one key per value will be returned. - If False, duplicate keys will only be returned once. - """ - return ( - k - for k, _ in self.items(multi) - ) - - def values(self, multi=False): - """ - Get all values. - - Args: - multi(bool): - If True, all values will be returned. - If False, only the first value per key will be returned. - """ - return ( - v - for _, v in self.items(multi) - ) - - def items(self, multi=False): - """ - Get all (key, value) tuples. - - Args: - multi(bool): - If True, all (key, value) pairs will be returned - If False, only the first (key, value) pair per unique key will be returned. - """ - if multi: - return self.fields - else: - return super().items() - - def collect(self): - """ - Returns a list of (key, value) tuples, where values are either - singular if there is only one matching item for a key, or a list - if there are more than one. The order of the keys matches the order - in the underlying fields list. - """ - coll = [] - for key in self: - values = self.get_all(key) - if len(values) == 1: - coll.append([key, values[0]]) - else: - coll.append([key, values]) - return coll - - def to_dict(self): - """ - Get the MultiDict as a plain Python dict. - Keys with multiple values are returned as lists. - - Example: - - .. code-block:: python - - # Simple dict with duplicate values. - >>> d = MultiDict([("name", "value"), ("a", False), ("a", 42)]) - >>> d.to_dict() - { - "name": "value", - "a": [False, 42] - } - """ - return { - k: v for k, v in self.collect() - } - - def get_state(self): - return self.fields - - def set_state(self, state): - self.fields = tuple(tuple(x) for x in state) - - @classmethod - def from_state(cls, state): - return cls(state) - - -class MultiDict(_MultiDict): - def __init__(self, fields=()): - super().__init__() - self.fields = tuple( - tuple(i) for i in fields - ) - - @staticmethod - def _reduce_values(values): - return values[0] - - @staticmethod - def _kconv(key): - return key - - -class ImmutableMultiDict(MultiDict, metaclass=ABCMeta): - def _immutable(self, *_): - raise TypeError('{} objects are immutable'.format(self.__class__.__name__)) - - __delitem__ = set_all = insert = _immutable - - def __hash__(self): - return hash(self.fields) - - def with_delitem(self, key): - """ - Returns: - An updated ImmutableMultiDict. The original object will not be modified. - """ - ret = self.copy() - # FIXME: This is filthy... - super(ImmutableMultiDict, ret).__delitem__(key) - return ret - - def with_set_all(self, key, values): - """ - Returns: - An updated ImmutableMultiDict. The original object will not be modified. - """ - ret = self.copy() - # FIXME: This is filthy... - super(ImmutableMultiDict, ret).set_all(key, values) - return ret - - def with_insert(self, index, key, value): - """ - Returns: - An updated ImmutableMultiDict. The original object will not be modified. - """ - ret = self.copy() - # FIXME: This is filthy... - super(ImmutableMultiDict, ret).insert(index, key, value) - return ret - - -class MultiDictView(_MultiDict): - """ - The MultiDictView provides the MultiDict interface over calculated data. - The view itself contains no state - data is retrieved from the parent on - request, and stored back to the parent on change. - """ - def __init__(self, getter, setter): - self._getter = getter - self._setter = setter - super().__init__() - - @staticmethod - def _kconv(key): - # All request-attributes are case-sensitive. - return key - - @staticmethod - def _reduce_values(values): - # We just return the first element if - # multiple elements exist with the same key. - return values[0] - - @property - def fields(self): - return self._getter() - - @fields.setter - def fields(self, value): - self._setter(value) diff --git a/netlib/tcp.py b/netlib/tcp.py index aed79388..4fde657f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -19,9 +19,9 @@ from OpenSSL import SSL from netlib import certutils from netlib import version_check -from netlib import basetypes +from mitmproxy.types import serializable from netlib import exceptions -from netlib import basethread +from mitmproxy.types import basethread # This is a rather hackish way to make sure that # the latest version of pyOpenSSL is actually installed. @@ -292,7 +292,7 @@ class Reader(_FileLike): raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -class Address(basetypes.Serializable): +class Address(serializable.Serializable): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and diff --git a/pathod/pathoc.py b/pathod/pathoc.py index 0cf08a60..caa9accb 100644 --- a/pathod/pathoc.py +++ b/pathod/pathoc.py @@ -16,7 +16,7 @@ from mitmproxy.utils import strutils from netlib import tcp, certutils, websockets, socks from netlib import exceptions from netlib.http import http1 -from netlib import basethread +from mitmproxy.types import basethread from . import log, language from .protocols import http2 diff --git a/pathod/test.py b/pathod/test.py index c92cc50b..b819d723 100644 --- a/pathod/test.py +++ b/pathod/test.py @@ -3,7 +3,7 @@ import time import queue from . import pathod -from netlib import basethread +from mitmproxy.types import basethread class Daemon: diff --git a/test/mitmproxy/test_contentview.py b/test/mitmproxy/test_contentview.py index d63ee50e..f113e294 100644 --- a/test/mitmproxy/test_contentview.py +++ b/test/mitmproxy/test_contentview.py @@ -2,7 +2,7 @@ import mock from mitmproxy.exceptions import ContentViewException from netlib.http import Headers from netlib.http import url -from netlib import multidict +from mitmproxy.types import multidict import mitmproxy.contentviews as cv from . import tutils diff --git a/test/mitmproxy/test_types_multidict.py b/test/mitmproxy/test_types_multidict.py new file mode 100644 index 00000000..ada33bf7 --- /dev/null +++ b/test/mitmproxy/test_types_multidict.py @@ -0,0 +1,247 @@ +from netlib import tutils +from mitmproxy.types import multidict + + +class _TMulti: + @staticmethod + def _kconv(key): + return key.lower() + + +class TMultiDict(_TMulti, multidict.MultiDict): + pass + + +class TImmutableMultiDict(_TMulti, multidict.ImmutableMultiDict): + pass + + +class TestMultiDict: + @staticmethod + def _multi(): + return TMultiDict(( + ("foo", "bar"), + ("bar", "baz"), + ("Bar", "bam") + )) + + def test_init(self): + md = TMultiDict() + assert len(md) == 0 + + md = TMultiDict([("foo", "bar")]) + assert len(md) == 1 + assert md.fields == (("foo", "bar"),) + + def test_repr(self): + assert repr(self._multi()) == ( + "TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]" + ) + + def test_getitem(self): + md = TMultiDict([("foo", "bar")]) + assert "foo" in md + assert "Foo" in md + assert md["foo"] == "bar" + + with tutils.raises(KeyError): + assert md["bar"] + + md_multi = TMultiDict( + [("foo", "a"), ("foo", "b")] + ) + assert md_multi["foo"] == "a" + + def test_setitem(self): + md = TMultiDict() + md["foo"] = "bar" + assert md.fields == (("foo", "bar"),) + + md["foo"] = "baz" + assert md.fields == (("foo", "baz"),) + + md["bar"] = "bam" + assert md.fields == (("foo", "baz"), ("bar", "bam")) + + def test_delitem(self): + md = self._multi() + del md["foo"] + assert "foo" not in md + assert "bar" in md + + with tutils.raises(KeyError): + del md["foo"] + + del md["bar"] + assert md.fields == () + + def test_iter(self): + md = self._multi() + assert list(md.__iter__()) == ["foo", "bar"] + + def test_len(self): + md = TMultiDict() + assert len(md) == 0 + + md = self._multi() + assert len(md) == 2 + + def test_eq(self): + assert TMultiDict() == TMultiDict() + assert not (TMultiDict() == 42) + + md1 = self._multi() + md2 = self._multi() + assert md1 == md2 + md1.fields = md1.fields[1:] + md1.fields[:1] + assert not (md1 == md2) + + def test_ne(self): + assert not TMultiDict() != TMultiDict() + assert TMultiDict() != self._multi() + assert TMultiDict() != 42 + + def test_hash(self): + """ + If a class defines mutable objects and implements an __eq__() method, + it should not implement __hash__(), since the implementation of hashable + collections requires that a key's hash value is immutable. + """ + with tutils.raises(TypeError): + assert hash(TMultiDict()) + + def test_get_all(self): + md = self._multi() + assert md.get_all("foo") == ["bar"] + assert md.get_all("bar") == ["baz", "bam"] + assert md.get_all("baz") == [] + + def test_set_all(self): + md = TMultiDict() + md.set_all("foo", ["bar", "baz"]) + assert md.fields == (("foo", "bar"), ("foo", "baz")) + + md = TMultiDict(( + ("a", "b"), + ("x", "x"), + ("c", "d"), + ("X", "X"), + ("e", "f"), + )) + md.set_all("x", ["1", "2", "3"]) + assert md.fields == ( + ("a", "b"), + ("x", "1"), + ("c", "d"), + ("X", "2"), + ("e", "f"), + ("x", "3"), + ) + md.set_all("x", ["4"]) + assert md.fields == ( + ("a", "b"), + ("x", "4"), + ("c", "d"), + ("e", "f"), + ) + + def test_add(self): + md = self._multi() + md.add("foo", "foo") + assert md.fields == ( + ("foo", "bar"), + ("bar", "baz"), + ("Bar", "bam"), + ("foo", "foo") + ) + + def test_insert(self): + md = TMultiDict([("b", "b")]) + md.insert(0, "a", "a") + md.insert(2, "c", "c") + assert md.fields == (("a", "a"), ("b", "b"), ("c", "c")) + + def test_keys(self): + md = self._multi() + assert list(md.keys()) == ["foo", "bar"] + assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"] + + def test_values(self): + md = self._multi() + assert list(md.values()) == ["bar", "baz"] + assert list(md.values(multi=True)) == ["bar", "baz", "bam"] + + def test_items(self): + md = self._multi() + assert list(md.items()) == [("foo", "bar"), ("bar", "baz")] + assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")] + + def test_to_dict(self): + md = self._multi() + assert md.to_dict() == { + "foo": "bar", + "bar": ["baz", "bam"] + } + + def test_state(self): + md = self._multi() + assert len(md.get_state()) == 3 + assert md == TMultiDict.from_state(md.get_state()) + + md2 = TMultiDict() + assert md != md2 + md2.set_state(md.get_state()) + assert md == md2 + + +class TestImmutableMultiDict: + def test_modify(self): + md = TImmutableMultiDict() + with tutils.raises(TypeError): + md["foo"] = "bar" + + with tutils.raises(TypeError): + del md["foo"] + + with tutils.raises(TypeError): + md.add("foo", "bar") + + def test_hash(self): + assert hash(TImmutableMultiDict()) + + def test_with_delitem(self): + md = TImmutableMultiDict([("foo", "bar")]) + assert md.with_delitem("foo").fields == () + assert md.fields == (("foo", "bar"),) + + def test_with_set_all(self): + md = TImmutableMultiDict() + assert md.with_set_all("foo", ["bar"]).fields == (("foo", "bar"),) + assert md.fields == () + + def test_with_insert(self): + md = TImmutableMultiDict() + assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),) + + +class TParent: + def __init__(self): + self.vals = tuple() + + def setter(self, vals): + self.vals = vals + + def getter(self): + return self.vals + + +class TestMultiDictView: + def test_modify(self): + p = TParent() + tv = multidict.MultiDictView(p.getter, p.setter) + assert len(tv) == 0 + tv["a"] = "b" + assert p.vals == (("a", "b"),) + tv["c"] = "b" + assert p.vals == (("a", "b"), ("c", "b")) + assert tv["a"] == "b" diff --git a/test/mitmproxy/test_types_serializable.py b/test/mitmproxy/test_types_serializable.py new file mode 100644 index 00000000..dd4a3778 --- /dev/null +++ b/test/mitmproxy/test_types_serializable.py @@ -0,0 +1,28 @@ +from mitmproxy.types import serializable + + +class SerializableDummy(serializable.Serializable): + def __init__(self, i): + self.i = i + + def get_state(self): + return self.i + + def set_state(self, i): + self.i = i + + def from_state(self, state): + return type(self)(state) + + +class TestSerializable: + + def test_copy(self): + a = SerializableDummy(42) + assert a.i == 42 + b = a.copy() + assert b.i == 42 + + a.set_state(1) + assert a.i == 1 + assert b.i == 42 diff --git a/test/netlib/test_basetypes.py b/test/netlib/test_basetypes.py deleted file mode 100644 index aa415784..00000000 --- a/test/netlib/test_basetypes.py +++ /dev/null @@ -1,28 +0,0 @@ -from netlib import basetypes - - -class SerializableDummy(basetypes.Serializable): - def __init__(self, i): - self.i = i - - def get_state(self): - return self.i - - def set_state(self, i): - self.i = i - - def from_state(self, state): - return type(self)(state) - - -class TestSerializable: - - def test_copy(self): - a = SerializableDummy(42) - assert a.i == 42 - b = a.copy() - assert b.i == 42 - - a.set_state(1) - assert a.i == 1 - assert b.i == 42 diff --git a/test/netlib/test_multidict.py b/test/netlib/test_multidict.py deleted file mode 100644 index a9523fd9..00000000 --- a/test/netlib/test_multidict.py +++ /dev/null @@ -1,247 +0,0 @@ -from netlib import tutils -from netlib.multidict import MultiDict, ImmutableMultiDict, MultiDictView - - -class _TMulti: - @staticmethod - def _kconv(key): - return key.lower() - - -class TMultiDict(_TMulti, MultiDict): - pass - - -class TImmutableMultiDict(_TMulti, ImmutableMultiDict): - pass - - -class TestMultiDict: - @staticmethod - def _multi(): - return TMultiDict(( - ("foo", "bar"), - ("bar", "baz"), - ("Bar", "bam") - )) - - def test_init(self): - md = TMultiDict() - assert len(md) == 0 - - md = TMultiDict([("foo", "bar")]) - assert len(md) == 1 - assert md.fields == (("foo", "bar"),) - - def test_repr(self): - assert repr(self._multi()) == ( - "TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]" - ) - - def test_getitem(self): - md = TMultiDict([("foo", "bar")]) - assert "foo" in md - assert "Foo" in md - assert md["foo"] == "bar" - - with tutils.raises(KeyError): - assert md["bar"] - - md_multi = TMultiDict( - [("foo", "a"), ("foo", "b")] - ) - assert md_multi["foo"] == "a" - - def test_setitem(self): - md = TMultiDict() - md["foo"] = "bar" - assert md.fields == (("foo", "bar"),) - - md["foo"] = "baz" - assert md.fields == (("foo", "baz"),) - - md["bar"] = "bam" - assert md.fields == (("foo", "baz"), ("bar", "bam")) - - def test_delitem(self): - md = self._multi() - del md["foo"] - assert "foo" not in md - assert "bar" in md - - with tutils.raises(KeyError): - del md["foo"] - - del md["bar"] - assert md.fields == () - - def test_iter(self): - md = self._multi() - assert list(md.__iter__()) == ["foo", "bar"] - - def test_len(self): - md = TMultiDict() - assert len(md) == 0 - - md = self._multi() - assert len(md) == 2 - - def test_eq(self): - assert TMultiDict() == TMultiDict() - assert not (TMultiDict() == 42) - - md1 = self._multi() - md2 = self._multi() - assert md1 == md2 - md1.fields = md1.fields[1:] + md1.fields[:1] - assert not (md1 == md2) - - def test_ne(self): - assert not TMultiDict() != TMultiDict() - assert TMultiDict() != self._multi() - assert TMultiDict() != 42 - - def test_hash(self): - """ - If a class defines mutable objects and implements an __eq__() method, - it should not implement __hash__(), since the implementation of hashable - collections requires that a key's hash value is immutable. - """ - with tutils.raises(TypeError): - assert hash(TMultiDict()) - - def test_get_all(self): - md = self._multi() - assert md.get_all("foo") == ["bar"] - assert md.get_all("bar") == ["baz", "bam"] - assert md.get_all("baz") == [] - - def test_set_all(self): - md = TMultiDict() - md.set_all("foo", ["bar", "baz"]) - assert md.fields == (("foo", "bar"), ("foo", "baz")) - - md = TMultiDict(( - ("a", "b"), - ("x", "x"), - ("c", "d"), - ("X", "X"), - ("e", "f"), - )) - md.set_all("x", ["1", "2", "3"]) - assert md.fields == ( - ("a", "b"), - ("x", "1"), - ("c", "d"), - ("X", "2"), - ("e", "f"), - ("x", "3"), - ) - md.set_all("x", ["4"]) - assert md.fields == ( - ("a", "b"), - ("x", "4"), - ("c", "d"), - ("e", "f"), - ) - - def test_add(self): - md = self._multi() - md.add("foo", "foo") - assert md.fields == ( - ("foo", "bar"), - ("bar", "baz"), - ("Bar", "bam"), - ("foo", "foo") - ) - - def test_insert(self): - md = TMultiDict([("b", "b")]) - md.insert(0, "a", "a") - md.insert(2, "c", "c") - assert md.fields == (("a", "a"), ("b", "b"), ("c", "c")) - - def test_keys(self): - md = self._multi() - assert list(md.keys()) == ["foo", "bar"] - assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"] - - def test_values(self): - md = self._multi() - assert list(md.values()) == ["bar", "baz"] - assert list(md.values(multi=True)) == ["bar", "baz", "bam"] - - def test_items(self): - md = self._multi() - assert list(md.items()) == [("foo", "bar"), ("bar", "baz")] - assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")] - - def test_to_dict(self): - md = self._multi() - assert md.to_dict() == { - "foo": "bar", - "bar": ["baz", "bam"] - } - - def test_state(self): - md = self._multi() - assert len(md.get_state()) == 3 - assert md == TMultiDict.from_state(md.get_state()) - - md2 = TMultiDict() - assert md != md2 - md2.set_state(md.get_state()) - assert md == md2 - - -class TestImmutableMultiDict: - def test_modify(self): - md = TImmutableMultiDict() - with tutils.raises(TypeError): - md["foo"] = "bar" - - with tutils.raises(TypeError): - del md["foo"] - - with tutils.raises(TypeError): - md.add("foo", "bar") - - def test_hash(self): - assert hash(TImmutableMultiDict()) - - def test_with_delitem(self): - md = TImmutableMultiDict([("foo", "bar")]) - assert md.with_delitem("foo").fields == () - assert md.fields == (("foo", "bar"),) - - def test_with_set_all(self): - md = TImmutableMultiDict() - assert md.with_set_all("foo", ["bar"]).fields == (("foo", "bar"),) - assert md.fields == () - - def test_with_insert(self): - md = TImmutableMultiDict() - assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),) - - -class TParent: - def __init__(self): - self.vals = tuple() - - def setter(self, vals): - self.vals = vals - - def getter(self): - return self.vals - - -class TestMultiDictView: - def test_modify(self): - p = TParent() - tv = MultiDictView(p.getter, p.setter) - assert len(tv) == 0 - tv["a"] = "b" - assert p.vals == (("a", "b"),) - tv["c"] = "b" - assert p.vals == (("a", "b"), ("c", "b")) - assert tv["a"] == "b" -- cgit v1.2.3