diff options
| author | Maximilian Hils <git@maximilianhils.com> | 2018-01-13 00:35:49 +0100 | 
|---|---|---|
| committer | GitHub <noreply@github.com> | 2018-01-13 00:35:49 +0100 | 
| commit | 96a5ed9dff60c245cc74c98e87caf37942cf0ee2 (patch) | |
| tree | 9b50bc84ba17808cbb2d93f01a9edb7e5f202672 | |
| parent | 37527a1da3629fc311139363e0bb3938081f7811 (diff) | |
| parent | 69726f180a70f42f8233b673aea209b0dabaa161 (diff) | |
| download | mitmproxy-96a5ed9dff60c245cc74c98e87caf37942cf0ee2.tar.gz mitmproxy-96a5ed9dff60c245cc74c98e87caf37942cf0ee2.tar.bz2 mitmproxy-96a5ed9dff60c245cc74c98e87caf37942cf0ee2.zip | |
Merge pull request #2790 from mhils/stateobject-improvements
stateobject: use typing, enable tuples and more complex datatypes
| -rw-r--r-- | mitmproxy/flow.py | 2 | ||||
| -rw-r--r-- | mitmproxy/stateobject.py | 83 | ||||
| -rw-r--r-- | mitmproxy/utils/typecheck.py | 56 | ||||
| -rw-r--r-- | setup.cfg | 1 | ||||
| -rw-r--r-- | test/mitmproxy/test_stateobject.py | 149 | ||||
| -rw-r--r-- | test/mitmproxy/utils/test_typecheck.py | 5 | 
6 files changed, 190 insertions, 106 deletions
| diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 944c032d..6a27a4a8 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -87,7 +87,7 @@ class Flow(stateobject.StateObject):          type=str,          intercepted=bool,          marked=bool, -        metadata=dict, +        metadata=typing.Dict[str, typing.Any],      )      def get_state(self): diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py index 007339e8..ffaf285f 100644 --- a/mitmproxy/stateobject.py +++ b/mitmproxy/stateobject.py @@ -1,18 +1,12 @@ -from typing import Any -from typing import List +import typing +from typing import Any  # noqa  from typing import MutableMapping  # noqa  from mitmproxy.coretypes import serializable - - -def _is_list(cls): -    # The typing module is broken on Python 3.5.0, fixed on 3.5.1. -    is_list_bugfix = getattr(cls, "__origin__", False) == getattr(List[Any], "__origin__", True) -    return issubclass(cls, List) or is_list_bugfix +from mitmproxy.utils import typecheck  class StateObject(serializable.Serializable): -      """      An object with serializable state. @@ -34,22 +28,7 @@ class StateObject(serializable.Serializable):          state = {}          for attr, cls in self._stateobject_attributes.items():              val = getattr(self, attr) -            if val is None: -                state[attr] = None -            elif hasattr(val, "get_state"): -                state[attr] = val.get_state() -            elif _is_list(cls): -                state[attr] = [x.get_state() for x in val] -            elif isinstance(val, dict): -                s = {} -                for k, v in val.items(): -                    if hasattr(v, "get_state"): -                        s[k] = v.get_state() -                    else: -                        s[k] = v -                state[attr] = s -            else: -                state[attr] = val +            state[attr] = get_state(cls, val)          return state      def set_state(self, state): @@ -65,13 +44,51 @@ class StateObject(serializable.Serializable):                  curr = getattr(self, attr)                  if hasattr(curr, "set_state"):                      curr.set_state(val) -                elif hasattr(cls, "from_state"): -                    obj = cls.from_state(val) -                    setattr(self, attr, obj) -                elif _is_list(cls): -                    cls = cls.__parameters__[0] if cls.__parameters__ else cls.__args__[0] -                    setattr(self, attr, [cls.from_state(x) for x in val]) -                else:  # primitive types such as int, str, ... -                    setattr(self, attr, cls(val)) +                else: +                    setattr(self, attr, make_object(cls, val))          if state:              raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state)) + + +def _process(typeinfo: typecheck.Type, val: typing.Any, make: bool) -> typing.Any: +    if val is None: +        return None +    elif make and hasattr(typeinfo, "from_state"): +        return typeinfo.from_state(val) +    elif not make and hasattr(val, "get_state"): +        return val.get_state() + +    typename = str(typeinfo) + +    if typename.startswith("typing.List"): +        T = typecheck.sequence_type(typeinfo) +        return [_process(T, x, make) for x in val] +    elif typename.startswith("typing.Tuple"): +        Ts = typecheck.tuple_types(typeinfo) +        if len(Ts) != len(val): +            raise ValueError("Invalid data. Expected {}, got {}.".format(Ts, val)) +        return tuple( +            _process(T, x, make) for T, x in zip(Ts, val) +        ) +    elif typename.startswith("typing.Dict"): +        k_cls, v_cls = typecheck.mapping_types(typeinfo) +        return { +            _process(k_cls, k, make): _process(v_cls, v, make) +            for k, v in val.items() +        } +    elif typename.startswith("typing.Any"): +        # FIXME: Remove this when we remove flow.metadata +        assert isinstance(val, (int, str, bool, bytes)) +        return val +    else: +        return typeinfo(val) + + +def make_object(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any: +    """Create an object based on the state given in val.""" +    return _process(typeinfo, val, True) + + +def get_state(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any: +    """Get the state of the object given as val.""" +    return _process(typeinfo, val, False) diff --git a/mitmproxy/utils/typecheck.py b/mitmproxy/utils/typecheck.py index 1070fad0..22db68f5 100644 --- a/mitmproxy/utils/typecheck.py +++ b/mitmproxy/utils/typecheck.py @@ -1,7 +1,40 @@  import typing +Type = typing.Union[ +    typing.Any  # anything more elaborate really fails with mypy at the moment. +] -def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> None: + +def sequence_type(typeinfo: typing.Type[typing.List]) -> Type: +    """Return the type of a sequence, e.g. typing.List""" +    try: +        return typeinfo.__args__[0]  # type: ignore +    except AttributeError:  # Python 3.5.0 +        return typeinfo.__parameters__[0]  # type: ignore + + +def tuple_types(typeinfo: typing.Type[typing.Tuple]) -> typing.Sequence[Type]: +    """Return the types of a typing.Tuple""" +    try: +        return typeinfo.__args__  # type: ignore +    except AttributeError:  # Python 3.5.x +        return typeinfo.__tuple_params__  # type: ignore + + +def union_types(typeinfo: typing.Type[typing.Tuple]) -> typing.Sequence[Type]: +    """return the types of a typing.Union""" +    try: +        return typeinfo.__args__  # type: ignore +    except AttributeError:  # Python 3.5.x +        return typeinfo.__union_params__  # type: ignore + + +def mapping_types(typeinfo: typing.Type[typing.Mapping]) -> typing.Tuple[Type, Type]: +    """return the types of a mapping, e.g. typing.Dict""" +    return typeinfo.__args__  # type: ignore + + +def check_option_type(name: str, value: typing.Any, typeinfo: Type) -> None:      """      Check if the provided value is an instance of typeinfo and raises a      TypeError otherwise. This function supports only those types required for @@ -16,13 +49,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non      typename = str(typeinfo)      if typename.startswith("typing.Union"): -        try: -            types = typeinfo.__args__  # type: ignore -        except AttributeError: -            # Python 3.5.x -            types = typeinfo.__union_params__  # type: ignore - -        for T in types: +        for T in union_types(typeinfo):              try:                  check_option_type(name, value, T)              except TypeError: @@ -31,12 +58,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non                  return          raise e      elif typename.startswith("typing.Tuple"): -        try: -            types = typeinfo.__args__  # type: ignore -        except AttributeError: -            # Python 3.5.x -            types = typeinfo.__tuple_params__  # type: ignore - +        types = tuple_types(typeinfo)          if not isinstance(value, (tuple, list)):              raise e          if len(types) != len(value): @@ -45,11 +67,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non              check_option_type("{}[{}]".format(name, i), x, T)          return      elif typename.startswith("typing.Sequence"): -        try: -            T = typeinfo.__args__[0]  # type: ignore -        except AttributeError: -            # Python 3.5.0 -            T = typeinfo.__parameters__[0]  # type: ignore +        T = sequence_type(typeinfo)          if not isinstance(value, (tuple, list)):              raise e          for v in value: @@ -75,7 +75,6 @@ exclude =      mitmproxy/proxy/protocol/tls.py      mitmproxy/proxy/root_context.py      mitmproxy/proxy/server.py -    mitmproxy/stateobject.py      mitmproxy/utils/bits.py      pathod/language/actions.py      pathod/language/base.py diff --git a/test/mitmproxy/test_stateobject.py b/test/mitmproxy/test_stateobject.py index d8c7a8e9..bd5d1792 100644 --- a/test/mitmproxy/test_stateobject.py +++ b/test/mitmproxy/test_stateobject.py @@ -1,101 +1,146 @@ -from typing import List +import typing +  import pytest  from mitmproxy.stateobject import StateObject -class Child(StateObject): +class TObject(StateObject):      def __init__(self, x):          self.x = x -    _stateobject_attributes = dict( -        x=int -    ) -      @classmethod      def from_state(cls, state):          obj = cls(None)          obj.set_state(state)          return obj + +class Child(TObject): +    _stateobject_attributes = dict( +        x=int +    ) +      def __eq__(self, other):          return isinstance(other, Child) and self.x == other.x -class Container(StateObject): -    def __init__(self): -        self.child = None -        self.children = None -        self.dictionary = None +class TTuple(TObject): +    _stateobject_attributes = dict( +        x=typing.Tuple[int, Child] +    ) + + +class TList(TObject): +    _stateobject_attributes = dict( +        x=typing.List[Child] +    ) + +class TDict(TObject):      _stateobject_attributes = dict( -        child=Child, -        children=List[Child], -        dictionary=dict, +        x=typing.Dict[str, Child]      ) -    @classmethod -    def from_state(cls, state): -        obj = cls() -        obj.set_state(state) -        return obj + +class TAny(TObject): +    _stateobject_attributes = dict( +        x=typing.Any +    ) + + +class TSerializableChild(TObject): +    _stateobject_attributes = dict( +        x=Child +    )  def test_simple():      a = Child(42) +    assert a.get_state() == {"x": 42}      b = a.copy() -    assert b.get_state() == {"x": 42}      a.set_state({"x": 44})      assert a.x == 44      assert b.x == 42 -def test_container(): -    a = Container() -    a.child = Child(42) +def test_serializable_child(): +    child = Child(42) +    a = TSerializableChild(child) +    assert a.get_state() == { +        "x": {"x": 42} +    } +    a.set_state({ +        "x": {"x": 43} +    }) +    assert a.x.x == 43 +    assert a.x is child      b = a.copy() -    assert a.child.x == b.child.x -    b.child.x = 44 -    assert a.child.x != b.child.x +    assert a.x == b.x +    assert a.x is not b.x -def test_container_list(): -    a = Container() -    a.children = [Child(42), Child(44)] +def test_tuple(): +    a = TTuple((42, Child(43)))      assert a.get_state() == { -        "child": None, -        "children": [{"x": 42}, {"x": 44}], -        "dictionary": None, +        "x": (42, {"x": 43})      } -    copy = a.copy() -    assert len(copy.children) == 2 -    assert copy.children is not a.children -    assert copy.children[0] is not a.children[0] -    assert Container.from_state(a.get_state()) +    b = a.copy() +    a.set_state({"x": (44, {"x": 45})}) +    assert a.x == (44, Child(45)) +    assert b.x == (42, Child(43)) + +def test_tuple_err(): +    a = TTuple(None) +    with pytest.raises(ValueError, msg="Invalid data"): +        a.set_state({"x": (42,)}) -def test_container_dict(): -    a = Container() -    a.dictionary = dict() -    a.dictionary['foo'] = 'bar' -    a.dictionary['bar'] = Child(44) + +def test_list(): +    a = TList([Child(1), Child(2)])      assert a.get_state() == { -        "child": None, -        "children": None, -        "dictionary": {'bar': {'x': 44}, 'foo': 'bar'}, +        "x": [{"x": 1}, {"x": 2}],      }      copy = a.copy() -    assert len(copy.dictionary) == 2 -    assert copy.dictionary is not a.dictionary -    assert copy.dictionary['bar'] is not a.dictionary['bar'] +    assert len(copy.x) == 2 +    assert copy.x is not a.x +    assert copy.x[0] is not a.x[0] + + +def test_dict(): +    a = TDict({"foo": Child(42)}) +    assert a.get_state() == { +        "x": {"foo": {"x": 42}} +    } +    b = a.copy() +    assert list(a.x.items()) == list(b.x.items()) +    assert a.x is not b.x +    assert a.x["foo"] is not b.x["foo"] + + +def test_any(): +    a = TAny(42) +    b = a.copy() +    assert a.x == b.x + +    a = TAny(object()) +    with pytest.raises(AssertionError): +        a.get_state()  def test_too_much_state(): -    a = Container() -    a.child = Child(42) +    a = Child(42)      s = a.get_state()      s['foo'] = 'bar' -    b = Container()      with pytest.raises(RuntimeWarning): -        b.set_state(s) +        a.set_state(s) + + +def test_none(): +    a = Child(None) +    assert a.get_state() == {"x": None} +    a = Child(42) +    a.set_state({"x": None}) +    assert a.x is None diff --git a/test/mitmproxy/utils/test_typecheck.py b/test/mitmproxy/utils/test_typecheck.py index 5295fff5..9cb4334e 100644 --- a/test/mitmproxy/utils/test_typecheck.py +++ b/test/mitmproxy/utils/test_typecheck.py @@ -93,3 +93,8 @@ def test_typesec_to_str():      assert(typecheck.typespec_to_str(typing.Optional[str])) == "optional str"      with pytest.raises(NotImplementedError):          typecheck.typespec_to_str(dict) + + +def test_mapping_types(): +    # this is not covered by check_option_type, but still belongs in this module +    assert (str, int) == typecheck.mapping_types(typing.Mapping[str, int]) | 
