aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2016-05-18 18:46:42 -0700
committerMaximilian Hils <git@maximilianhils.com>2016-05-18 18:46:42 -0700
commit44ac64aa7235362acbb96e0f12aa27534580e575 (patch)
treec03b8c3519c273a4f42b60cb2bce8cc0dd524925 /netlib
parent4c3fb8f5097fad2c5de96104dae3f8026b0b4666 (diff)
downloadmitmproxy-44ac64aa7235362acbb96e0f12aa27534580e575.tar.gz
mitmproxy-44ac64aa7235362acbb96e0f12aa27534580e575.tar.bz2
mitmproxy-44ac64aa7235362acbb96e0f12aa27534580e575.zip
add MultiDict
This commit introduces MultiDict, a multi-dictionary similar to ODict, but with improved semantics (as in the Headers class). MultiDict fixes a few issues that were present in the Request/Response API. In particular, `request.cookies["foo"] = "bar"` has previously been a no-op, as the cookies property returned a mutable _copy_ of the cookies.
Diffstat (limited to 'netlib')
-rw-r--r--netlib/encoding.py1
-rw-r--r--netlib/http/cookies.py17
-rw-r--r--netlib/http/headers.py132
-rw-r--r--netlib/http/http1/read.py4
-rw-r--r--netlib/http/http2/connections.py12
-rw-r--r--netlib/http/message.py35
-rw-r--r--netlib/http/request.py71
-rw-r--r--netlib/http/response.py2
-rw-r--r--netlib/multidict.py163
-rw-r--r--netlib/utils.py11
10 files changed, 300 insertions, 148 deletions
diff --git a/netlib/encoding.py b/netlib/encoding.py
index 14479e00..98502451 100644
--- a/netlib/encoding.py
+++ b/netlib/encoding.py
@@ -5,7 +5,6 @@ from __future__ import absolute_import
from io import BytesIO
import gzip
import zlib
-from .utils import always_byte_args
ENCODINGS = {"identity", "gzip", "deflate"}
diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py
index 4451f1da..fd531146 100644
--- a/netlib/http/cookies.py
+++ b/netlib/http/cookies.py
@@ -1,6 +1,4 @@
-from six.moves import http_cookies as Cookie
import re
-import string
from email.utils import parsedate_tz, formatdate, mktime_tz
from .. import odict
@@ -179,20 +177,27 @@ def format_set_cookie_header(name, value, attrs):
return _format_set_cookie_pairs(pairs)
+def parse_cookie_headers(cookie_headers):
+ cookie_list = []
+ for header in cookie_headers:
+ cookie_list.extend(parse_cookie_header(header))
+ return cookie_list
+
+
def parse_cookie_header(line):
"""
Parse a Cookie header value.
- Returns a (possibly empty) ODict object.
+ Returns a list of (lhs, rhs) tuples.
"""
pairs, off_ = _read_pairs(line)
- return odict.ODict(pairs)
+ return pairs
-def format_cookie_header(od):
+def format_cookie_header(lst):
"""
Formats a Cookie header value.
"""
- return _format_pairs(od.lst)
+ return _format_pairs(lst)
def refresh_set_cookie_header(c, delta):
diff --git a/netlib/http/headers.py b/netlib/http/headers.py
index 72739f90..7e39c371 100644
--- a/netlib/http/headers.py
+++ b/netlib/http/headers.py
@@ -1,9 +1,3 @@
-"""
-
-Unicode Handling
-----------------
-See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
-"""
from __future__ import absolute_import, print_function, division
import re
@@ -13,23 +7,22 @@ try:
except ImportError: # pragma: no cover
from collections import MutableMapping # Workaround for Python < 3.3
-
import six
+from ..multidict import MultiDict
+from ..utils import always_bytes
-from netlib.utils import always_byte_args, always_bytes, Serializable
+# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
if six.PY2: # pragma: no cover
_native = lambda x: x
_always_bytes = lambda x: x
- _always_byte_args = lambda x: x
else:
# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
_native = lambda x: x.decode("utf-8", "surrogateescape")
_always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape")
- _always_byte_args = always_byte_args("utf-8", "surrogateescape")
-class Headers(MutableMapping, Serializable):
+class Headers(MultiDict):
"""
Header class which allows both convenient access to individual headers as well as
direct access to the underlying raw data. Provides a full dictionary interface.
@@ -49,7 +42,7 @@ class Headers(MutableMapping, Serializable):
>>> h["host"]
"example.com"
- # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples
+ # Headers can also be created from a list of raw (header_name, header_value) byte tuples
>>> h = Headers([
[b"Host",b"example.com"],
[b"Accept",b"text/html"],
@@ -77,7 +70,6 @@ class Headers(MutableMapping, Serializable):
For use with the "Set-Cookie" header, see :py:meth:`get_all`.
"""
- @_always_byte_args
def __init__(self, fields=None, **headers):
"""
Args:
@@ -89,19 +81,25 @@ class Headers(MutableMapping, Serializable):
If ``**headers`` contains multiple keys that have equal ``.lower()`` s,
the behavior is undefined.
"""
- self.fields = fields or []
-
- for name, value in self.fields:
- if not isinstance(name, bytes) or not isinstance(value, bytes):
- raise ValueError("Headers passed as fields must be bytes.")
+ super(Headers, self).__init__(fields)
# content_type -> content-type
headers = {
- _always_bytes(name).replace(b"_", b"-"): value
+ _always_bytes(name).replace(b"_", b"-"): _always_bytes(value)
for name, value in six.iteritems(headers)
}
self.update(headers)
+ @staticmethod
+ def _reduce_values(values):
+ # Headers can be folded
+ return ", ".join(values)
+
+ @staticmethod
+ def _kconv(key):
+ # Headers are case-insensitive
+ return key.lower()
+
def __bytes__(self):
if self.fields:
return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n"
@@ -111,98 +109,40 @@ class Headers(MutableMapping, Serializable):
if six.PY2: # pragma: no cover
__str__ = __bytes__
- @_always_byte_args
- def __getitem__(self, name):
- values = self.get_all(name)
- if not values:
- raise KeyError(name)
- return ", ".join(values)
-
- @_always_byte_args
- def __setitem__(self, name, value):
- idx = self._index(name)
-
- # To please the human eye, we insert at the same position the first existing header occured.
- if idx is not None:
- del self[name]
- self.fields.insert(idx, [name, value])
- else:
- self.fields.append([name, value])
-
- @_always_byte_args
- def __delitem__(self, name):
- if name not in self:
- raise KeyError(name)
- name = name.lower()
- self.fields = [
- field for field in self.fields
- if name != field[0].lower()
- ]
+ def __delitem__(self, key):
+ key = _always_bytes(key)
+ super(Headers, self).__delitem__(key)
def __iter__(self):
- seen = set()
- for name, _ in self.fields:
- name_lower = name.lower()
- if name_lower not in seen:
- seen.add(name_lower)
- yield _native(name)
-
- def __len__(self):
- return len(set(name.lower() for name, _ in self.fields))
-
- # __hash__ = object.__hash__
-
- def _index(self, name):
- name = name.lower()
- for i, field in enumerate(self.fields):
- if field[0].lower() == name:
- return i
- return None
-
- def __eq__(self, other):
- if isinstance(other, Headers):
- return self.fields == other.fields
- return False
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- @_always_byte_args
+ for x in super(Headers, self).__iter__():
+ yield _native(x)
+
def get_all(self, name):
"""
Like :py:meth:`get`, but does not fold multiple headers into a single one.
This is useful for Set-Cookie headers, which do not support folding.
-
See also: https://tools.ietf.org/html/rfc7230#section-3.2.2
"""
- name_lower = name.lower()
- values = [_native(value) for n, value in self.fields if n.lower() == name_lower]
- return values
+ name = _always_bytes(name)
+ return [
+ _native(x) for x in
+ super(Headers, self).get_all(name)
+ ]
- @_always_byte_args
def set_all(self, name, values):
"""
Explicitly set multiple headers for the given key.
See: :py:meth:`get_all`
"""
- values = map(_always_bytes, values) # _always_byte_args does not fix lists
- if name in self:
- del self[name]
- self.fields.extend(
- [name, value] for value in values
- )
-
- def get_state(self):
- return tuple(tuple(field) for field in self.fields)
-
- def set_state(self, state):
- self.fields = [list(field) for field in state]
+ name = _always_bytes(name)
+ values = [_always_bytes(x) for x in values]
+ return super(Headers, self).set_all(name, values)
- @classmethod
- def from_state(cls, state):
- return cls([list(field) for field in state])
+ def insert(self, index, key, value):
+ key = _always_bytes(key)
+ value = _always_bytes(value)
+ super(Headers, self).insert(index, key, value)
- @_always_byte_args
def replace(self, pattern, repl, flags=0):
"""
Replaces a regular expression pattern with repl in each "name: value"
@@ -211,6 +151,8 @@ class Headers(MutableMapping, Serializable):
Returns:
The number of replacements made.
"""
+ pattern = _always_bytes(pattern)
+ repl = _always_bytes(repl)
pattern = re.compile(pattern, flags)
replacements = 0
diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py
index 6e3a1b93..d30976bd 100644
--- a/netlib/http/http1/read.py
+++ b/netlib/http/http1/read.py
@@ -316,14 +316,14 @@ def _read_headers(rfile):
if not ret:
raise HttpSyntaxException("Invalid headers")
# continued header
- ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip()
+ ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip())
else:
try:
name, value = line.split(b":", 1)
value = value.strip()
if not name:
raise ValueError()
- ret.append([name, value])
+ ret.append((name, value))
except ValueError:
raise HttpSyntaxException("Invalid headers")
return Headers(ret)
diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py
index f900b67c..6643b6b9 100644
--- a/netlib/http/http2/connections.py
+++ b/netlib/http/http2/connections.py
@@ -201,13 +201,13 @@ class HTTP2Protocol(object):
headers = request.headers.copy()
if ':authority' not in headers:
- headers.fields.insert(0, (b':authority', authority.encode('ascii')))
+ headers.insert(0, b':authority', authority.encode('ascii'))
if ':scheme' not in headers:
- headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii')))
+ headers.insert(0, b':scheme', request.scheme.encode('ascii'))
if ':path' not in headers:
- headers.fields.insert(0, (b':path', request.path.encode('ascii')))
+ headers.insert(0, b':path', request.path.encode('ascii'))
if ':method' not in headers:
- headers.fields.insert(0, (b':method', request.method.encode('ascii')))
+ headers.insert(0, b':method', request.method.encode('ascii'))
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
@@ -224,7 +224,7 @@ class HTTP2Protocol(object):
headers = response.headers.copy()
if ':status' not in headers:
- headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii')))
+ headers.insert(0, b':status', str(response.status_code).encode('ascii'))
if hasattr(response, 'stream_id'):
stream_id = response.stream_id
@@ -420,7 +420,7 @@ class HTTP2Protocol(object):
self._handle_unexpected_frame(frm)
headers = Headers(
- [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)]
+ (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)
)
return stream_id, headers, body
diff --git a/netlib/http/message.py b/netlib/http/message.py
index da9681a0..262ef3e1 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -4,6 +4,7 @@ import warnings
import six
+from ..multidict import MultiDict
from .headers import Headers
from .. import encoding, utils
@@ -235,3 +236,37 @@ class decoded(object):
def __exit__(self, type, value, tb):
if self.ce:
self.message.encode(self.ce)
+
+
+class MessageMultiDict(MultiDict):
+ """
+ A MultiDict that provides a proxy view to the underlying message.
+ """
+
+ def __init__(self, attr, message):
+ if False:
+ # We do not want to call the parent constructor here as that
+ # would cause an unnecessary parse/unparse pass.
+ # This is here to silence linters. Message
+ super(MessageMultiDict, self).__init__(None)
+ self._attr = attr
+ self._message = message # type: Message
+
+ @staticmethod
+ def _kconv(key):
+ # All request-attributes are case-sensitive.
+ return key
+
+ @staticmethod
+ def _reduce_values(values):
+ # We just return the first element if
+ # multiple elements exist with the same key.
+ return values[0]
+
+ @property
+ def fields(self):
+ return getattr(self._message, "_" + self._attr)
+
+ @fields.setter
+ def fields(self, value):
+ setattr(self._message, self._attr, value)
diff --git a/netlib/http/request.py b/netlib/http/request.py
index a42150ff..26ec12cf 100644
--- a/netlib/http/request.py
+++ b/netlib/http/request.py
@@ -11,7 +11,7 @@ from netlib.http import cookies
from netlib.odict import ODict
from .. import encoding
from .headers import Headers
-from .message import Message, _native, _always_bytes, MessageData
+from .message import Message, _native, _always_bytes, MessageData, MessageMultiDict
# This regex extracts & splits the host header into host and port.
# Handles the edge case of IPv6 addresses containing colons.
@@ -224,45 +224,54 @@ class Request(Message):
@property
def query(self):
+ # type: () -> MessageMultiDict
"""
- The request query string as an :py:class:`ODict` object.
- None, if there is no query.
+ The request query string as an :py:class:`MessageMultiDict` object.
"""
+ return MessageMultiDict("query", self)
+
+ @property
+ def _query(self):
_, _, _, _, query, _ = urllib.parse.urlparse(self.url)
- if query:
- return ODict(utils.urldecode(query))
- return None
+ return tuple(utils.urldecode(query))
@query.setter
- def query(self, odict):
- query = utils.urlencode(odict.lst)
+ def query(self, value):
+ query = utils.urlencode(value)
scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
_, _, _, self.path = utils.parse_url(
urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
@property
def cookies(self):
+ # type: () -> MessageMultiDict
"""
The request cookies.
- An empty :py:class:`ODict` object if the cookie monster ate them all.
+
+ An empty :py:class:`MessageMultiDict` object if the cookie monster ate them all.
"""
- ret = ODict()
- for i in self.headers.get_all("Cookie"):
- ret.extend(cookies.parse_cookie_header(i))
- return ret
+ return MessageMultiDict("cookies", self)
+
+ @property
+ def _cookies(self):
+ h = self.headers.get_all("Cookie")
+ return tuple(cookies.parse_cookie_headers(h))
@cookies.setter
- def cookies(self, odict):
- self.headers["cookie"] = cookies.format_cookie_header(odict)
+ def cookies(self, value):
+ self.headers["cookie"] = cookies.format_cookie_header(value)
@property
def path_components(self):
"""
- The URL's path components as a list of strings.
+ The URL's path components as a tuple of strings.
Components are unquoted.
"""
_, _, path, _, _, _ = urllib.parse.urlparse(self.url)
- return [urllib.parse.unquote(i) for i in path.split("/") if i]
+ # This needs to be a tuple so that it's immutable.
+ # Otherwise, this would fail silently:
+ # request.path_components.append("foo")
+ return tuple(urllib.parse.unquote(i) for i in path.split("/") if i)
@path_components.setter
def path_components(self, components):
@@ -309,34 +318,42 @@ class Request(Message):
@property
def urlencoded_form(self):
"""
- The URL-encoded form data as an :py:class:`ODict` object.
- None if there is no data or the content-type indicates non-form data.
+ The URL-encoded form data as an :py:class:`MessageMultiDict` object.
+ None if the content-type indicates non-form data.
"""
is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower()
- if self.content and is_valid_content_type:
- return ODict(utils.urldecode(self.content))
+ if is_valid_content_type:
+ return MessageMultiDict("urlencoded_form", self)
return None
+ @property
+ def _urlencoded_form(self):
+ return tuple(utils.urldecode(self.content))
+
@urlencoded_form.setter
- def urlencoded_form(self, odict):
+ def urlencoded_form(self, value):
"""
Sets the body to the URL-encoded form data, and adds the appropriate content-type header.
This will overwrite the existing content if there is one.
"""
self.headers["content-type"] = "application/x-www-form-urlencoded"
- self.content = utils.urlencode(odict.lst)
+ self.content = utils.urlencode(value)
@property
def multipart_form(self):
"""
- The multipart form data as an :py:class:`ODict` object.
- None if there is no data or the content-type indicates non-form data.
+ The multipart form data as an :py:class:`MultipartFormDict` object.
+ None if the content-type indicates non-form data.
"""
is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower()
- if self.content and is_valid_content_type:
- return ODict(utils.multipartdecode(self.headers,self.content))
+ if is_valid_content_type:
+ return MessageMultiDict("multipart_form", self)
return None
+ @property
+ def _multipart_form(self):
+ return utils.multipartdecode(self.headers, self.content)
+
@multipart_form.setter
def multipart_form(self, value):
raise NotImplementedError()
diff --git a/netlib/http/response.py b/netlib/http/response.py
index 2f06149e..20074dca 100644
--- a/netlib/http/response.py
+++ b/netlib/http/response.py
@@ -70,6 +70,7 @@ class Response(Message):
def reason(self, reason):
self.data.reason = _always_bytes(reason)
+ # FIXME
@property
def cookies(self):
"""
@@ -88,6 +89,7 @@ class Response(Message):
ret.append([name, [value, attrs]])
return ODict(ret)
+ # FIXME
@cookies.setter
def cookies(self, odict):
values = []
diff --git a/netlib/multidict.py b/netlib/multidict.py
new file mode 100644
index 00000000..a7158bc5
--- /dev/null
+++ b/netlib/multidict.py
@@ -0,0 +1,163 @@
+from __future__ import absolute_import, print_function, division
+
+from abc import ABCMeta, abstractmethod
+
+from typing import Tuple
+
+try:
+ from collections.abc import MutableMapping
+except ImportError: # pragma: no cover
+ from collections import MutableMapping # Workaround for Python < 3.3
+
+import six
+
+from .utils import Serializable
+
+
+@six.add_metaclass(ABCMeta)
+class MultiDict(MutableMapping, Serializable):
+ def __init__(self, fields=None):
+
+ # it is important for us that .fields is immutable, so that we can easily
+ # detect changes to it.
+ self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
+
+ for key, value in self.fields:
+ if not isinstance(key, bytes) or not isinstance(value, bytes):
+ raise TypeError("MultiDict fields must be bytes.")
+
+ def __repr__(self):
+ fields = tuple(
+ repr(field)
+ for field in self.fields
+ )
+ return "{cls}[{fields}]".format(
+ cls=type(self).__name__,
+ fields=", ".join(fields)
+ )
+
+ @staticmethod
+ @abstractmethod
+ def _reduce_values(values):
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def _kconv(v):
+ pass
+
+ def __getitem__(self, key):
+ values = self.get_all(key)
+ if not values:
+ raise KeyError(key)
+ return self._reduce_values(values)
+
+ def __setitem__(self, key, value):
+ self.set_all(key, [value])
+
+ def __delitem__(self, key):
+ if key not in self:
+ raise KeyError(key)
+ key = self._kconv(key)
+ self.fields = tuple(
+ field for field in self.fields
+ if key != self._kconv(field[0])
+ )
+
+ def __iter__(self):
+ seen = set()
+ for key, _ in self.fields:
+ key_kconv = self._kconv(key)
+ if key_kconv not in seen:
+ seen.add(key_kconv)
+ yield key
+
+ def __len__(self):
+ return len(set(self._kconv(key) for key, _ in self.fields))
+
+ def __eq__(self, other):
+ if isinstance(other, MultiDict):
+ return self.fields == other.fields
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def get_all(self, key):
+ """
+ Return the list of items for a given key.
+ If that key is not in the MultiDict,
+ the return value will be an empty list.
+ """
+ key = self._kconv(key)
+ return [
+ value
+ for k, value in self.fields
+ if self._kconv(k) == key
+ ]
+
+ def set_all(self, key, values):
+ """
+ Remove the old values for a key and add new ones.
+ """
+ key_kconv = self._kconv(key)
+
+ new_fields = []
+ for field in self.fields:
+ if self._kconv(field[0]) == key_kconv:
+ if values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ else:
+ new_fields.append(field)
+ while values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ self.fields = tuple(new_fields)
+
+ def add(self, key, value):
+ self.insert(len(self.fields), key, value)
+
+ def insert(self, index, key, value):
+ item = (key, value)
+ self.fields = self.fields[:index] + (item,) + self.fields[index:]
+
+ def keys(self, multi=False):
+ return (
+ k
+ for k, _ in self.items(multi)
+ )
+
+ def values(self, multi=False):
+ return (
+ v
+ for _, v in self.items(multi)
+ )
+
+ def items(self, multi=False):
+ if multi:
+ return self.fields
+ else:
+ return super(MultiDict, self).items()
+
+ def to_dict(self):
+ d = {}
+ for key in self:
+ values = self.get_all(key)
+ if len(values) == 1:
+ d[key] = values[0]
+ else:
+ d[key] = values
+ return d
+
+ def get_state(self):
+ return self.fields
+
+ def set_state(self, state):
+ self.fields = tuple(tuple(x) for x in state)
+
+ @classmethod
+ def from_state(cls, state):
+ return cls(tuple(x) for x in state)
diff --git a/netlib/utils.py b/netlib/utils.py
index be2701a0..7499f71f 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -51,17 +51,6 @@ def always_bytes(unicode_or_bytes, *encode_args):
return unicode_or_bytes
-def always_byte_args(*encode_args):
- """Decorator that transparently encodes all arguments passed as unicode"""
- def decorator(fun):
- def _fun(*args, **kwargs):
- args = [always_bytes(arg, *encode_args) for arg in args]
- kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)}
- return fun(*args, **kwargs)
- return _fun
- return decorator
-
-
def native(s, *encoding_opts):
"""
Convert :py:class:`bytes` or :py:class:`unicode` to the native