diff options
| author | Aldo Cortesi <aldo@nullcube.com> | 2016-05-21 10:15:37 +1200 | 
|---|---|---|
| committer | Aldo Cortesi <aldo@nullcube.com> | 2016-05-21 10:15:37 +1200 | 
| commit | 96d8ec1ee33b076a472afc3053fdd8256559fcc3 (patch) | |
| tree | 933549b94c497b70eb6165f90bef191eebca4cc7 /netlib | |
| parent | 84144ca0c635f4a42c8ba8a13e779fe127a81d45 (diff) | |
| parent | b538138ead1dc8550f2d4e4a3f30ff70abb95f53 (diff) | |
| download | mitmproxy-96d8ec1ee33b076a472afc3053fdd8256559fcc3.tar.gz mitmproxy-96d8ec1ee33b076a472afc3053fdd8256559fcc3.tar.bz2 mitmproxy-96d8ec1ee33b076a472afc3053fdd8256559fcc3.zip | |
Merge branch 'multidict' of https://github.com/mhils/mitmproxy into mhils-multidict
Diffstat (limited to 'netlib')
| -rw-r--r-- | netlib/encoding.py | 1 | ||||
| -rw-r--r-- | netlib/http/__init__.py | 8 | ||||
| -rw-r--r-- | netlib/http/cookies.py | 60 | ||||
| -rw-r--r-- | netlib/http/headers.py | 140 | ||||
| -rw-r--r-- | netlib/http/http1/read.py | 4 | ||||
| -rw-r--r-- | netlib/http/http2/connections.py | 12 | ||||
| -rw-r--r-- | netlib/http/message.py | 70 | ||||
| -rw-r--r-- | netlib/http/request.py | 106 | ||||
| -rw-r--r-- | netlib/http/response.py | 41 | ||||
| -rw-r--r-- | netlib/multidict.py | 248 | ||||
| -rw-r--r-- | netlib/utils.py | 11 | 
11 files changed, 487 insertions, 214 deletions
| diff --git a/netlib/encoding.py b/netlib/encoding.py index 14479e00..98502451 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -5,7 +5,6 @@ from __future__ import absolute_import  from io import BytesIO  import gzip  import zlib -from .utils import always_byte_args  ENCODINGS = {"identity", "gzip", "deflate"} diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 917080f7..9fafa28f 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -2,13 +2,13 @@ from __future__ import absolute_import, print_function, division  from .request import Request  from .response import Response  from .headers import Headers -from .message import decoded -from . import http1, http2 +from .message import MultiDictView, decoded +from . import http1, http2, status_codes  __all__ = [      "Request",      "Response",      "Headers", -    "decoded", -    "http1", "http2", +    "MultiDictView", "decoded", +    "http1", "http2", "status_codes",  ] diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index 4451f1da..88c76870 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -1,8 +1,8 @@ -from six.moves import http_cookies as Cookie +import collections  import re -import string  from email.utils import parsedate_tz, formatdate, mktime_tz +from netlib.multidict import ImmutableMultiDict  from .. import odict  """ @@ -157,42 +157,76 @@ def _parse_set_cookie_pairs(s):      return pairs +def parse_set_cookie_headers(headers): +    ret = [] +    for header in headers: +        v = parse_set_cookie_header(header) +        if v: +            name, value, attrs = v +            ret.append((name, SetCookie(value, attrs))) +    return ret + + +class CookieAttrs(ImmutableMultiDict): +    @staticmethod +    def _kconv(key): +        return key.lower() + +    @staticmethod +    def _reduce_values(values): +        # See the StickyCookieTest for a weird cookie that only makes sense +        # if we take the last part. +        return values[-1] + + +SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"]) + +  def parse_set_cookie_header(line):      """          Parse a Set-Cookie header value          Returns a (name, value, attrs) tuple, or None, where attrs is an -        ODictCaseless set of attributes. No attempt is made to parse attribute +        CookieAttrs dict of attributes. No attempt is made to parse attribute          values - they are treated purely as strings.      """      pairs = _parse_set_cookie_pairs(line)      if pairs: -        return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) +        return pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:])  def format_set_cookie_header(name, value, attrs):      """          Formats a Set-Cookie header value.      """ -    pairs = [[name, value]] -    pairs.extend(attrs.lst) +    pairs = [(name, value)] +    pairs.extend( +        attrs.fields if hasattr(attrs, "fields") else attrs +    )      return _format_set_cookie_pairs(pairs) +def parse_cookie_headers(cookie_headers): +    cookie_list = [] +    for header in cookie_headers: +        cookie_list.extend(parse_cookie_header(header)) +    return cookie_list + +  def parse_cookie_header(line):      """          Parse a Cookie header value. -        Returns a (possibly empty) ODict object. +        Returns a list of (lhs, rhs) tuples.      """      pairs, off_ = _read_pairs(line) -    return odict.ODict(pairs) +    return pairs -def format_cookie_header(od): +def format_cookie_header(lst):      """          Formats a Cookie header value.      """ -    return _format_pairs(od.lst) +    return _format_pairs(lst)  def refresh_set_cookie_header(c, delta): @@ -209,10 +243,10 @@ def refresh_set_cookie_header(c, delta):          raise ValueError("Invalid Cookie")      if "expires" in attrs: -        e = parsedate_tz(attrs["expires"][-1]) +        e = parsedate_tz(attrs["expires"])          if e:              f = mktime_tz(e) + delta -            attrs["expires"] = [formatdate(f)] +            attrs = attrs.with_set_all("expires", [formatdate(f)])          else:              # This can happen when the expires tag is invalid.              # reddit.com sends a an expires tag like this: "Thu, 31 Dec @@ -220,7 +254,7 @@ def refresh_set_cookie_header(c, delta):              # strictly correct according to the cookie spec. Browsers              # appear to parse this tolerantly - maybe we should too.              # For now, we just ignore this. -            del attrs["expires"] +            attrs = attrs.with_delitem("expires")      ret = format_set_cookie_header(name, value, attrs)      if not ret: diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 72739f90..60d3f429 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -1,9 +1,3 @@ -""" - -Unicode Handling ----------------- -See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ -"""  from __future__ import absolute_import, print_function, division  import re @@ -13,23 +7,22 @@ try:  except ImportError:  # pragma: no cover      from collections import MutableMapping  # Workaround for Python < 3.3 -  import six +from ..multidict import MultiDict +from ..utils import always_bytes -from netlib.utils import always_byte_args, always_bytes, Serializable +# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/  if six.PY2:  # pragma: no cover      _native = lambda x: x      _always_bytes = lambda x: x -    _always_byte_args = lambda x: x  else:      # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.      _native = lambda x: x.decode("utf-8", "surrogateescape")      _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") -    _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping, Serializable): +class Headers(MultiDict):      """      Header class which allows both convenient access to individual headers as well as      direct access to the underlying raw data. Provides a full dictionary interface. @@ -49,11 +42,11 @@ class Headers(MutableMapping, Serializable):          >>> h["host"]          "example.com" -        # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples +        # Headers can also be created from a list of raw (header_name, header_value) byte tuples          >>> h = Headers([ -            [b"Host",b"example.com"], -            [b"Accept",b"text/html"], -            [b"accept",b"application/xml"] +            (b"Host",b"example.com"), +            (b"Accept",b"text/html"), +            (b"accept",b"application/xml")          ])          # Multiple headers are folded into a single header as per RFC7230 @@ -77,7 +70,6 @@ class Headers(MutableMapping, Serializable):          For use with the "Set-Cookie" header, see :py:meth:`get_all`.      """ -    @_always_byte_args      def __init__(self, fields=None, **headers):          """          Args: @@ -89,19 +81,29 @@ class Headers(MutableMapping, Serializable):                  If ``**headers`` contains multiple keys that have equal ``.lower()`` s,                  the behavior is undefined.          """ -        self.fields = fields or [] +        super(Headers, self).__init__(fields) -        for name, value in self.fields: -            if not isinstance(name, bytes) or not isinstance(value, bytes): -                raise ValueError("Headers passed as fields must be bytes.") +        for key, value in self.fields: +            if not isinstance(key, bytes) or not isinstance(value, bytes): +                raise TypeError("Header fields must be bytes.")          # content_type -> content-type          headers = { -            _always_bytes(name).replace(b"_", b"-"): value +            _always_bytes(name).replace(b"_", b"-"): _always_bytes(value)              for name, value in six.iteritems(headers)              }          self.update(headers) +    @staticmethod +    def _reduce_values(values): +        # Headers can be folded +        return ", ".join(values) + +    @staticmethod +    def _kconv(key): +        # Headers are case-insensitive +        return key.lower() +      def __bytes__(self):          if self.fields:              return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" @@ -111,98 +113,40 @@ class Headers(MutableMapping, Serializable):      if six.PY2:  # pragma: no cover          __str__ = __bytes__ -    @_always_byte_args -    def __getitem__(self, name): -        values = self.get_all(name) -        if not values: -            raise KeyError(name) -        return ", ".join(values) - -    @_always_byte_args -    def __setitem__(self, name, value): -        idx = self._index(name) - -        # To please the human eye, we insert at the same position the first existing header occured. -        if idx is not None: -            del self[name] -            self.fields.insert(idx, [name, value]) -        else: -            self.fields.append([name, value]) - -    @_always_byte_args -    def __delitem__(self, name): -        if name not in self: -            raise KeyError(name) -        name = name.lower() -        self.fields = [ -            field for field in self.fields -            if name != field[0].lower() -        ] +    def __delitem__(self, key): +        key = _always_bytes(key) +        super(Headers, self).__delitem__(key)      def __iter__(self): -        seen = set() -        for name, _ in self.fields: -            name_lower = name.lower() -            if name_lower not in seen: -                seen.add(name_lower) -                yield _native(name) - -    def __len__(self): -        return len(set(name.lower() for name, _ in self.fields)) - -    # __hash__ = object.__hash__ - -    def _index(self, name): -        name = name.lower() -        for i, field in enumerate(self.fields): -            if field[0].lower() == name: -                return i -        return None - -    def __eq__(self, other): -        if isinstance(other, Headers): -            return self.fields == other.fields -        return False - -    def __ne__(self, other): -        return not self.__eq__(other) - -    @_always_byte_args +        for x in super(Headers, self).__iter__(): +            yield _native(x) +      def get_all(self, name):          """          Like :py:meth:`get`, but does not fold multiple headers into a single one.          This is useful for Set-Cookie headers, which do not support folding. -          See also: https://tools.ietf.org/html/rfc7230#section-3.2.2          """ -        name_lower = name.lower() -        values = [_native(value) for n, value in self.fields if n.lower() == name_lower] -        return values +        name = _always_bytes(name) +        return [ +            _native(x) for x in +            super(Headers, self).get_all(name) +        ] -    @_always_byte_args      def set_all(self, name, values):          """          Explicitly set multiple headers for the given key.          See: :py:meth:`get_all`          """ -        values = map(_always_bytes, values)  # _always_byte_args does not fix lists -        if name in self: -            del self[name] -        self.fields.extend( -            [name, value] for value in values -        ) - -    def get_state(self): -        return tuple(tuple(field) for field in self.fields) - -    def set_state(self, state): -        self.fields = [list(field) for field in state] +        name = _always_bytes(name) +        values = [_always_bytes(x) for x in values] +        return super(Headers, self).set_all(name, values) -    @classmethod -    def from_state(cls, state): -        return cls([list(field) for field in state]) +    def insert(self, index, key, value): +        key = _always_bytes(key) +        value = _always_bytes(value) +        super(Headers, self).insert(index, key, value) -    @_always_byte_args      def replace(self, pattern, repl, flags=0):          """          Replaces a regular expression pattern with repl in each "name: value" @@ -211,6 +155,8 @@ class Headers(MutableMapping, Serializable):          Returns:              The number of replacements made.          """ +        pattern = _always_bytes(pattern) +        repl = _always_bytes(repl)          pattern = re.compile(pattern, flags)          replacements = 0 diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 6e3a1b93..d30976bd 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -316,14 +316,14 @@ def _read_headers(rfile):              if not ret:                  raise HttpSyntaxException("Invalid headers")              # continued header -            ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() +            ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip())          else:              try:                  name, value = line.split(b":", 1)                  value = value.strip()                  if not name:                      raise ValueError() -                ret.append([name, value]) +                ret.append((name, value))              except ValueError:                  raise HttpSyntaxException("Invalid headers")      return Headers(ret) diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index f900b67c..6643b6b9 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -201,13 +201,13 @@ class HTTP2Protocol(object):          headers = request.headers.copy()          if ':authority' not in headers: -            headers.fields.insert(0, (b':authority', authority.encode('ascii'))) +            headers.insert(0, b':authority', authority.encode('ascii'))          if ':scheme' not in headers: -            headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) +            headers.insert(0, b':scheme', request.scheme.encode('ascii'))          if ':path' not in headers: -            headers.fields.insert(0, (b':path', request.path.encode('ascii'))) +            headers.insert(0, b':path', request.path.encode('ascii'))          if ':method' not in headers: -            headers.fields.insert(0, (b':method', request.method.encode('ascii'))) +            headers.insert(0, b':method', request.method.encode('ascii'))          if hasattr(request, 'stream_id'):              stream_id = request.stream_id @@ -224,7 +224,7 @@ class HTTP2Protocol(object):          headers = response.headers.copy()          if ':status' not in headers: -            headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) +            headers.insert(0, b':status', str(response.status_code).encode('ascii'))          if hasattr(response, 'stream_id'):              stream_id = response.stream_id @@ -420,7 +420,7 @@ class HTTP2Protocol(object):                  self._handle_unexpected_frame(frm)          headers = Headers( -            [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] +            (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)          )          return stream_id, headers, body diff --git a/netlib/http/message.py b/netlib/http/message.py index da9681a0..db4054b1 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,6 +4,7 @@ import warnings  import six +from ..multidict import MultiDict  from .headers import Headers  from .. import encoding, utils @@ -235,3 +236,72 @@ class decoded(object):      def __exit__(self, type, value, tb):          if self.ce:              self.message.encode(self.ce) + + +class MultiDictView(MultiDict): +    """ +    Some parts in HTTP (Cookies, URL query strings, ...) require a specific data structure: A MultiDict. +    It behaves mostly like an ordered dict but it can have several values for the same key. + +    The MultiDictView provides a MultiDict *view* on an :py:class:`Request` or :py:class:`Response`. +    That is, it represents a part of the request as a MultiDict, but doesn't contain state/data themselves. + +    For example, ``request.cookies`` provides a view on the ``Cookie: ...`` header. +    Any change to ``request.cookies`` will also modify the ``Cookie`` header. +    Any change to the ``Cookie`` header will also modify ``request.cookies``. + +    Example: + +    .. code-block:: python + +        # Cookies are represented as a MultiDict. +        >>> request.cookies +        MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] + +        # MultiDicts mostly behave like a normal dict. +        >>> request.cookies["name"] +        "value" + +        # If there is more than one value, only the first value is returned. +        >>> request.cookies["a"] +        "false" + +        # `.get_all(key)` returns a list of all values. +        >>> request.cookies.get_all("a") +        ["false", "42"] + +        # Changes to the headers are immediately reflected in the cookies. +        >>> request.cookies +        MultiDictView[("name", "value"), ...] +        >>> del request.headers["Cookie"] +        >>> request.cookies +        MultiDictView[]  # empty now +    """ + +    def __init__(self, attr, message): +        if False:  # pragma: no cover +            # We do not want to call the parent constructor here as that +            # would cause an unnecessary parse/unparse pass. +            # This is here to silence linters. Message +            super(MultiDictView, self).__init__(None) +        self._attr = attr +        self._message = message  # type: Message + +    @staticmethod +    def _kconv(key): +        # All request-attributes are case-sensitive. +        return key + +    @staticmethod +    def _reduce_values(values): +        # We just return the first element if +        # multiple elements exist with the same key. +        return values[0] + +    @property +    def fields(self): +        return getattr(self._message, "_" + self._attr) + +    @fields.setter +    def fields(self, value): +        setattr(self._message, self._attr, value) diff --git a/netlib/http/request.py b/netlib/http/request.py index a42150ff..ae28084b 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -11,7 +11,7 @@ from netlib.http import cookies  from netlib.odict import ODict  from .. import encoding  from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData +from .message import Message, _native, _always_bytes, MessageData, MultiDictView  # This regex extracts & splits the host header into host and port.  # Handles the edge case of IPv6 addresses containing colons. @@ -224,45 +224,54 @@ class Request(Message):      @property      def query(self): +        # type: () -> MultiDictView          """ -        The request query string as an :py:class:`ODict` object. -        None, if there is no query. +        The request query string as an :py:class:`MultiDictView` object.          """ +        return MultiDictView("query", self) + +    @property +    def _query(self):          _, _, _, _, query, _ = urllib.parse.urlparse(self.url) -        if query: -            return ODict(utils.urldecode(query)) -        return None +        return tuple(utils.urldecode(query))      @query.setter -    def query(self, odict): -        query = utils.urlencode(odict.lst) +    def query(self, value): +        query = utils.urlencode(value)          scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)          _, _, _, self.path = utils.parse_url(                  urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))      @property      def cookies(self): +        # type: () -> MultiDictView          """          The request cookies. -        An empty :py:class:`ODict` object if the cookie monster ate them all. + +        An empty :py:class:`MultiDictView` object if the cookie monster ate them all.          """ -        ret = ODict() -        for i in self.headers.get_all("Cookie"): -            ret.extend(cookies.parse_cookie_header(i)) -        return ret +        return MultiDictView("cookies", self) + +    @property +    def _cookies(self): +        h = self.headers.get_all("Cookie") +        return tuple(cookies.parse_cookie_headers(h))      @cookies.setter -    def cookies(self, odict): -        self.headers["cookie"] = cookies.format_cookie_header(odict) +    def cookies(self, value): +        self.headers["cookie"] = cookies.format_cookie_header(value)      @property      def path_components(self):          """ -        The URL's path components as a list of strings. +        The URL's path components as a tuple of strings.          Components are unquoted.          """          _, _, path, _, _, _ = urllib.parse.urlparse(self.url) -        return [urllib.parse.unquote(i) for i in path.split("/") if i] +        # This needs to be a tuple so that it's immutable. +        # Otherwise, this would fail silently: +        #   request.path_components.append("foo") +        return tuple(urllib.parse.unquote(i) for i in path.split("/") if i)      @path_components.setter      def path_components(self, components): @@ -309,64 +318,43 @@ class Request(Message):      @property      def urlencoded_form(self):          """ -        The URL-encoded form data as an :py:class:`ODict` object. -        None if there is no data or the content-type indicates non-form data. +        The URL-encoded form data as an :py:class:`MultiDictView` object. +        An empty MultiDictView if the content-type indicates non-form data +        or the content could not be parsed.          """ +        return MultiDictView("urlencoded_form", self) + +    @property +    def _urlencoded_form(self):          is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() -        if self.content and is_valid_content_type: -            return ODict(utils.urldecode(self.content)) -        return None +        if is_valid_content_type: +            return tuple(utils.urldecode(self.content)) +        return ()      @urlencoded_form.setter -    def urlencoded_form(self, odict): +    def urlencoded_form(self, value):          """          Sets the body to the URL-encoded form data, and adds the appropriate content-type header.          This will overwrite the existing content if there is one.          """          self.headers["content-type"] = "application/x-www-form-urlencoded" -        self.content = utils.urlencode(odict.lst) +        self.content = utils.urlencode(value)      @property      def multipart_form(self):          """ -        The multipart form data as an :py:class:`ODict` object. -        None if there is no data or the content-type indicates non-form data. +        The multipart form data as an :py:class:`MultipartFormDict` object. +        None if the content-type indicates non-form data.          """ +        return MultiDictView("multipart_form", self) + +    @property +    def _multipart_form(self):          is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() -        if self.content and is_valid_content_type: -            return ODict(utils.multipartdecode(self.headers,self.content)) -        return None +        if is_valid_content_type: +            return utils.multipartdecode(self.headers, self.content) +        return ()      @multipart_form.setter      def multipart_form(self, value):          raise NotImplementedError() - -    # Legacy - -    def get_query(self):  # pragma: no cover -        warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) -        return self.query or ODict([]) - -    def set_query(self, odict):  # pragma: no cover -        warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) -        self.query = odict - -    def get_path_components(self):  # pragma: no cover -        warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) -        return self.path_components - -    def set_path_components(self, lst):  # pragma: no cover -        warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) -        self.path_components = lst - -    def get_form_urlencoded(self):  # pragma: no cover -        warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) -        return self.urlencoded_form or ODict([]) - -    def set_form_urlencoded(self, odict):  # pragma: no cover -        warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) -        self.urlencoded_form = odict - -    def get_form_multipart(self):  # pragma: no cover -        warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) -        return self.multipart_form or ODict([]) diff --git a/netlib/http/response.py b/netlib/http/response.py index 2f06149e..6d56fc1f 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,14 +1,12 @@  from __future__ import absolute_import, print_function, division -import warnings  from email.utils import parsedate_tz, formatdate, mktime_tz  import time  from . import cookies  from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData +from .message import Message, _native, _always_bytes, MessageData, MultiDictView  from .. import utils -from ..odict import ODict  class ResponseData(MessageData): @@ -72,29 +70,30 @@ class Response(Message):      @property      def cookies(self): +        # type: () -> MultiDictView          """ -        Get the contents of all Set-Cookie headers. +        The response cookies. A possibly empty :py:class:`MultiDictView`, where the keys are +        cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is +        an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly) +        are indicated by a Null value. -        A possibly empty :py:class:`ODict`, where keys are cookie name strings, -        and values are [value, attr] lists. Value is a string, and attr is -        an ODictCaseless containing cookie attributes. Within attrs, unary -        attributes (e.g. HTTPOnly) are indicated by a Null value. +        Caveats: +            Updating the attr          """ -        ret = [] -        for header in self.headers.get_all("set-cookie"): -            v = cookies.parse_set_cookie_header(header) -            if v: -                name, value, attrs = v -                ret.append([name, [value, attrs]]) -        return ODict(ret) +        return MultiDictView("cookies", self) + +    @property +    def _cookies(self): +        h = self.headers.get_all("set-cookie") +        return tuple(cookies.parse_set_cookie_headers(h))      @cookies.setter -    def cookies(self, odict): -        values = [] -        for i in odict.lst: -            header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) -            values.append(header) -        self.headers.set_all("set-cookie", values) +    def cookies(self, all_cookies): +        cookie_headers = [] +        for k, v in all_cookies: +            header = cookies.format_set_cookie_header(k, v[0], v[1]) +            cookie_headers.append(header) +        self.headers.set_all("set-cookie", cookie_headers)      def refresh(self, now=None):          """ diff --git a/netlib/multidict.py b/netlib/multidict.py new file mode 100644 index 00000000..a359d46b --- /dev/null +++ b/netlib/multidict.py @@ -0,0 +1,248 @@ +from __future__ import absolute_import, print_function, division + +from abc import ABCMeta, abstractmethod + +from typing import Tuple, TypeVar + +try: +    from collections.abc import MutableMapping +except ImportError:  # pragma: no cover +    from collections import MutableMapping  # Workaround for Python < 3.3 + +import six + +from .utils import Serializable + + +@six.add_metaclass(ABCMeta) +class MultiDict(MutableMapping, Serializable): +    def __init__(self, fields=None): + +        # it is important for us that .fields is immutable, so that we can easily +        # detect changes to it. +        self.fields = tuple(fields) if fields else tuple()  # type: Tuple[Tuple[bytes, bytes], ...] + +    def __repr__(self): +        fields = tuple( +            repr(field) +            for field in self.fields +        ) +        return "{cls}[{fields}]".format( +            cls=type(self).__name__, +            fields=", ".join(fields) +        ) + +    @staticmethod +    @abstractmethod +    def _reduce_values(values): +        """ +        If a user accesses multidict["foo"], this method +        reduces all values for "foo" to a single value that is returned. +        For example, HTTP headers are folded, whereas we will just take +        the first cookie we found with that name. +        """ + +    @staticmethod +    @abstractmethod +    def _kconv(key): +        """ +        This method converts a key to its canonical representation. +        For example, HTTP headers are case-insensitive, so this method returns key.lower(). +        """ + +    def __getitem__(self, key): +        values = self.get_all(key) +        if not values: +            raise KeyError(key) +        return self._reduce_values(values) + +    def __setitem__(self, key, value): +        self.set_all(key, [value]) + +    def __delitem__(self, key): +        if key not in self: +            raise KeyError(key) +        key = self._kconv(key) +        self.fields = tuple( +            field for field in self.fields +            if key != self._kconv(field[0]) +        ) + +    def __iter__(self): +        seen = set() +        for key, _ in self.fields: +            key_kconv = self._kconv(key) +            if key_kconv not in seen: +                seen.add(key_kconv) +                yield key + +    def __len__(self): +        return len(set(self._kconv(key) for key, _ in self.fields)) + +    def __eq__(self, other): +        if isinstance(other, MultiDict): +            return self.fields == other.fields +        return False + +    def __ne__(self, other): +        return not self.__eq__(other) + +    def get_all(self, key): +        """ +        Return the list of all values for a given key. +        If that key is not in the MultiDict, the return value will be an empty list. +        """ +        key = self._kconv(key) +        return [ +            value +            for k, value in self.fields +            if self._kconv(k) == key +            ] + +    def set_all(self, key, values): +        """ +        Remove the old values for a key and add new ones. +        """ +        key_kconv = self._kconv(key) + +        new_fields = [] +        for field in self.fields: +            if self._kconv(field[0]) == key_kconv: +                if values: +                    new_fields.append( +                        (key, values.pop(0)) +                    ) +            else: +                new_fields.append(field) +        while values: +            new_fields.append( +                (key, values.pop(0)) +            ) +        self.fields = tuple(new_fields) + +    def add(self, key, value): +        """ +        Add an additional value for the given key at the bottom. +        """ +        self.insert(len(self.fields), key, value) + +    def insert(self, index, key, value): +        """ +        Insert an additional value for the given key at the specified position. +        """ +        item = (key, value) +        self.fields = self.fields[:index] + (item,) + self.fields[index:] + +    def keys(self, multi=False): +        """ +        Get all keys. + +        Args: +            multi(bool): +                If True, one key per value will be returned. +                If False, duplicate keys will only be returned once. +        """ +        return ( +            k +            for k, _ in self.items(multi) +        ) + +    def values(self, multi=False): +        """ +        Get all values. + +        Args: +            multi(bool): +                If True, all values will be returned. +                If False, only the first value per key will be returned. +        """ +        return ( +            v +            for _, v in self.items(multi) +        ) + +    def items(self, multi=False): +        """ +        Get all (key, value) tuples. + +        Args: +            multi(bool): +                If True, all (key, value) pairs will be returned +                If False, only the first (key, value) pair per unique key will be returned. +        """ +        if multi: +            return self.fields +        else: +            return super(MultiDict, self).items() + +    def to_dict(self): +        """ +        Get the MultiDict as a plain Python dict. +        Keys with multiple values are returned as lists. + +        Example: + +        .. code-block:: python + +            # Simple dict with duplicate values. +            >>> d +            MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] +            >>> d.to_dict() +            { +                "name": "value", +                "a": ["false", "42"] +            } +        """ +        d = {} +        for key in self: +            values = self.get_all(key) +            if len(values) == 1: +                d[key] = values[0] +            else: +                d[key] = values +        return d + +    def get_state(self): +        return self.fields + +    def set_state(self, state): +        self.fields = tuple(tuple(x) for x in state) + +    @classmethod +    def from_state(cls, state): +        return cls(tuple(x) for x in state) + + +@six.add_metaclass(ABCMeta) +class ImmutableMultiDict(MultiDict): +    def _immutable(self, *_): +        raise TypeError('{} objects are immutable'.format(self.__class__.__name__)) + +    __delitem__ = set_all = insert = _immutable + +    def with_delitem(self, key): +        """ +        Returns: +            An updated ImmutableMultiDict. The original object will not be modified. +        """ +        ret = self.copy() +        super(ImmutableMultiDict, ret).__delitem__(key) +        return ret + +    def with_set_all(self, key, values): +        """ +        Returns: +            An updated ImmutableMultiDict. The original object will not be modified. +        """ +        ret = self.copy() +        super(ImmutableMultiDict, ret).set_all(key, values) +        return ret + +    def with_insert(self, index, key, value): +        """ +        Returns: +            An updated ImmutableMultiDict. The original object will not be modified. +        """ +        ret = self.copy() +        super(ImmutableMultiDict, ret).insert(index, key, value) +        return ret diff --git a/netlib/utils.py b/netlib/utils.py index be2701a0..7499f71f 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -51,17 +51,6 @@ def always_bytes(unicode_or_bytes, *encode_args):      return unicode_or_bytes -def always_byte_args(*encode_args): -    """Decorator that transparently encodes all arguments passed as unicode""" -    def decorator(fun): -        def _fun(*args, **kwargs): -            args = [always_bytes(arg, *encode_args) for arg in args] -            kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} -            return fun(*args, **kwargs) -        return _fun -    return decorator - -  def native(s, *encoding_opts):      """      Convert :py:class:`bytes` or :py:class:`unicode` to the native | 
