diff options
author | Aldo Cortesi <aldo@nullcube.com> | 2016-05-21 15:01:19 +1200 |
---|---|---|
committer | Aldo Cortesi <aldo@nullcube.com> | 2016-05-21 15:01:19 +1200 |
commit | 97f3077082371a078380ec28e1425de3dc2a0021 (patch) | |
tree | 9bb2e4dfeacec85209174672fee3e78df32dcfaf /netlib/multidict.py | |
parent | 14fb2eeb1ea734d161e84be3920691591f6531a3 (diff) | |
parent | 43d796553292a9bbf2c174e31e5e0e39e1068be1 (diff) | |
download | mitmproxy-97f3077082371a078380ec28e1425de3dc2a0021.tar.gz mitmproxy-97f3077082371a078380ec28e1425de3dc2a0021.tar.bz2 mitmproxy-97f3077082371a078380ec28e1425de3dc2a0021.zip |
Merge branch 'mhils-multidict'
Diffstat (limited to 'netlib/multidict.py')
-rw-r--r-- | netlib/multidict.py | 279 |
1 files changed, 279 insertions, 0 deletions
diff --git a/netlib/multidict.py b/netlib/multidict.py new file mode 100644 index 00000000..3af7979b --- /dev/null +++ b/netlib/multidict.py @@ -0,0 +1,279 @@ +from __future__ import absolute_import, print_function, division + +from abc import ABCMeta, abstractmethod + +from typing import Tuple, TypeVar + +try: + from collections.abc import MutableMapping +except ImportError: # pragma: no cover + from collections import MutableMapping # Workaround for Python < 3.3 + +import six + +from .utils import Serializable + + +@six.add_metaclass(ABCMeta) +class _MultiDict(MutableMapping, Serializable): + def __repr__(self): + fields = tuple( + repr(field) + for field in self.fields + ) + return "{cls}[{fields}]".format( + cls=type(self).__name__, + fields=", ".join(fields) + ) + + @staticmethod + @abstractmethod + def _reduce_values(values): + """ + If a user accesses multidict["foo"], this method + reduces all values for "foo" to a single value that is returned. + For example, HTTP headers are folded, whereas we will just take + the first cookie we found with that name. + """ + + @staticmethod + @abstractmethod + def _kconv(key): + """ + This method converts a key to its canonical representation. + For example, HTTP headers are case-insensitive, so this method returns key.lower(). + """ + + def __getitem__(self, key): + values = self.get_all(key) + if not values: + raise KeyError(key) + return self._reduce_values(values) + + def __setitem__(self, key, value): + self.set_all(key, [value]) + + def __delitem__(self, key): + if key not in self: + raise KeyError(key) + key = self._kconv(key) + self.fields = tuple( + field for field in self.fields + if key != self._kconv(field[0]) + ) + + def __iter__(self): + seen = set() + for key, _ in self.fields: + key_kconv = self._kconv(key) + if key_kconv not in seen: + seen.add(key_kconv) + yield key + + def __len__(self): + return len(set(self._kconv(key) for key, _ in self.fields)) + + def __eq__(self, other): + if isinstance(other, MultiDict): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_all(self, key): + """ + Return the list of all values for a given key. + If that key is not in the MultiDict, the return value will be an empty list. + """ + key = self._kconv(key) + return [ + value + for k, value in self.fields + if self._kconv(k) == key + ] + + def set_all(self, key, values): + """ + Remove the old values for a key and add new ones. + """ + key_kconv = self._kconv(key) + + new_fields = [] + for field in self.fields: + if self._kconv(field[0]) == key_kconv: + if values: + new_fields.append( + (key, values.pop(0)) + ) + else: + new_fields.append(field) + while values: + new_fields.append( + (key, values.pop(0)) + ) + self.fields = tuple(new_fields) + + def add(self, key, value): + """ + Add an additional value for the given key at the bottom. + """ + self.insert(len(self.fields), key, value) + + def insert(self, index, key, value): + """ + Insert an additional value for the given key at the specified position. + """ + item = (key, value) + self.fields = self.fields[:index] + (item,) + self.fields[index:] + + def keys(self, multi=False): + """ + Get all keys. + + Args: + multi(bool): + If True, one key per value will be returned. + If False, duplicate keys will only be returned once. + """ + return ( + k + for k, _ in self.items(multi) + ) + + def values(self, multi=False): + """ + Get all values. + + Args: + multi(bool): + If True, all values will be returned. + If False, only the first value per key will be returned. + """ + return ( + v + for _, v in self.items(multi) + ) + + def items(self, multi=False): + """ + Get all (key, value) tuples. + + Args: + multi(bool): + If True, all (key, value) pairs will be returned + If False, only the first (key, value) pair per unique key will be returned. + """ + if multi: + return self.fields + else: + return super(_MultiDict, self).items() + + def to_dict(self): + """ + Get the MultiDict as a plain Python dict. + Keys with multiple values are returned as lists. + + Example: + + .. code-block:: python + + # Simple dict with duplicate values. + >>> d + MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] + >>> d.to_dict() + { + "name": "value", + "a": ["false", "42"] + } + """ + d = {} + for key in self: + values = self.get_all(key) + if len(values) == 1: + d[key] = values[0] + else: + d[key] = values + return d + + def get_state(self): + return self.fields + + def set_state(self, state): + self.fields = tuple(tuple(x) for x in state) + + @classmethod + def from_state(cls, state): + return cls(tuple(x) for x in state) + + +class MultiDict(_MultiDict): + def __init__(self, fields=None): + super(MultiDict, self).__init__() + self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] + + +@six.add_metaclass(ABCMeta) +class ImmutableMultiDict(MultiDict): + def _immutable(self, *_): + raise TypeError('{} objects are immutable'.format(self.__class__.__name__)) + + __delitem__ = set_all = insert = _immutable + + def with_delitem(self, key): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).__delitem__(key) + return ret + + def with_set_all(self, key, values): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).set_all(key, values) + return ret + + def with_insert(self, index, key, value): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).insert(index, key, value) + return ret + + +class MultiDictView(_MultiDict): + """ + The MultiDictView provides the MultiDict interface over calculated data. + The view itself contains no state - data is retrieved from the parent on + request, and stored back to the parent on change. + """ + def __init__(self, getter, setter): + self._getter = getter + self._setter = setter + super(MultiDictView, self).__init__() + + @staticmethod + def _kconv(key): + # All request-attributes are case-sensitive. + return key + + @staticmethod + def _reduce_values(values): + # We just return the first element if + # multiple elements exist with the same key. + return values[0] + + @property + def fields(self): + return self._getter() + + @fields.setter + def fields(self, value): + return self._setter(value) |