aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/multidict.py
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2016-05-18 22:50:19 -0700
committerMaximilian Hils <git@maximilianhils.com>2016-05-18 22:50:19 -0700
commit6f8db2d7eb32684a8328e0ae8bdd73eceb861707 (patch)
tree254d964e9f8b95393b82683f66b9c2f77fb060de /netlib/multidict.py
parent8e39b7bf38e7becd1116dfcded380327fd0228d0 (diff)
downloadmitmproxy-6f8db2d7eb32684a8328e0ae8bdd73eceb861707.tar.gz
mitmproxy-6f8db2d7eb32684a8328e0ae8bdd73eceb861707.tar.bz2
mitmproxy-6f8db2d7eb32684a8328e0ae8bdd73eceb861707.zip
improve MultiDict, add ImmutableMultiDict, adjust response.cookies
Diffstat (limited to 'netlib/multidict.py')
-rw-r--r--netlib/multidict.py403
1 files changed, 240 insertions, 163 deletions
diff --git a/netlib/multidict.py b/netlib/multidict.py
index a7158bc5..32d5bfc2 100644
--- a/netlib/multidict.py
+++ b/netlib/multidict.py
@@ -1,163 +1,240 @@
-from __future__ import absolute_import, print_function, division
-
-from abc import ABCMeta, abstractmethod
-
-from typing import Tuple
-
-try:
- from collections.abc import MutableMapping
-except ImportError: # pragma: no cover
- from collections import MutableMapping # Workaround for Python < 3.3
-
-import six
-
-from .utils import Serializable
-
-
-@six.add_metaclass(ABCMeta)
-class MultiDict(MutableMapping, Serializable):
- def __init__(self, fields=None):
-
- # it is important for us that .fields is immutable, so that we can easily
- # detect changes to it.
- self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
-
- for key, value in self.fields:
- if not isinstance(key, bytes) or not isinstance(value, bytes):
- raise TypeError("MultiDict fields must be bytes.")
-
- def __repr__(self):
- fields = tuple(
- repr(field)
- for field in self.fields
- )
- return "{cls}[{fields}]".format(
- cls=type(self).__name__,
- fields=", ".join(fields)
- )
-
- @staticmethod
- @abstractmethod
- def _reduce_values(values):
- pass
-
- @staticmethod
- @abstractmethod
- def _kconv(v):
- pass
-
- def __getitem__(self, key):
- values = self.get_all(key)
- if not values:
- raise KeyError(key)
- return self._reduce_values(values)
-
- def __setitem__(self, key, value):
- self.set_all(key, [value])
-
- def __delitem__(self, key):
- if key not in self:
- raise KeyError(key)
- key = self._kconv(key)
- self.fields = tuple(
- field for field in self.fields
- if key != self._kconv(field[0])
- )
-
- def __iter__(self):
- seen = set()
- for key, _ in self.fields:
- key_kconv = self._kconv(key)
- if key_kconv not in seen:
- seen.add(key_kconv)
- yield key
-
- def __len__(self):
- return len(set(self._kconv(key) for key, _ in self.fields))
-
- def __eq__(self, other):
- if isinstance(other, MultiDict):
- return self.fields == other.fields
- return False
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def get_all(self, key):
- """
- Return the list of items for a given key.
- If that key is not in the MultiDict,
- the return value will be an empty list.
- """
- key = self._kconv(key)
- return [
- value
- for k, value in self.fields
- if self._kconv(k) == key
- ]
-
- def set_all(self, key, values):
- """
- Remove the old values for a key and add new ones.
- """
- key_kconv = self._kconv(key)
-
- new_fields = []
- for field in self.fields:
- if self._kconv(field[0]) == key_kconv:
- if values:
- new_fields.append(
- (key, values.pop(0))
- )
- else:
- new_fields.append(field)
- while values:
- new_fields.append(
- (key, values.pop(0))
- )
- self.fields = tuple(new_fields)
-
- def add(self, key, value):
- self.insert(len(self.fields), key, value)
-
- def insert(self, index, key, value):
- item = (key, value)
- self.fields = self.fields[:index] + (item,) + self.fields[index:]
-
- def keys(self, multi=False):
- return (
- k
- for k, _ in self.items(multi)
- )
-
- def values(self, multi=False):
- return (
- v
- for _, v in self.items(multi)
- )
-
- def items(self, multi=False):
- if multi:
- return self.fields
- else:
- return super(MultiDict, self).items()
-
- def to_dict(self):
- d = {}
- for key in self:
- values = self.get_all(key)
- if len(values) == 1:
- d[key] = values[0]
- else:
- d[key] = values
- return d
-
- def get_state(self):
- return self.fields
-
- def set_state(self, state):
- self.fields = tuple(tuple(x) for x in state)
-
- @classmethod
- def from_state(cls, state):
- return cls(tuple(x) for x in state)
+from __future__ import absolute_import, print_function, division
+
+from abc import ABCMeta, abstractmethod
+
+from typing import Tuple, TypeVar
+
+try:
+ from collections.abc import MutableMapping
+except ImportError: # pragma: no cover
+ from collections import MutableMapping # Workaround for Python < 3.3
+
+import six
+
+from .utils import Serializable
+
+
+@six.add_metaclass(ABCMeta)
+class MultiDict(MutableMapping, Serializable):
+ def __init__(self, fields=None):
+
+ # it is important for us that .fields is immutable, so that we can easily
+ # detect changes to it.
+ self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
+
+ def __repr__(self):
+ fields = tuple(
+ repr(field)
+ for field in self.fields
+ )
+ return "{cls}[{fields}]".format(
+ cls=type(self).__name__,
+ fields=", ".join(fields)
+ )
+
+ @staticmethod
+ @abstractmethod
+ def _reduce_values(values):
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def _kconv(v):
+ pass
+
+ def __getitem__(self, key):
+ values = self.get_all(key)
+ if not values:
+ raise KeyError(key)
+ return self._reduce_values(values)
+
+ def __setitem__(self, key, value):
+ self.set_all(key, [value])
+
+ def __delitem__(self, key):
+ if key not in self:
+ raise KeyError(key)
+ key = self._kconv(key)
+ self.fields = tuple(
+ field for field in self.fields
+ if key != self._kconv(field[0])
+ )
+
+ def __iter__(self):
+ seen = set()
+ for key, _ in self.fields:
+ key_kconv = self._kconv(key)
+ if key_kconv not in seen:
+ seen.add(key_kconv)
+ yield key
+
+ def __len__(self):
+ return len(set(self._kconv(key) for key, _ in self.fields))
+
+ def __eq__(self, other):
+ if isinstance(other, MultiDict):
+ return self.fields == other.fields
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def get_all(self, key):
+ """
+ Return the list of all values for a given key.
+ If that key is not in the MultiDict, the return value will be an empty list.
+ """
+ key = self._kconv(key)
+ return [
+ value
+ for k, value in self.fields
+ if self._kconv(k) == key
+ ]
+
+ def set_all(self, key, values):
+ """
+ Remove the old values for a key and add new ones.
+ """
+ key_kconv = self._kconv(key)
+
+ new_fields = []
+ for field in self.fields:
+ if self._kconv(field[0]) == key_kconv:
+ if values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ else:
+ new_fields.append(field)
+ while values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ self.fields = tuple(new_fields)
+
+ def add(self, key, value):
+ """
+ Add an additional value for the given key at the bottom.
+ """
+ self.insert(len(self.fields), key, value)
+
+ def insert(self, index, key, value):
+ """
+ Insert an additional value for the given key at the specified position.
+ """
+ item = (key, value)
+ self.fields = self.fields[:index] + (item,) + self.fields[index:]
+
+ def keys(self, multi=False):
+ """
+ Get all keys.
+
+ Args:
+ multi(bool):
+ If True, one key per value will be returned.
+ If False, duplicate keys will only be returned once.
+ """
+ return (
+ k
+ for k, _ in self.items(multi)
+ )
+
+ def values(self, multi=False):
+ """
+ Get all values.
+
+ Args:
+ multi(bool):
+ If True, all values will be returned.
+ If False, only the first value per key will be returned.
+ """
+ return (
+ v
+ for _, v in self.items(multi)
+ )
+
+ def items(self, multi=False):
+ """
+ Get all (key, value) tuples.
+
+ Args:
+ multi(bool):
+ If True, all (key, value) pairs will be returned
+ If False, only the first (key, value) pair per unique key will be returned.
+ """
+ if multi:
+ return self.fields
+ else:
+ return super(MultiDict, self).items()
+
+ def to_dict(self):
+ """
+ Get the MultiDict as a plain Python dict.
+ Keys with multiple values are returned as lists.
+
+ Example:
+
+ .. code-block:: python
+
+ # Simple dict with duplicate values.
+ >>> d
+ MultiDictView[("name", "value"), ("a", "false"), ("a", "42")]
+ >>> d.to_dict()
+ {
+ "name": "value",
+ "a": ["false", "42"]
+ }
+ """
+ d = {}
+ for key in self:
+ values = self.get_all(key)
+ if len(values) == 1:
+ d[key] = values[0]
+ else:
+ d[key] = values
+ return d
+
+ def get_state(self):
+ return self.fields
+
+ def set_state(self, state):
+ self.fields = tuple(tuple(x) for x in state)
+
+ @classmethod
+ def from_state(cls, state):
+ return cls(tuple(x) for x in state)
+
+
+@six.add_metaclass(ABCMeta)
+class ImmutableMultiDict(MultiDict):
+ def _immutable(self, *_):
+ raise TypeError('{} objects are immutable'.format(self.__class__.__name__))
+
+ __delitem__ = set_all = insert = _immutable
+
+ def with_delitem(self, key):
+ """
+ Returns:
+ An updated ImmutableMultiDict. The original object will not be modified.
+ """
+ ret = self.copy()
+ super(ImmutableMultiDict, ret).__delitem__(key)
+ return ret
+
+ def with_set_all(self, key, values):
+ """
+ Returns:
+ An updated ImmutableMultiDict. The original object will not be modified.
+ """
+ ret = self.copy()
+ super(ImmutableMultiDict, ret).set_all(key, values)
+ return ret
+
+ def with_insert(self, index, key, value):
+ """
+ Returns:
+ An updated ImmutableMultiDict. The original object will not be modified.
+ """
+ ret = self.copy()
+ super(ImmutableMultiDict, ret).insert(index, key, value)
+ return ret