aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/multidict.py
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2016-05-18 18:46:42 -0700
committerMaximilian Hils <git@maximilianhils.com>2016-05-18 18:46:42 -0700
commit44ac64aa7235362acbb96e0f12aa27534580e575 (patch)
treec03b8c3519c273a4f42b60cb2bce8cc0dd524925 /netlib/multidict.py
parent4c3fb8f5097fad2c5de96104dae3f8026b0b4666 (diff)
downloadmitmproxy-44ac64aa7235362acbb96e0f12aa27534580e575.tar.gz
mitmproxy-44ac64aa7235362acbb96e0f12aa27534580e575.tar.bz2
mitmproxy-44ac64aa7235362acbb96e0f12aa27534580e575.zip
add MultiDict
This commit introduces MultiDict, a multi-dictionary similar to ODict, but with improved semantics (as in the Headers class). MultiDict fixes a few issues that were present in the Request/Response API. In particular, `request.cookies["foo"] = "bar"` has previously been a no-op, as the cookies property returned a mutable _copy_ of the cookies.
Diffstat (limited to 'netlib/multidict.py')
-rw-r--r--netlib/multidict.py163
1 files changed, 163 insertions, 0 deletions
diff --git a/netlib/multidict.py b/netlib/multidict.py
new file mode 100644
index 00000000..a7158bc5
--- /dev/null
+++ b/netlib/multidict.py
@@ -0,0 +1,163 @@
+from __future__ import absolute_import, print_function, division
+
+from abc import ABCMeta, abstractmethod
+
+from typing import Tuple
+
+try:
+ from collections.abc import MutableMapping
+except ImportError: # pragma: no cover
+ from collections import MutableMapping # Workaround for Python < 3.3
+
+import six
+
+from .utils import Serializable
+
+
+@six.add_metaclass(ABCMeta)
+class MultiDict(MutableMapping, Serializable):
+ def __init__(self, fields=None):
+
+ # it is important for us that .fields is immutable, so that we can easily
+ # detect changes to it.
+ self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
+
+ for key, value in self.fields:
+ if not isinstance(key, bytes) or not isinstance(value, bytes):
+ raise TypeError("MultiDict fields must be bytes.")
+
+ def __repr__(self):
+ fields = tuple(
+ repr(field)
+ for field in self.fields
+ )
+ return "{cls}[{fields}]".format(
+ cls=type(self).__name__,
+ fields=", ".join(fields)
+ )
+
+ @staticmethod
+ @abstractmethod
+ def _reduce_values(values):
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def _kconv(v):
+ pass
+
+ def __getitem__(self, key):
+ values = self.get_all(key)
+ if not values:
+ raise KeyError(key)
+ return self._reduce_values(values)
+
+ def __setitem__(self, key, value):
+ self.set_all(key, [value])
+
+ def __delitem__(self, key):
+ if key not in self:
+ raise KeyError(key)
+ key = self._kconv(key)
+ self.fields = tuple(
+ field for field in self.fields
+ if key != self._kconv(field[0])
+ )
+
+ def __iter__(self):
+ seen = set()
+ for key, _ in self.fields:
+ key_kconv = self._kconv(key)
+ if key_kconv not in seen:
+ seen.add(key_kconv)
+ yield key
+
+ def __len__(self):
+ return len(set(self._kconv(key) for key, _ in self.fields))
+
+ def __eq__(self, other):
+ if isinstance(other, MultiDict):
+ return self.fields == other.fields
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def get_all(self, key):
+ """
+ Return the list of items for a given key.
+ If that key is not in the MultiDict,
+ the return value will be an empty list.
+ """
+ key = self._kconv(key)
+ return [
+ value
+ for k, value in self.fields
+ if self._kconv(k) == key
+ ]
+
+ def set_all(self, key, values):
+ """
+ Remove the old values for a key and add new ones.
+ """
+ key_kconv = self._kconv(key)
+
+ new_fields = []
+ for field in self.fields:
+ if self._kconv(field[0]) == key_kconv:
+ if values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ else:
+ new_fields.append(field)
+ while values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ self.fields = tuple(new_fields)
+
+ def add(self, key, value):
+ self.insert(len(self.fields), key, value)
+
+ def insert(self, index, key, value):
+ item = (key, value)
+ self.fields = self.fields[:index] + (item,) + self.fields[index:]
+
+ def keys(self, multi=False):
+ return (
+ k
+ for k, _ in self.items(multi)
+ )
+
+ def values(self, multi=False):
+ return (
+ v
+ for _, v in self.items(multi)
+ )
+
+ def items(self, multi=False):
+ if multi:
+ return self.fields
+ else:
+ return super(MultiDict, self).items()
+
+ def to_dict(self):
+ d = {}
+ for key in self:
+ values = self.get_all(key)
+ if len(values) == 1:
+ d[key] = values[0]
+ else:
+ d[key] = values
+ return d
+
+ def get_state(self):
+ return self.fields
+
+ def set_state(self, state):
+ self.fields = tuple(tuple(x) for x in state)
+
+ @classmethod
+ def from_state(cls, state):
+ return cls(tuple(x) for x in state)