aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2016-04-29 20:34:12 -0700
committerMaximilian Hils <git@maximilianhils.com>2016-04-29 20:59:26 -0700
commit74cfd7a4e2a64aa5ee3c98c3c7a0e2e668779618 (patch)
treefc3bba451a43b9e34b6dca88bc60ac5dd1af4427
parentcb1119f3eebc57914fc6093f0afcc7b3cd88fcc7 (diff)
downloadmitmproxy-74cfd7a4e2a64aa5ee3c98c3c7a0e2e668779618.tar.gz
mitmproxy-74cfd7a4e2a64aa5ee3c98c3c7a0e2e668779618.tar.bz2
mitmproxy-74cfd7a4e2a64aa5ee3c98c3c7a0e2e668779618.zip
stateobject: support lists
-rw-r--r--mitmproxy/stateobject.py17
-rw-r--r--setup.py1
-rw-r--r--test/mitmproxy/test_stateobject.py63
3 files changed, 80 insertions, 1 deletions
diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py
index fff6e116..765c35d6 100644
--- a/mitmproxy/stateobject.py
+++ b/mitmproxy/stateobject.py
@@ -1,10 +1,18 @@
from __future__ import absolute_import
import six
+from typing import List, Any
from netlib.utils import Serializable
+def _is_list(cls):
+ # The typing module backport is somewhat broken.
+ # Python 3.5 or 3.6 should fix this.
+ is_list_bugfix = getattr(cls, "__origin__", False) == getattr(List[Any], "__origin__", True)
+ return issubclass(cls, List) or is_list_bugfix
+
+
class StateObject(Serializable):
"""
@@ -28,8 +36,12 @@ class StateObject(Serializable):
state = {}
for attr, cls in six.iteritems(self._stateobject_attributes):
val = getattr(self, attr)
- if hasattr(val, "get_state"):
+ if val is None:
+ state[attr] = None
+ elif hasattr(val, "get_state"):
state[attr] = val.get_state()
+ elif _is_list(cls):
+ state[attr] = [x.get_state() for x in val]
else:
state[attr] = val
return state
@@ -49,6 +61,9 @@ class StateObject(Serializable):
elif hasattr(cls, "from_state"):
obj = cls.from_state(state.pop(attr))
setattr(self, attr, obj)
+ elif _is_list(cls):
+ cls = cls.__parameters__[0]
+ setattr(self, attr, [cls.from_state(x) for x in state.pop(attr)])
else: # primitive types such as int, str, ...
setattr(self, attr, cls(state.pop(attr)))
if state:
diff --git a/setup.py b/setup.py
index f2007329..a0777d02 100644
--- a/setup.py
+++ b/setup.py
@@ -81,6 +81,7 @@ setup(
"requests>=2.9.1, <2.10",
"six>=1.10, <1.11",
"tornado>=4.3, <4.4",
+ "typing==3.5.1.0",
"urwid>=1.3.1, <1.4",
"watchdog>=0.8.3, <0.9",
],
diff --git a/test/mitmproxy/test_stateobject.py b/test/mitmproxy/test_stateobject.py
new file mode 100644
index 00000000..b9ffe7ae
--- /dev/null
+++ b/test/mitmproxy/test_stateobject.py
@@ -0,0 +1,63 @@
+from typing import List
+
+from mitmproxy.stateobject import StateObject
+
+
+class Child(StateObject):
+ def __init__(self, x):
+ self.x = x
+
+ _stateobject_attributes = dict(
+ x=int
+ )
+
+ @classmethod
+ def from_state(cls, state):
+ obj = cls(None)
+ obj.set_state(state)
+ return obj
+
+
+class Container(StateObject):
+ def __init__(self):
+ self.child = None
+ self.children = None
+
+ _stateobject_attributes = dict(
+ child=Child,
+ children=List[Child],
+ )
+
+ @classmethod
+ def from_state(cls, state):
+ obj = cls()
+ obj.set_state(state)
+ return obj
+
+
+def test_simple():
+ a = Child(42)
+ b = a.copy()
+ assert b.get_state() == {"x": 42}
+ a.set_state({"x": 44})
+ assert a.x == 44
+ assert b.x == 42
+
+
+def test_container():
+ a = Container()
+ a.child = Child(42)
+ b = a.copy()
+ assert a.child.x == b.child.x
+ b.child.x = 44
+ assert a.child.x != b.child.x
+
+
+def test_container_list():
+ a = Container()
+ a.children = [Child(42), Child(44)]
+ assert a.get_state() == {
+ "child": None,
+ "children": [{"x": 42}, {"x": 44}]
+ }
+ assert len(a.copy().children) == 2