diff options
-rw-r--r-- | mitmproxy/types/multidict.py | 27 | ||||
-rw-r--r-- | test/mitmproxy/types/test_multidict.py | 9 |
2 files changed, 24 insertions, 12 deletions
diff --git a/mitmproxy/types/multidict.py b/mitmproxy/types/multidict.py index 0a9b12d9..c4f42580 100644 --- a/mitmproxy/types/multidict.py +++ b/mitmproxy/types/multidict.py @@ -4,7 +4,7 @@ from collections.abc import MutableMapping from mitmproxy.types import serializable -class _MultiDict(MutableMapping, serializable.Serializable, metaclass=ABCMeta): +class _MultiDict(MutableMapping, metaclass=ABCMeta): def __repr__(self): fields = ( repr(field) @@ -171,18 +171,8 @@ class _MultiDict(MutableMapping, serializable.Serializable, metaclass=ABCMeta): coll.append([key, values]) return coll - def get_state(self): - return self.fields - - def set_state(self, state): - self.fields = tuple(tuple(x) for x in state) - @classmethod - def from_state(cls, state): - return cls(state) - - -class MultiDict(_MultiDict): +class MultiDict(_MultiDict, serializable.Serializable): def __init__(self, fields=()): super().__init__() self.fields = tuple( @@ -197,6 +187,16 @@ class MultiDict(_MultiDict): def _kconv(key): return key + def get_state(self): + return self.fields + + def set_state(self, state): + self.fields = tuple(tuple(x) for x in state) + + @classmethod + def from_state(cls, state): + return cls(state) + class MultiDictView(_MultiDict): """ @@ -227,3 +227,6 @@ class MultiDictView(_MultiDict): @fields.setter def fields(self, value): self._setter(value) + + def copy(self): + return MultiDict(self.fields) diff --git a/test/mitmproxy/types/test_multidict.py b/test/mitmproxy/types/test_multidict.py index 3b879ed1..c76cd753 100644 --- a/test/mitmproxy/types/test_multidict.py +++ b/test/mitmproxy/types/test_multidict.py @@ -200,3 +200,12 @@ class TestMultiDictView: tv["c"] = "b" assert p.vals == (("a", "b"), ("c", "b")) assert tv["a"] == "b" + + def test_copy(self): + p = TParent() + tv = multidict.MultiDictView(p.getter, p.setter) + c = tv.copy() + assert isinstance(c, multidict.MultiDict) + assert tv.items() == c.items() + c["foo"] = "bar" + assert tv.items() != c.items() |