From 44ac64aa7235362acbb96e0f12aa27534580e575 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 18 May 2016 18:46:42 -0700 Subject: add MultiDict This commit introduces MultiDict, a multi-dictionary similar to ODict, but with improved semantics (as in the Headers class). MultiDict fixes a few issues that were present in the Request/Response API. In particular, `request.cookies["foo"] = "bar"` has previously been a no-op, as the cookies property returned a mutable _copy_ of the cookies. --- netlib/multidict.py | 163 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 netlib/multidict.py (limited to 'netlib/multidict.py') diff --git a/netlib/multidict.py b/netlib/multidict.py new file mode 100644 index 00000000..a7158bc5 --- /dev/null +++ b/netlib/multidict.py @@ -0,0 +1,163 @@ +from __future__ import absolute_import, print_function, division + +from abc import ABCMeta, abstractmethod + +from typing import Tuple + +try: + from collections.abc import MutableMapping +except ImportError: # pragma: no cover + from collections import MutableMapping # Workaround for Python < 3.3 + +import six + +from .utils import Serializable + + +@six.add_metaclass(ABCMeta) +class MultiDict(MutableMapping, Serializable): + def __init__(self, fields=None): + + # it is important for us that .fields is immutable, so that we can easily + # detect changes to it. + self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] + + for key, value in self.fields: + if not isinstance(key, bytes) or not isinstance(value, bytes): + raise TypeError("MultiDict fields must be bytes.") + + def __repr__(self): + fields = tuple( + repr(field) + for field in self.fields + ) + return "{cls}[{fields}]".format( + cls=type(self).__name__, + fields=", ".join(fields) + ) + + @staticmethod + @abstractmethod + def _reduce_values(values): + pass + + @staticmethod + @abstractmethod + def _kconv(v): + pass + + def __getitem__(self, key): + values = self.get_all(key) + if not values: + raise KeyError(key) + return self._reduce_values(values) + + def __setitem__(self, key, value): + self.set_all(key, [value]) + + def __delitem__(self, key): + if key not in self: + raise KeyError(key) + key = self._kconv(key) + self.fields = tuple( + field for field in self.fields + if key != self._kconv(field[0]) + ) + + def __iter__(self): + seen = set() + for key, _ in self.fields: + key_kconv = self._kconv(key) + if key_kconv not in seen: + seen.add(key_kconv) + yield key + + def __len__(self): + return len(set(self._kconv(key) for key, _ in self.fields)) + + def __eq__(self, other): + if isinstance(other, MultiDict): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_all(self, key): + """ + Return the list of items for a given key. + If that key is not in the MultiDict, + the return value will be an empty list. + """ + key = self._kconv(key) + return [ + value + for k, value in self.fields + if self._kconv(k) == key + ] + + def set_all(self, key, values): + """ + Remove the old values for a key and add new ones. + """ + key_kconv = self._kconv(key) + + new_fields = [] + for field in self.fields: + if self._kconv(field[0]) == key_kconv: + if values: + new_fields.append( + (key, values.pop(0)) + ) + else: + new_fields.append(field) + while values: + new_fields.append( + (key, values.pop(0)) + ) + self.fields = tuple(new_fields) + + def add(self, key, value): + self.insert(len(self.fields), key, value) + + def insert(self, index, key, value): + item = (key, value) + self.fields = self.fields[:index] + (item,) + self.fields[index:] + + def keys(self, multi=False): + return ( + k + for k, _ in self.items(multi) + ) + + def values(self, multi=False): + return ( + v + for _, v in self.items(multi) + ) + + def items(self, multi=False): + if multi: + return self.fields + else: + return super(MultiDict, self).items() + + def to_dict(self): + d = {} + for key in self: + values = self.get_all(key) + if len(values) == 1: + d[key] = values[0] + else: + d[key] = values + return d + + def get_state(self): + return self.fields + + def set_state(self, state): + self.fields = tuple(tuple(x) for x in state) + + @classmethod + def from_state(cls, state): + return cls(tuple(x) for x in state) -- cgit v1.2.3