aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2018-01-12 22:42:02 +0100
committerMaximilian Hils <git@maximilianhils.com>2018-01-13 00:33:37 +0100
commit69726f180a70f42f8233b673aea209b0dabaa161 (patch)
treef0c757a864f4f3021bac1c538a728888194fdce2
parentb7db304dde0daf2b410dc36d33a24856aa22ba59 (diff)
downloadmitmproxy-69726f180a70f42f8233b673aea209b0dabaa161.tar.gz
mitmproxy-69726f180a70f42f8233b673aea209b0dabaa161.tar.bz2
mitmproxy-69726f180a70f42f8233b673aea209b0dabaa161.zip
stateobject: use typing, enable tuples and more complex datatypes
-rw-r--r--mitmproxy/flow.py2
-rw-r--r--mitmproxy/stateobject.py83
-rw-r--r--mitmproxy/utils/typecheck.py56
-rw-r--r--setup.cfg1
-rw-r--r--test/mitmproxy/test_stateobject.py149
-rw-r--r--test/mitmproxy/utils/test_typecheck.py5
6 files changed, 190 insertions, 106 deletions
diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py
index 944c032d..6a27a4a8 100644
--- a/mitmproxy/flow.py
+++ b/mitmproxy/flow.py
@@ -87,7 +87,7 @@ class Flow(stateobject.StateObject):
type=str,
intercepted=bool,
marked=bool,
- metadata=dict,
+ metadata=typing.Dict[str, typing.Any],
)
def get_state(self):
diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py
index 007339e8..ffaf285f 100644
--- a/mitmproxy/stateobject.py
+++ b/mitmproxy/stateobject.py
@@ -1,18 +1,12 @@
-from typing import Any
-from typing import List
+import typing
+from typing import Any # noqa
from typing import MutableMapping # noqa
from mitmproxy.coretypes import serializable
-
-
-def _is_list(cls):
- # The typing module is broken on Python 3.5.0, fixed on 3.5.1.
- is_list_bugfix = getattr(cls, "__origin__", False) == getattr(List[Any], "__origin__", True)
- return issubclass(cls, List) or is_list_bugfix
+from mitmproxy.utils import typecheck
class StateObject(serializable.Serializable):
-
"""
An object with serializable state.
@@ -34,22 +28,7 @@ class StateObject(serializable.Serializable):
state = {}
for attr, cls in self._stateobject_attributes.items():
val = getattr(self, attr)
- if val is None:
- state[attr] = None
- elif hasattr(val, "get_state"):
- state[attr] = val.get_state()
- elif _is_list(cls):
- state[attr] = [x.get_state() for x in val]
- elif isinstance(val, dict):
- s = {}
- for k, v in val.items():
- if hasattr(v, "get_state"):
- s[k] = v.get_state()
- else:
- s[k] = v
- state[attr] = s
- else:
- state[attr] = val
+ state[attr] = get_state(cls, val)
return state
def set_state(self, state):
@@ -65,13 +44,51 @@ class StateObject(serializable.Serializable):
curr = getattr(self, attr)
if hasattr(curr, "set_state"):
curr.set_state(val)
- elif hasattr(cls, "from_state"):
- obj = cls.from_state(val)
- setattr(self, attr, obj)
- elif _is_list(cls):
- cls = cls.__parameters__[0] if cls.__parameters__ else cls.__args__[0]
- setattr(self, attr, [cls.from_state(x) for x in val])
- else: # primitive types such as int, str, ...
- setattr(self, attr, cls(val))
+ else:
+ setattr(self, attr, make_object(cls, val))
if state:
raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state))
+
+
+def _process(typeinfo: typecheck.Type, val: typing.Any, make: bool) -> typing.Any:
+ if val is None:
+ return None
+ elif make and hasattr(typeinfo, "from_state"):
+ return typeinfo.from_state(val)
+ elif not make and hasattr(val, "get_state"):
+ return val.get_state()
+
+ typename = str(typeinfo)
+
+ if typename.startswith("typing.List"):
+ T = typecheck.sequence_type(typeinfo)
+ return [_process(T, x, make) for x in val]
+ elif typename.startswith("typing.Tuple"):
+ Ts = typecheck.tuple_types(typeinfo)
+ if len(Ts) != len(val):
+ raise ValueError("Invalid data. Expected {}, got {}.".format(Ts, val))
+ return tuple(
+ _process(T, x, make) for T, x in zip(Ts, val)
+ )
+ elif typename.startswith("typing.Dict"):
+ k_cls, v_cls = typecheck.mapping_types(typeinfo)
+ return {
+ _process(k_cls, k, make): _process(v_cls, v, make)
+ for k, v in val.items()
+ }
+ elif typename.startswith("typing.Any"):
+ # FIXME: Remove this when we remove flow.metadata
+ assert isinstance(val, (int, str, bool, bytes))
+ return val
+ else:
+ return typeinfo(val)
+
+
+def make_object(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any:
+ """Create an object based on the state given in val."""
+ return _process(typeinfo, val, True)
+
+
+def get_state(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any:
+ """Get the state of the object given as val."""
+ return _process(typeinfo, val, False)
diff --git a/mitmproxy/utils/typecheck.py b/mitmproxy/utils/typecheck.py
index 1070fad0..22db68f5 100644
--- a/mitmproxy/utils/typecheck.py
+++ b/mitmproxy/utils/typecheck.py
@@ -1,7 +1,40 @@
import typing
+Type = typing.Union[
+ typing.Any # anything more elaborate really fails with mypy at the moment.
+]
-def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> None:
+
+def sequence_type(typeinfo: typing.Type[typing.List]) -> Type:
+ """Return the type of a sequence, e.g. typing.List"""
+ try:
+ return typeinfo.__args__[0] # type: ignore
+ except AttributeError: # Python 3.5.0
+ return typeinfo.__parameters__[0] # type: ignore
+
+
+def tuple_types(typeinfo: typing.Type[typing.Tuple]) -> typing.Sequence[Type]:
+ """Return the types of a typing.Tuple"""
+ try:
+ return typeinfo.__args__ # type: ignore
+ except AttributeError: # Python 3.5.x
+ return typeinfo.__tuple_params__ # type: ignore
+
+
+def union_types(typeinfo: typing.Type[typing.Tuple]) -> typing.Sequence[Type]:
+ """return the types of a typing.Union"""
+ try:
+ return typeinfo.__args__ # type: ignore
+ except AttributeError: # Python 3.5.x
+ return typeinfo.__union_params__ # type: ignore
+
+
+def mapping_types(typeinfo: typing.Type[typing.Mapping]) -> typing.Tuple[Type, Type]:
+ """return the types of a mapping, e.g. typing.Dict"""
+ return typeinfo.__args__ # type: ignore
+
+
+def check_option_type(name: str, value: typing.Any, typeinfo: Type) -> None:
"""
Check if the provided value is an instance of typeinfo and raises a
TypeError otherwise. This function supports only those types required for
@@ -16,13 +49,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non
typename = str(typeinfo)
if typename.startswith("typing.Union"):
- try:
- types = typeinfo.__args__ # type: ignore
- except AttributeError:
- # Python 3.5.x
- types = typeinfo.__union_params__ # type: ignore
-
- for T in types:
+ for T in union_types(typeinfo):
try:
check_option_type(name, value, T)
except TypeError:
@@ -31,12 +58,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non
return
raise e
elif typename.startswith("typing.Tuple"):
- try:
- types = typeinfo.__args__ # type: ignore
- except AttributeError:
- # Python 3.5.x
- types = typeinfo.__tuple_params__ # type: ignore
-
+ types = tuple_types(typeinfo)
if not isinstance(value, (tuple, list)):
raise e
if len(types) != len(value):
@@ -45,11 +67,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non
check_option_type("{}[{}]".format(name, i), x, T)
return
elif typename.startswith("typing.Sequence"):
- try:
- T = typeinfo.__args__[0] # type: ignore
- except AttributeError:
- # Python 3.5.0
- T = typeinfo.__parameters__[0] # type: ignore
+ T = sequence_type(typeinfo)
if not isinstance(value, (tuple, list)):
raise e
for v in value:
diff --git a/setup.cfg b/setup.cfg
index 7c754722..592cc2e3 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -75,7 +75,6 @@ exclude =
mitmproxy/proxy/protocol/tls.py
mitmproxy/proxy/root_context.py
mitmproxy/proxy/server.py
- mitmproxy/stateobject.py
mitmproxy/utils/bits.py
pathod/language/actions.py
pathod/language/base.py
diff --git a/test/mitmproxy/test_stateobject.py b/test/mitmproxy/test_stateobject.py
index d8c7a8e9..bd5d1792 100644
--- a/test/mitmproxy/test_stateobject.py
+++ b/test/mitmproxy/test_stateobject.py
@@ -1,101 +1,146 @@
-from typing import List
+import typing
+
import pytest
from mitmproxy.stateobject import StateObject
-class Child(StateObject):
+class TObject(StateObject):
def __init__(self, x):
self.x = x
- _stateobject_attributes = dict(
- x=int
- )
-
@classmethod
def from_state(cls, state):
obj = cls(None)
obj.set_state(state)
return obj
+
+class Child(TObject):
+ _stateobject_attributes = dict(
+ x=int
+ )
+
def __eq__(self, other):
return isinstance(other, Child) and self.x == other.x
-class Container(StateObject):
- def __init__(self):
- self.child = None
- self.children = None
- self.dictionary = None
+class TTuple(TObject):
+ _stateobject_attributes = dict(
+ x=typing.Tuple[int, Child]
+ )
+
+
+class TList(TObject):
+ _stateobject_attributes = dict(
+ x=typing.List[Child]
+ )
+
+class TDict(TObject):
_stateobject_attributes = dict(
- child=Child,
- children=List[Child],
- dictionary=dict,
+ x=typing.Dict[str, Child]
)
- @classmethod
- def from_state(cls, state):
- obj = cls()
- obj.set_state(state)
- return obj
+
+class TAny(TObject):
+ _stateobject_attributes = dict(
+ x=typing.Any
+ )
+
+
+class TSerializableChild(TObject):
+ _stateobject_attributes = dict(
+ x=Child
+ )
def test_simple():
a = Child(42)
+ assert a.get_state() == {"x": 42}
b = a.copy()
- assert b.get_state() == {"x": 42}
a.set_state({"x": 44})
assert a.x == 44
assert b.x == 42
-def test_container():
- a = Container()
- a.child = Child(42)
+def test_serializable_child():
+ child = Child(42)
+ a = TSerializableChild(child)
+ assert a.get_state() == {
+ "x": {"x": 42}
+ }
+ a.set_state({
+ "x": {"x": 43}
+ })
+ assert a.x.x == 43
+ assert a.x is child
b = a.copy()
- assert a.child.x == b.child.x
- b.child.x = 44
- assert a.child.x != b.child.x
+ assert a.x == b.x
+ assert a.x is not b.x
-def test_container_list():
- a = Container()
- a.children = [Child(42), Child(44)]
+def test_tuple():
+ a = TTuple((42, Child(43)))
assert a.get_state() == {
- "child": None,
- "children": [{"x": 42}, {"x": 44}],
- "dictionary": None,
+ "x": (42, {"x": 43})
}
- copy = a.copy()
- assert len(copy.children) == 2
- assert copy.children is not a.children
- assert copy.children[0] is not a.children[0]
- assert Container.from_state(a.get_state())
+ b = a.copy()
+ a.set_state({"x": (44, {"x": 45})})
+ assert a.x == (44, Child(45))
+ assert b.x == (42, Child(43))
+
+def test_tuple_err():
+ a = TTuple(None)
+ with pytest.raises(ValueError, msg="Invalid data"):
+ a.set_state({"x": (42,)})
-def test_container_dict():
- a = Container()
- a.dictionary = dict()
- a.dictionary['foo'] = 'bar'
- a.dictionary['bar'] = Child(44)
+
+def test_list():
+ a = TList([Child(1), Child(2)])
assert a.get_state() == {
- "child": None,
- "children": None,
- "dictionary": {'bar': {'x': 44}, 'foo': 'bar'},
+ "x": [{"x": 1}, {"x": 2}],
}
copy = a.copy()
- assert len(copy.dictionary) == 2
- assert copy.dictionary is not a.dictionary
- assert copy.dictionary['bar'] is not a.dictionary['bar']
+ assert len(copy.x) == 2
+ assert copy.x is not a.x
+ assert copy.x[0] is not a.x[0]
+
+
+def test_dict():
+ a = TDict({"foo": Child(42)})
+ assert a.get_state() == {
+ "x": {"foo": {"x": 42}}
+ }
+ b = a.copy()
+ assert list(a.x.items()) == list(b.x.items())
+ assert a.x is not b.x
+ assert a.x["foo"] is not b.x["foo"]
+
+
+def test_any():
+ a = TAny(42)
+ b = a.copy()
+ assert a.x == b.x
+
+ a = TAny(object())
+ with pytest.raises(AssertionError):
+ a.get_state()
def test_too_much_state():
- a = Container()
- a.child = Child(42)
+ a = Child(42)
s = a.get_state()
s['foo'] = 'bar'
- b = Container()
with pytest.raises(RuntimeWarning):
- b.set_state(s)
+ a.set_state(s)
+
+
+def test_none():
+ a = Child(None)
+ assert a.get_state() == {"x": None}
+ a = Child(42)
+ a.set_state({"x": None})
+ assert a.x is None
diff --git a/test/mitmproxy/utils/test_typecheck.py b/test/mitmproxy/utils/test_typecheck.py
index 5295fff5..9cb4334e 100644
--- a/test/mitmproxy/utils/test_typecheck.py
+++ b/test/mitmproxy/utils/test_typecheck.py
@@ -93,3 +93,8 @@ def test_typesec_to_str():
assert(typecheck.typespec_to_str(typing.Optional[str])) == "optional str"
with pytest.raises(NotImplementedError):
typecheck.typespec_to_str(dict)
+
+
+def test_mapping_types():
+ # this is not covered by check_option_type, but still belongs in this module
+ assert (str, int) == typecheck.mapping_types(typing.Mapping[str, int])