aboutsummaryrefslogtreecommitdiffstats
path: root/mitmproxy/stateobject.py
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2018-01-12 22:42:02 +0100
committerMaximilian Hils <git@maximilianhils.com>2018-01-13 00:33:37 +0100
commit69726f180a70f42f8233b673aea209b0dabaa161 (patch)
treef0c757a864f4f3021bac1c538a728888194fdce2 /mitmproxy/stateobject.py
parentb7db304dde0daf2b410dc36d33a24856aa22ba59 (diff)
downloadmitmproxy-69726f180a70f42f8233b673aea209b0dabaa161.tar.gz
mitmproxy-69726f180a70f42f8233b673aea209b0dabaa161.tar.bz2
mitmproxy-69726f180a70f42f8233b673aea209b0dabaa161.zip
stateobject: use typing, enable tuples and more complex datatypes
Diffstat (limited to 'mitmproxy/stateobject.py')
-rw-r--r--mitmproxy/stateobject.py83
1 files changed, 50 insertions, 33 deletions
diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py
index 007339e8..ffaf285f 100644
--- a/mitmproxy/stateobject.py
+++ b/mitmproxy/stateobject.py
@@ -1,18 +1,12 @@
-from typing import Any
-from typing import List
+import typing
+from typing import Any # noqa
from typing import MutableMapping # noqa
from mitmproxy.coretypes import serializable
-
-
-def _is_list(cls):
- # The typing module is broken on Python 3.5.0, fixed on 3.5.1.
- is_list_bugfix = getattr(cls, "__origin__", False) == getattr(List[Any], "__origin__", True)
- return issubclass(cls, List) or is_list_bugfix
+from mitmproxy.utils import typecheck
class StateObject(serializable.Serializable):
-
"""
An object with serializable state.
@@ -34,22 +28,7 @@ class StateObject(serializable.Serializable):
state = {}
for attr, cls in self._stateobject_attributes.items():
val = getattr(self, attr)
- if val is None:
- state[attr] = None
- elif hasattr(val, "get_state"):
- state[attr] = val.get_state()
- elif _is_list(cls):
- state[attr] = [x.get_state() for x in val]
- elif isinstance(val, dict):
- s = {}
- for k, v in val.items():
- if hasattr(v, "get_state"):
- s[k] = v.get_state()
- else:
- s[k] = v
- state[attr] = s
- else:
- state[attr] = val
+ state[attr] = get_state(cls, val)
return state
def set_state(self, state):
@@ -65,13 +44,51 @@ class StateObject(serializable.Serializable):
curr = getattr(self, attr)
if hasattr(curr, "set_state"):
curr.set_state(val)
- elif hasattr(cls, "from_state"):
- obj = cls.from_state(val)
- setattr(self, attr, obj)
- elif _is_list(cls):
- cls = cls.__parameters__[0] if cls.__parameters__ else cls.__args__[0]
- setattr(self, attr, [cls.from_state(x) for x in val])
- else: # primitive types such as int, str, ...
- setattr(self, attr, cls(val))
+ else:
+ setattr(self, attr, make_object(cls, val))
if state:
raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state))
+
+
+def _process(typeinfo: typecheck.Type, val: typing.Any, make: bool) -> typing.Any:
+ if val is None:
+ return None
+ elif make and hasattr(typeinfo, "from_state"):
+ return typeinfo.from_state(val)
+ elif not make and hasattr(val, "get_state"):
+ return val.get_state()
+
+ typename = str(typeinfo)
+
+ if typename.startswith("typing.List"):
+ T = typecheck.sequence_type(typeinfo)
+ return [_process(T, x, make) for x in val]
+ elif typename.startswith("typing.Tuple"):
+ Ts = typecheck.tuple_types(typeinfo)
+ if len(Ts) != len(val):
+ raise ValueError("Invalid data. Expected {}, got {}.".format(Ts, val))
+ return tuple(
+ _process(T, x, make) for T, x in zip(Ts, val)
+ )
+ elif typename.startswith("typing.Dict"):
+ k_cls, v_cls = typecheck.mapping_types(typeinfo)
+ return {
+ _process(k_cls, k, make): _process(v_cls, v, make)
+ for k, v in val.items()
+ }
+ elif typename.startswith("typing.Any"):
+ # FIXME: Remove this when we remove flow.metadata
+ assert isinstance(val, (int, str, bool, bytes))
+ return val
+ else:
+ return typeinfo(val)
+
+
+def make_object(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any:
+ """Create an object based on the state given in val."""
+ return _process(typeinfo, val, True)
+
+
+def get_state(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any:
+ """Get the state of the object given as val."""
+ return _process(typeinfo, val, False)