diff options
author | Maximilian Hils <git@maximilianhils.com> | 2016-05-18 18:46:42 -0700 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2016-05-18 18:46:42 -0700 |
commit | 44ac64aa7235362acbb96e0f12aa27534580e575 (patch) | |
tree | c03b8c3519c273a4f42b60cb2bce8cc0dd524925 | |
parent | 4c3fb8f5097fad2c5de96104dae3f8026b0b4666 (diff) | |
download | mitmproxy-44ac64aa7235362acbb96e0f12aa27534580e575.tar.gz mitmproxy-44ac64aa7235362acbb96e0f12aa27534580e575.tar.bz2 mitmproxy-44ac64aa7235362acbb96e0f12aa27534580e575.zip |
add MultiDict
This commit introduces MultiDict, a multi-dictionary similar to
ODict, but with improved semantics (as in the Headers class).
MultiDict fixes a few issues that were present in the Request/Response
API. In particular, `request.cookies["foo"] = "bar"` has previously been a
no-op, as the cookies property returned a mutable _copy_ of the cookies.
-rw-r--r-- | examples/modify_form.py | 11 | ||||
-rw-r--r-- | examples/modify_querystring.py | 5 | ||||
-rw-r--r-- | mitmproxy/flow.py | 8 | ||||
-rw-r--r-- | mitmproxy/flow_export.py | 4 | ||||
-rw-r--r-- | netlib/encoding.py | 1 | ||||
-rw-r--r-- | netlib/http/cookies.py | 17 | ||||
-rw-r--r-- | netlib/http/headers.py | 132 | ||||
-rw-r--r-- | netlib/http/http1/read.py | 4 | ||||
-rw-r--r-- | netlib/http/http2/connections.py | 12 | ||||
-rw-r--r-- | netlib/http/message.py | 35 | ||||
-rw-r--r-- | netlib/http/request.py | 71 | ||||
-rw-r--r-- | netlib/http/response.py | 2 | ||||
-rw-r--r-- | netlib/multidict.py | 163 | ||||
-rw-r--r-- | netlib/utils.py | 11 | ||||
-rw-r--r-- | test/mitmproxy/test_examples.py | 12 | ||||
-rw-r--r-- | test/mitmproxy/test_flow.py | 54 | ||||
-rw-r--r-- | test/mitmproxy/test_flow_export.py | 2 | ||||
-rw-r--r-- | test/netlib/http/http1/test_read.py | 8 | ||||
-rw-r--r-- | test/netlib/http/http2/test_connections.py | 6 | ||||
-rw-r--r-- | test/netlib/http/test_cookies.py | 4 | ||||
-rw-r--r-- | test/netlib/http/test_headers.py | 10 | ||||
-rw-r--r-- | test/netlib/http/test_request.py | 61 | ||||
-rw-r--r-- | test/netlib/http/test_response.py | 2 |
23 files changed, 370 insertions, 265 deletions
diff --git a/examples/modify_form.py b/examples/modify_form.py index 86188781..c4edb2cd 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,5 +1,8 @@ def request(context, flow): - form = flow.request.urlencoded_form - if form is not None: - form["mitmproxy"] = ["rocks"] - flow.request.urlencoded_form = form + if flow.request.urlencoded_form is not None: + flow.request.urlencoded_form["mitmproxy"] = "rocks" + else: + # This sets the proper content type and overrides the body. + flow.request.urlencoded_form = [ + ("foo", "bar") + ] diff --git a/examples/modify_querystring.py b/examples/modify_querystring.py index d682df69..b89e5c8d 100644 --- a/examples/modify_querystring.py +++ b/examples/modify_querystring.py @@ -1,5 +1,2 @@ def request(context, flow): - q = flow.request.query - if q: - q["mitmproxy"] = ["rocks"] - flow.request.query = q + flow.request.query["mitmproxy"] = "rocks" diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 7fd97af3..4663144d 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -156,9 +156,9 @@ class SetHeaders: for _, header, value, cpatt in self.lst: if cpatt(f): if f.response: - f.response.headers.fields.append((header, value)) + f.response.headers.add(header, value) else: - f.request.headers.fields.append((header, value)) + f.request.headers.add(header, value) class StreamLargeBodies(object): @@ -263,7 +263,7 @@ class ServerPlaybackState: form_contents = r.urlencoded_form or r.multipart_form if self.ignore_payload_params and form_contents: key.extend( - p for p in form_contents + p for p in form_contents.items(multi=True) if p[0] not in self.ignore_payload_params ) else: @@ -354,7 +354,7 @@ class StickyCookieState: ] if all(match): c = self.jar[i] - l.extend([cookies.format_cookie_header(c[name]) for name in c.keys()]) + l.extend([cookies.format_cookie_header(c[name].lst) for name in c.keys()]) if l: f.request.stickycookie = True f.request.headers["cookie"] = "; ".join(l) diff --git a/mitmproxy/flow_export.py b/mitmproxy/flow_export.py index d8e65704..ae282fce 100644 --- a/mitmproxy/flow_export.py +++ b/mitmproxy/flow_export.py @@ -51,7 +51,7 @@ def python_code(flow): params = "" if flow.request.query: - lines = [" '%s': '%s',\n" % (k, v) for k, v in flow.request.query] + lines = [" %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()] params = "\nparams = {\n%s}\n" % "".join(lines) args += "\n params=params," @@ -140,7 +140,7 @@ def locust_code(flow): params = "" if flow.request.query: - lines = [" '%s': '%s',\n" % (k, v) for k, v in flow.request.query] + lines = [" %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()] params = "\n params = {\n%s }\n" % "".join(lines) args += "\n params=params," diff --git a/netlib/encoding.py b/netlib/encoding.py index 14479e00..98502451 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -5,7 +5,6 @@ from __future__ import absolute_import from io import BytesIO import gzip import zlib -from .utils import always_byte_args ENCODINGS = {"identity", "gzip", "deflate"} diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index 4451f1da..fd531146 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -1,6 +1,4 @@ -from six.moves import http_cookies as Cookie import re -import string from email.utils import parsedate_tz, formatdate, mktime_tz from .. import odict @@ -179,20 +177,27 @@ def format_set_cookie_header(name, value, attrs): return _format_set_cookie_pairs(pairs) +def parse_cookie_headers(cookie_headers): + cookie_list = [] + for header in cookie_headers: + cookie_list.extend(parse_cookie_header(header)) + return cookie_list + + def parse_cookie_header(line): """ Parse a Cookie header value. - Returns a (possibly empty) ODict object. + Returns a list of (lhs, rhs) tuples. """ pairs, off_ = _read_pairs(line) - return odict.ODict(pairs) + return pairs -def format_cookie_header(od): +def format_cookie_header(lst): """ Formats a Cookie header value. """ - return _format_pairs(od.lst) + return _format_pairs(lst) def refresh_set_cookie_header(c, delta): diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 72739f90..7e39c371 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -1,9 +1,3 @@ -""" - -Unicode Handling ----------------- -See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ -""" from __future__ import absolute_import, print_function, division import re @@ -13,23 +7,22 @@ try: except ImportError: # pragma: no cover from collections import MutableMapping # Workaround for Python < 3.3 - import six +from ..multidict import MultiDict +from ..utils import always_bytes -from netlib.utils import always_byte_args, always_bytes, Serializable +# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ if six.PY2: # pragma: no cover _native = lambda x: x _always_bytes = lambda x: x - _always_byte_args = lambda x: x else: # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. _native = lambda x: x.decode("utf-8", "surrogateescape") _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") - _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping, Serializable): +class Headers(MultiDict): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. @@ -49,7 +42,7 @@ class Headers(MutableMapping, Serializable): >>> h["host"] "example.com" - # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples + # Headers can also be created from a list of raw (header_name, header_value) byte tuples >>> h = Headers([ [b"Host",b"example.com"], [b"Accept",b"text/html"], @@ -77,7 +70,6 @@ class Headers(MutableMapping, Serializable): For use with the "Set-Cookie" header, see :py:meth:`get_all`. """ - @_always_byte_args def __init__(self, fields=None, **headers): """ Args: @@ -89,19 +81,25 @@ class Headers(MutableMapping, Serializable): If ``**headers`` contains multiple keys that have equal ``.lower()`` s, the behavior is undefined. """ - self.fields = fields or [] - - for name, value in self.fields: - if not isinstance(name, bytes) or not isinstance(value, bytes): - raise ValueError("Headers passed as fields must be bytes.") + super(Headers, self).__init__(fields) # content_type -> content-type headers = { - _always_bytes(name).replace(b"_", b"-"): value + _always_bytes(name).replace(b"_", b"-"): _always_bytes(value) for name, value in six.iteritems(headers) } self.update(headers) + @staticmethod + def _reduce_values(values): + # Headers can be folded + return ", ".join(values) + + @staticmethod + def _kconv(key): + # Headers are case-insensitive + return key.lower() + def __bytes__(self): if self.fields: return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" @@ -111,98 +109,40 @@ class Headers(MutableMapping, Serializable): if six.PY2: # pragma: no cover __str__ = __bytes__ - @_always_byte_args - def __getitem__(self, name): - values = self.get_all(name) - if not values: - raise KeyError(name) - return ", ".join(values) - - @_always_byte_args - def __setitem__(self, name, value): - idx = self._index(name) - - # To please the human eye, we insert at the same position the first existing header occured. - if idx is not None: - del self[name] - self.fields.insert(idx, [name, value]) - else: - self.fields.append([name, value]) - - @_always_byte_args - def __delitem__(self, name): - if name not in self: - raise KeyError(name) - name = name.lower() - self.fields = [ - field for field in self.fields - if name != field[0].lower() - ] + def __delitem__(self, key): + key = _always_bytes(key) + super(Headers, self).__delitem__(key) def __iter__(self): - seen = set() - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - yield _native(name) - - def __len__(self): - return len(set(name.lower() for name, _ in self.fields)) - - # __hash__ = object.__hash__ - - def _index(self, name): - name = name.lower() - for i, field in enumerate(self.fields): - if field[0].lower() == name: - return i - return None - - def __eq__(self, other): - if isinstance(other, Headers): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - @_always_byte_args + for x in super(Headers, self).__iter__(): + yield _native(x) + def get_all(self, name): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. This is useful for Set-Cookie headers, which do not support folding. - See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 """ - name_lower = name.lower() - values = [_native(value) for n, value in self.fields if n.lower() == name_lower] - return values + name = _always_bytes(name) + return [ + _native(x) for x in + super(Headers, self).get_all(name) + ] - @_always_byte_args def set_all(self, name, values): """ Explicitly set multiple headers for the given key. See: :py:meth:`get_all` """ - values = map(_always_bytes, values) # _always_byte_args does not fix lists - if name in self: - del self[name] - self.fields.extend( - [name, value] for value in values - ) - - def get_state(self): - return tuple(tuple(field) for field in self.fields) - - def set_state(self, state): - self.fields = [list(field) for field in state] + name = _always_bytes(name) + values = [_always_bytes(x) for x in values] + return super(Headers, self).set_all(name, values) - @classmethod - def from_state(cls, state): - return cls([list(field) for field in state]) + def insert(self, index, key, value): + key = _always_bytes(key) + value = _always_bytes(value) + super(Headers, self).insert(index, key, value) - @_always_byte_args def replace(self, pattern, repl, flags=0): """ Replaces a regular expression pattern with repl in each "name: value" @@ -211,6 +151,8 @@ class Headers(MutableMapping, Serializable): Returns: The number of replacements made. """ + pattern = _always_bytes(pattern) + repl = _always_bytes(repl) pattern = re.compile(pattern, flags) replacements = 0 diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 6e3a1b93..d30976bd 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -316,14 +316,14 @@ def _read_headers(rfile): if not ret: raise HttpSyntaxException("Invalid headers") # continued header - ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() + ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip()) else: try: name, value = line.split(b":", 1) value = value.strip() if not name: raise ValueError() - ret.append([name, value]) + ret.append((name, value)) except ValueError: raise HttpSyntaxException("Invalid headers") return Headers(ret) diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index f900b67c..6643b6b9 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -201,13 +201,13 @@ class HTTP2Protocol(object): headers = request.headers.copy() if ':authority' not in headers: - headers.fields.insert(0, (b':authority', authority.encode('ascii'))) + headers.insert(0, b':authority', authority.encode('ascii')) if ':scheme' not in headers: - headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) + headers.insert(0, b':scheme', request.scheme.encode('ascii')) if ':path' not in headers: - headers.fields.insert(0, (b':path', request.path.encode('ascii'))) + headers.insert(0, b':path', request.path.encode('ascii')) if ':method' not in headers: - headers.fields.insert(0, (b':method', request.method.encode('ascii'))) + headers.insert(0, b':method', request.method.encode('ascii')) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -224,7 +224,7 @@ class HTTP2Protocol(object): headers = response.headers.copy() if ':status' not in headers: - headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) + headers.insert(0, b':status', str(response.status_code).encode('ascii')) if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -420,7 +420,7 @@ class HTTP2Protocol(object): self._handle_unexpected_frame(frm) headers = Headers( - [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] + (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks) ) return stream_id, headers, body diff --git a/netlib/http/message.py b/netlib/http/message.py index da9681a0..262ef3e1 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,6 +4,7 @@ import warnings import six +from ..multidict import MultiDict from .headers import Headers from .. import encoding, utils @@ -235,3 +236,37 @@ class decoded(object): def __exit__(self, type, value, tb): if self.ce: self.message.encode(self.ce) + + +class MessageMultiDict(MultiDict): + """ + A MultiDict that provides a proxy view to the underlying message. + """ + + def __init__(self, attr, message): + if False: + # We do not want to call the parent constructor here as that + # would cause an unnecessary parse/unparse pass. + # This is here to silence linters. Message + super(MessageMultiDict, self).__init__(None) + self._attr = attr + self._message = message # type: Message + + @staticmethod + def _kconv(key): + # All request-attributes are case-sensitive. + return key + + @staticmethod + def _reduce_values(values): + # We just return the first element if + # multiple elements exist with the same key. + return values[0] + + @property + def fields(self): + return getattr(self._message, "_" + self._attr) + + @fields.setter + def fields(self, value): + setattr(self._message, self._attr, value) diff --git a/netlib/http/request.py b/netlib/http/request.py index a42150ff..26ec12cf 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -11,7 +11,7 @@ from netlib.http import cookies from netlib.odict import ODict from .. import encoding from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData +from .message import Message, _native, _always_bytes, MessageData, MessageMultiDict # This regex extracts & splits the host header into host and port. # Handles the edge case of IPv6 addresses containing colons. @@ -224,45 +224,54 @@ class Request(Message): @property def query(self): + # type: () -> MessageMultiDict """ - The request query string as an :py:class:`ODict` object. - None, if there is no query. + The request query string as an :py:class:`MessageMultiDict` object. """ + return MessageMultiDict("query", self) + + @property + def _query(self): _, _, _, _, query, _ = urllib.parse.urlparse(self.url) - if query: - return ODict(utils.urldecode(query)) - return None + return tuple(utils.urldecode(query)) @query.setter - def query(self, odict): - query = utils.urlencode(odict.lst) + def query(self, value): + query = utils.urlencode(value) scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) _, _, _, self.path = utils.parse_url( urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) @property def cookies(self): + # type: () -> MessageMultiDict """ The request cookies. - An empty :py:class:`ODict` object if the cookie monster ate them all. + + An empty :py:class:`MessageMultiDict` object if the cookie monster ate them all. """ - ret = ODict() - for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret + return MessageMultiDict("cookies", self) + + @property + def _cookies(self): + h = self.headers.get_all("Cookie") + return tuple(cookies.parse_cookie_headers(h)) @cookies.setter - def cookies(self, odict): - self.headers["cookie"] = cookies.format_cookie_header(odict) + def cookies(self, value): + self.headers["cookie"] = cookies.format_cookie_header(value) @property def path_components(self): """ - The URL's path components as a list of strings. + The URL's path components as a tuple of strings. Components are unquoted. """ _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(i) for i in path.split("/") if i] + # This needs to be a tuple so that it's immutable. + # Otherwise, this would fail silently: + # request.path_components.append("foo") + return tuple(urllib.parse.unquote(i) for i in path.split("/") if i) @path_components.setter def path_components(self, components): @@ -309,34 +318,42 @@ class Request(Message): @property def urlencoded_form(self): """ - The URL-encoded form data as an :py:class:`ODict` object. - None if there is no data or the content-type indicates non-form data. + The URL-encoded form data as an :py:class:`MessageMultiDict` object. + None if the content-type indicates non-form data. """ is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() - if self.content and is_valid_content_type: - return ODict(utils.urldecode(self.content)) + if is_valid_content_type: + return MessageMultiDict("urlencoded_form", self) return None + @property + def _urlencoded_form(self): + return tuple(utils.urldecode(self.content)) + @urlencoded_form.setter - def urlencoded_form(self, odict): + def urlencoded_form(self, value): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. This will overwrite the existing content if there is one. """ self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = utils.urlencode(odict.lst) + self.content = utils.urlencode(value) @property def multipart_form(self): """ - The multipart form data as an :py:class:`ODict` object. - None if there is no data or the content-type indicates non-form data. + The multipart form data as an :py:class:`MultipartFormDict` object. + None if the content-type indicates non-form data. """ is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() - if self.content and is_valid_content_type: - return ODict(utils.multipartdecode(self.headers,self.content)) + if is_valid_content_type: + return MessageMultiDict("multipart_form", self) return None + @property + def _multipart_form(self): + return utils.multipartdecode(self.headers, self.content) + @multipart_form.setter def multipart_form(self, value): raise NotImplementedError() diff --git a/netlib/http/response.py b/netlib/http/response.py index 2f06149e..20074dca 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -70,6 +70,7 @@ class Response(Message): def reason(self, reason): self.data.reason = _always_bytes(reason) + # FIXME @property def cookies(self): """ @@ -88,6 +89,7 @@ class Response(Message): ret.append([name, [value, attrs]]) return ODict(ret) + # FIXME @cookies.setter def cookies(self, odict): values = [] diff --git a/netlib/multidict.py b/netlib/multidict.py new file mode 100644 index 00000000..a7158bc5 --- /dev/null +++ b/netlib/multidict.py @@ -0,0 +1,163 @@ +from __future__ import absolute_import, print_function, division
+
+from abc import ABCMeta, abstractmethod
+
+from typing import Tuple
+
+try:
+ from collections.abc import MutableMapping
+except ImportError: # pragma: no cover
+ from collections import MutableMapping # Workaround for Python < 3.3
+
+import six
+
+from .utils import Serializable
+
+
+@six.add_metaclass(ABCMeta)
+class MultiDict(MutableMapping, Serializable):
+ def __init__(self, fields=None):
+
+ # it is important for us that .fields is immutable, so that we can easily
+ # detect changes to it.
+ self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
+
+ for key, value in self.fields:
+ if not isinstance(key, bytes) or not isinstance(value, bytes):
+ raise TypeError("MultiDict fields must be bytes.")
+
+ def __repr__(self):
+ fields = tuple(
+ repr(field)
+ for field in self.fields
+ )
+ return "{cls}[{fields}]".format(
+ cls=type(self).__name__,
+ fields=", ".join(fields)
+ )
+
+ @staticmethod
+ @abstractmethod
+ def _reduce_values(values):
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def _kconv(v):
+ pass
+
+ def __getitem__(self, key):
+ values = self.get_all(key)
+ if not values:
+ raise KeyError(key)
+ return self._reduce_values(values)
+
+ def __setitem__(self, key, value):
+ self.set_all(key, [value])
+
+ def __delitem__(self, key):
+ if key not in self:
+ raise KeyError(key)
+ key = self._kconv(key)
+ self.fields = tuple(
+ field for field in self.fields
+ if key != self._kconv(field[0])
+ )
+
+ def __iter__(self):
+ seen = set()
+ for key, _ in self.fields:
+ key_kconv = self._kconv(key)
+ if key_kconv not in seen:
+ seen.add(key_kconv)
+ yield key
+
+ def __len__(self):
+ return len(set(self._kconv(key) for key, _ in self.fields))
+
+ def __eq__(self, other):
+ if isinstance(other, MultiDict):
+ return self.fields == other.fields
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def get_all(self, key):
+ """
+ Return the list of items for a given key.
+ If that key is not in the MultiDict,
+ the return value will be an empty list.
+ """
+ key = self._kconv(key)
+ return [
+ value
+ for k, value in self.fields
+ if self._kconv(k) == key
+ ]
+
+ def set_all(self, key, values):
+ """
+ Remove the old values for a key and add new ones.
+ """
+ key_kconv = self._kconv(key)
+
+ new_fields = []
+ for field in self.fields:
+ if self._kconv(field[0]) == key_kconv:
+ if values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ else:
+ new_fields.append(field)
+ while values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ self.fields = tuple(new_fields)
+
+ def add(self, key, value):
+ self.insert(len(self.fields), key, value)
+
+ def insert(self, index, key, value):
+ item = (key, value)
+ self.fields = self.fields[:index] + (item,) + self.fields[index:]
+
+ def keys(self, multi=False):
+ return (
+ k
+ for k, _ in self.items(multi)
+ )
+
+ def values(self, multi=False):
+ return (
+ v
+ for _, v in self.items(multi)
+ )
+
+ def items(self, multi=False):
+ if multi:
+ return self.fields
+ else:
+ return super(MultiDict, self).items()
+
+ def to_dict(self):
+ d = {}
+ for key in self:
+ values = self.get_all(key)
+ if len(values) == 1:
+ d[key] = values[0]
+ else:
+ d[key] = values
+ return d
+
+ def get_state(self):
+ return self.fields
+
+ def set_state(self, state):
+ self.fields = tuple(tuple(x) for x in state)
+
+ @classmethod
+ def from_state(cls, state):
+ return cls(tuple(x) for x in state)
diff --git a/netlib/utils.py b/netlib/utils.py index be2701a0..7499f71f 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -51,17 +51,6 @@ def always_bytes(unicode_or_bytes, *encode_args): return unicode_or_bytes -def always_byte_args(*encode_args): - """Decorator that transparently encodes all arguments passed as unicode""" - def decorator(fun): - def _fun(*args, **kwargs): - args = [always_bytes(arg, *encode_args) for arg in args] - kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} - return fun(*args, **kwargs) - return _fun - return decorator - - def native(s, *encoding_opts): """ Convert :py:class:`bytes` or :py:class:`unicode` to the native diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index c401a6b9..d0a258e9 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -94,14 +94,22 @@ def test_modify_form(): flow = tutils.tflow(req=netutils.treq(headers=form_header)) with example("modify_form.py") as ex: ex.run("request", flow) - assert flow.request.urlencoded_form["mitmproxy"] == ["rocks"] + assert flow.request.urlencoded_form["mitmproxy"] == "rocks" + + flow.request.headers["content-type"] = "" + ex.run("request", flow) + assert list(flow.request.urlencoded_form.items()) == [("foo","bar")] def test_modify_querystring(): flow = tutils.tflow(req=netutils.treq(path="/search?q=term")) with example("modify_querystring.py") as ex: ex.run("request", flow) - assert flow.request.query["mitmproxy"] == ["rocks"] + assert flow.request.query["mitmproxy"] == "rocks" + + flow.request.path = "/" + ex.run("request", flow) + assert flow.request.query["mitmproxy"] == "rocks" def test_modify_response_body(): diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index b9c6a2f6..bf417423 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -1067,60 +1067,6 @@ class TestRequest: assert r.url == "https://address:22/path" assert r.pretty_url == "https://foo.com:22/path" - def test_path_components(self): - r = HTTPRequest.wrap(netlib.tutils.treq()) - r.path = "/" - assert r.get_path_components() == [] - r.path = "/foo/bar" - assert r.get_path_components() == ["foo", "bar"] - q = odict.ODict() - q["test"] = ["123"] - r.set_query(q) - assert r.get_path_components() == ["foo", "bar"] - - r.set_path_components([]) - assert r.get_path_components() == [] - r.set_path_components(["foo"]) - assert r.get_path_components() == ["foo"] - r.set_path_components(["/oo"]) - assert r.get_path_components() == ["/oo"] - assert "%2F" in r.path - - def test_getset_form_urlencoded(self): - d = odict.ODict([("one", "two"), ("three", "four")]) - r = HTTPRequest.wrap(netlib.tutils.treq(content=netlib.utils.urlencode(d.lst))) - r.headers["content-type"] = "application/x-www-form-urlencoded" - assert r.get_form_urlencoded() == d - - d = odict.ODict([("x", "y")]) - r.set_form_urlencoded(d) - assert r.get_form_urlencoded() == d - - r.headers["content-type"] = "foo" - assert not r.get_form_urlencoded() - - def test_getset_query(self): - r = HTTPRequest.wrap(netlib.tutils.treq()) - r.path = "/foo?x=y&a=b" - q = r.get_query() - assert q.lst == [("x", "y"), ("a", "b")] - - r.path = "/" - q = r.get_query() - assert not q - - r.path = "/?adsfa" - q = r.get_query() - assert q.lst == [("adsfa", "")] - - r.path = "/foo?x=y&a=b" - assert r.get_query() - r.set_query(odict.ODict([])) - assert not r.get_query() - qv = odict.ODict([("a", "b"), ("c", "d")]) - r.set_query(qv) - assert r.get_query() == qv - def test_anticache(self): r = HTTPRequest.wrap(netlib.tutils.treq()) r.headers = Headers() diff --git a/test/mitmproxy/test_flow_export.py b/test/mitmproxy/test_flow_export.py index 035f07b7..2b1f897c 100644 --- a/test/mitmproxy/test_flow_export.py +++ b/test/mitmproxy/test_flow_export.py @@ -21,7 +21,7 @@ def python_equals(testdata, text): assert clean_blanks(text).rstrip() == clean_blanks(d).rstrip() -req_get = lambda: netlib.tutils.treq(method='GET', content='') +req_get = lambda: netlib.tutils.treq(method='GET', content='', path=b"/") req_post = lambda: netlib.tutils.treq(method='POST', headers=None) diff --git a/test/netlib/http/http1/test_read.py b/test/netlib/http/http1/test_read.py index 90234070..d8106904 100644 --- a/test/netlib/http/http1/test_read.py +++ b/test/netlib/http/http1/test_read.py @@ -261,7 +261,7 @@ class TestReadHeaders(object): b"\r\n" ) headers = self._read(data) - assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]] + assert headers.fields == ((b"Header", b"one"), (b"Header2", b"two")) def test_read_multi(self): data = ( @@ -270,7 +270,7 @@ class TestReadHeaders(object): b"\r\n" ) headers = self._read(data) - assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]] + assert headers.fields == ((b"Header", b"one"), (b"Header", b"two")) def test_read_continued(self): data = ( @@ -280,7 +280,7 @@ class TestReadHeaders(object): b"\r\n" ) headers = self._read(data) - assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]] + assert headers.fields == ((b"Header", b"one\r\n two"), (b"Header2", b"three")) def test_read_continued_err(self): data = b"\tfoo: bar\r\n" @@ -300,7 +300,7 @@ class TestReadHeaders(object): def test_read_empty_value(self): data = b"bar:" headers = self._read(data) - assert headers.fields == [[b"bar", b""]] + assert headers.fields == ((b"bar", b""),) def test_read_chunked(): req = treq(content=None) diff --git a/test/netlib/http/http2/test_connections.py b/test/netlib/http/http2/test_connections.py index 7b003067..7d240c0e 100644 --- a/test/netlib/http/http2/test_connections.py +++ b/test/netlib/http/http2/test_connections.py @@ -312,7 +312,7 @@ class TestReadRequest(tservers.ServerTestBase): req = protocol.read_request(NotImplemented) assert req.stream_id - assert req.headers.fields == [[b':method', b'GET'], [b':path', b'/'], [b':scheme', b'https']] + assert req.headers.fields == ((b':method', b'GET'), (b':path', b'/'), (b':scheme', b'https')) assert req.content == b'foobar' @@ -418,7 +418,7 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.http_version == "HTTP/2.0" assert resp.status_code == 200 assert resp.reason == '' - assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] + assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) assert resp.content == b'foobar' assert resp.timestamp_end @@ -445,7 +445,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.http_version == "HTTP/2.0" assert resp.status_code == 200 assert resp.reason == '' - assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] + assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) assert resp.content == b'' diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py index da28850f..e2cee57f 100644 --- a/test/netlib/http/test_cookies.py +++ b/test/netlib/http/test_cookies.py @@ -128,10 +128,10 @@ def test_cookie_roundtrips(): ] for s, lst in pairs: ret = cookies.parse_cookie_header(s) - assert ret.lst == lst + assert ret == lst s2 = cookies.format_cookie_header(ret) ret = cookies.parse_cookie_header(s2) - assert ret.lst == lst + assert ret == lst def test_parse_set_cookie_pairs(): diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index 8c1db9dc..48d3b323 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -5,10 +5,10 @@ from netlib.tutils import raises class TestHeaders(object): def _2host(self): return Headers( - [ - [b"Host", b"example.com"], - [b"host", b"example.org"] - ] + ( + (b"Host", b"example.com"), + (b"host", b"example.org") + ) ) def test_init(self): @@ -38,7 +38,7 @@ class TestHeaders(object): assert headers["Host"] == "example.com" assert headers["Accept"] == "text/plain" - with raises(ValueError): + with raises(TypeError): Headers([[b"Host", u"not-bytes"]]) def test_getitem(self): diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index 7ed6bd0f..26593ee1 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -12,7 +12,7 @@ from .test_message import _test_decoded_attr, _test_passthrough_attr class TestRequestData(object): def test_init(self): - with raises(ValueError if six.PY2 else TypeError): + with raises(ValueError): treq(headers="foobar") assert isinstance(treq(headers=None).headers, Headers) @@ -158,16 +158,17 @@ class TestRequestUtils(object): def test_get_query(self): request = treq() - assert request.query is None + assert not request.query request.url = "http://localhost:80/foo?bar=42" - assert request.query.lst == [("bar", "42")] + assert dict(request.query) == {"bar": "42"} def test_set_query(self): - request = treq(host=b"foo", headers = Headers(host=b"bar")) - request.query = ODict([]) - assert request.host == "foo" - assert request.headers["host"] == "bar" + request = treq() + assert not request.query + request.query["foo"] = "bar" + assert request.query["foo"] == "bar" + assert request.path == "/path?foo=bar" def test_get_cookies_none(self): request = treq() @@ -177,47 +178,50 @@ class TestRequestUtils(object): def test_get_cookies_single(self): request = treq() request.headers = Headers(cookie="cookiename=cookievalue") - result = request.cookies - assert len(result) == 1 - assert result['cookiename'] == ['cookievalue'] + assert len(request.cookies) == 1 + assert request.cookies['cookiename'] == 'cookievalue' def test_get_cookies_double(self): request = treq() request.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") result = request.cookies assert len(result) == 2 - assert result['cookiename'] == ['cookievalue'] - assert result['othercookiename'] == ['othercookievalue'] + assert result['cookiename'] == 'cookievalue' + assert result['othercookiename'] == 'othercookievalue' def test_get_cookies_withequalsign(self): request = treq() request.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") result = request.cookies assert len(result) == 2 - assert result['cookiename'] == ['coo=kievalue'] - assert result['othercookiename'] == ['othercookievalue'] + assert result['cookiename'] == 'coo=kievalue' + assert result['othercookiename'] == 'othercookievalue' def test_set_cookies(self): request = treq() request.headers = Headers(cookie="cookiename=cookievalue") result = request.cookies - result["cookiename"] = ["foo"] - request.cookies = result - assert request.cookies["cookiename"] == ["foo"] + result["cookiename"] = "foo" + assert request.cookies["cookiename"] == "foo" def test_get_path_components(self): request = treq(path=b"/foo/bar") - assert request.path_components == ["foo", "bar"] + assert request.path_components == ("foo", "bar") def test_set_path_components(self): - request = treq(host=b"foo", headers = Headers(host=b"bar")) + request = treq() request.path_components = ["foo", "baz"] assert request.path == "/foo/baz" + request.path_components = [] assert request.path == "/" - request.query = ODict([]) - assert request.host == "foo" - assert request.headers["host"] == "bar" + + request.path_components = ["foo", "baz"] + request.query["hello"] = "hello" + assert request.path_components == ("foo", "baz") + + request.path_components = ["abc"] + assert request.path == "/abc?hello=hello" def test_anticache(self): request = treq() @@ -246,15 +250,15 @@ class TestRequestUtils(object): assert "gzip" in request.headers["Accept-Encoding"] def test_get_urlencoded_form(self): - request = treq(content="foobar") + request = treq(content="foobar=baz") assert request.urlencoded_form is None request.headers["Content-Type"] = "application/x-www-form-urlencoded" - assert request.urlencoded_form == ODict(utils.urldecode(request.content)) + assert list(request.urlencoded_form.items()) == [("foobar", "baz")] def test_set_urlencoded_form(self): request = treq() - request.urlencoded_form = ODict([('foo', 'bar'), ('rab', 'oof')]) + request.urlencoded_form = [('foo', 'bar'), ('rab', 'oof')] assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" assert request.content @@ -263,9 +267,4 @@ class TestRequestUtils(object): assert request.multipart_form is None request.headers["Content-Type"] = "multipart/form-data" - assert request.multipart_form == ODict( - utils.multipartdecode( - request.headers, - request.content - ) - ) + assert list(request.multipart_form.items()) == [] diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py index 5440176c..37273541 100644 --- a/test/netlib/http/test_response.py +++ b/test/netlib/http/test_response.py @@ -13,7 +13,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr class TestResponseData(object): def test_init(self): - with raises(ValueError if six.PY2 else TypeError): + with raises(ValueError): tresp(headers="foobar") assert isinstance(tresp(headers=None).headers, Headers) |