diff options
author | Maximilian Hils <git@maximilianhils.com> | 2016-05-18 22:50:19 -0700 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2016-05-18 22:50:19 -0700 |
commit | 6f8db2d7eb32684a8328e0ae8bdd73eceb861707 (patch) | |
tree | 254d964e9f8b95393b82683f66b9c2f77fb060de | |
parent | 8e39b7bf38e7becd1116dfcded380327fd0228d0 (diff) | |
download | mitmproxy-6f8db2d7eb32684a8328e0ae8bdd73eceb861707.tar.gz mitmproxy-6f8db2d7eb32684a8328e0ae8bdd73eceb861707.tar.bz2 mitmproxy-6f8db2d7eb32684a8328e0ae8bdd73eceb861707.zip |
improve MultiDict, add ImmutableMultiDict, adjust response.cookies
-rw-r--r-- | examples/modify_form.py | 2 | ||||
-rw-r--r-- | mitmproxy/console/flowview.py | 50 | ||||
-rw-r--r-- | mitmproxy/console/grideditor.py | 12 | ||||
-rw-r--r-- | mitmproxy/flow.py | 28 | ||||
-rw-r--r-- | netlib/http/__init__.py | 8 | ||||
-rw-r--r-- | netlib/http/cookies.py | 43 | ||||
-rw-r--r-- | netlib/http/headers.py | 4 | ||||
-rw-r--r-- | netlib/http/message.py | 41 | ||||
-rw-r--r-- | netlib/http/request.py | 69 | ||||
-rw-r--r-- | netlib/http/response.py | 45 | ||||
-rw-r--r-- | netlib/multidict.py | 403 | ||||
-rw-r--r-- | test/mitmproxy/test_examples.py | 2 | ||||
-rw-r--r-- | test/netlib/http/test_cookies.py | 14 | ||||
-rw-r--r-- | test/netlib/http/test_request.py | 4 | ||||
-rw-r--r-- | test/netlib/http/test_response.py | 32 |
15 files changed, 433 insertions, 324 deletions
diff --git a/examples/modify_form.py b/examples/modify_form.py index c4edb2cd..3fe0cf96 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,5 +1,5 @@ def request(context, flow): - if flow.request.urlencoded_form is not None: + if flow.request.urlencoded_form: flow.request.urlencoded_form["mitmproxy"] = "rocks" else: # This sets the proper content type and overrides the body. diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py index b2ebe49e..3538c4ad 100644 --- a/mitmproxy/console/flowview.py +++ b/mitmproxy/console/flowview.py @@ -6,8 +6,7 @@ import sys import math import urwid -from netlib import odict -from netlib.http import Headers +from netlib.http import Headers, status_codes from . import common, grideditor, signals, searchable, tabs from . import flowdetailview from .. import utils, controller, contentviews @@ -316,21 +315,18 @@ class FlowView(tabs.Tabs): return "Invalid URL." signals.flow_change.send(self, flow = self.flow) - def set_resp_code(self, code): - response = self.flow.response + def set_resp_status_code(self, status_code): try: - response.status_code = int(code) + status_code = int(status_code) except ValueError: return None - import BaseHTTPServer - if int(code) in BaseHTTPServer.BaseHTTPRequestHandler.responses: - response.msg = BaseHTTPServer.BaseHTTPRequestHandler.responses[ - int(code)][0] + self.flow.response.status_code = status_code + if status_code in status_codes.RESPONSES: + self.flow.response.reason = status_codes.RESPONSES[status_code] signals.flow_change.send(self, flow = self.flow) - def set_resp_msg(self, msg): - response = self.flow.response - response.msg = msg + def set_resp_reason(self, reason): + self.flow.response.reason = reason signals.flow_change.send(self, flow = self.flow) def set_headers(self, fields, conn): @@ -338,22 +334,22 @@ class FlowView(tabs.Tabs): signals.flow_change.send(self, flow = self.flow) def set_query(self, lst, conn): - conn.set_query(odict.ODict(lst)) + conn.query = lst signals.flow_change.send(self, flow = self.flow) def set_path_components(self, lst, conn): - conn.set_path_components(lst) + conn.path_components = lst signals.flow_change.send(self, flow = self.flow) def set_form(self, lst, conn): - conn.set_form_urlencoded(odict.ODict(lst)) + conn.urlencoded_form = lst signals.flow_change.send(self, flow = self.flow) def edit_form(self, conn): self.master.view_grideditor( grideditor.URLEncodedFormEditor( self.master, - conn.get_form_urlencoded().lst, + conn.urlencoded_form.items(multi=True), self.set_form, conn ) @@ -364,7 +360,7 @@ class FlowView(tabs.Tabs): self.edit_form(conn) def set_cookies(self, lst, conn): - conn.cookies = odict.ODict(lst) + conn.cookies = lst signals.flow_change.send(self, flow = self.flow) def set_setcookies(self, data, conn): @@ -388,7 +384,7 @@ class FlowView(tabs.Tabs): self.master.view_grideditor( grideditor.CookieEditor( self.master, - message.cookies.lst, + message.cookies.items(multi=True), self.set_cookies, message ) @@ -397,7 +393,7 @@ class FlowView(tabs.Tabs): self.master.view_grideditor( grideditor.SetCookieEditor( self.master, - message.cookies, + message.cookies.items(multi=True), self.set_setcookies, message ) @@ -413,7 +409,7 @@ class FlowView(tabs.Tabs): c = self.master.spawn_editor(message.content or "") message.content = c.rstrip("\n") elif part == "f": - if not message.get_form_urlencoded() and message.content: + if not message.urlencoded_form and message.content: signals.status_prompt_onekey.send( prompt = "Existing body is not a URL-encoded form. Clear and edit?", keys = [ @@ -435,7 +431,7 @@ class FlowView(tabs.Tabs): ) ) elif part == "p": - p = message.get_path_components() + p = message.path_components self.master.view_grideditor( grideditor.PathEditor( self.master, @@ -448,7 +444,7 @@ class FlowView(tabs.Tabs): self.master.view_grideditor( grideditor.QueryEditor( self.master, - message.get_query().lst, + message.query.items(multi=True), self.set_query, message ) ) @@ -458,7 +454,7 @@ class FlowView(tabs.Tabs): text = message.url, callback = self.set_url ) - elif part == "m": + elif part == "m" and message == self.flow.request: signals.status_prompt_onekey.send( prompt = "Method", keys = common.METHOD_OPTIONS, @@ -468,13 +464,13 @@ class FlowView(tabs.Tabs): signals.status_prompt.send( prompt = "Code", text = str(message.status_code), - callback = self.set_resp_code + callback = self.set_resp_status_code ) - elif part == "m": + elif part == "m" and message == self.flow.response: signals.status_prompt.send( prompt = "Message", - text = message.msg, - callback = self.set_resp_msg + text = message.reason, + callback = self.set_resp_reason ) signals.flow_change.send(self, flow = self.flow) diff --git a/mitmproxy/console/grideditor.py b/mitmproxy/console/grideditor.py index 46ff348e..11ce7d02 100644 --- a/mitmproxy/console/grideditor.py +++ b/mitmproxy/console/grideditor.py @@ -700,17 +700,17 @@ class SetCookieEditor(GridEditor): def data_in(self, data): flattened = [] - for k, v in data.items(): - flattened.append([k, v[0], v[1].lst]) + for key, (value, attrs) in data: + flattened.append([key, value, attrs.items(multi=True)]) return flattened def data_out(self, data): vals = [] - for i in data: + for key, value, attrs in data: vals.append( [ - i[0], - [i[1], odict.ODictCaseless(i[2])] + key, + (value, attrs) ] ) - return odict.ODict(vals) + return vals diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 4663144d..647ebf68 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -319,10 +319,10 @@ class StickyCookieState: """ domain = f.request.host path = "/" - if attrs["domain"]: - domain = attrs["domain"][-1] - if attrs["path"]: - path = attrs["path"][-1] + if "domain" in attrs: + domain = attrs["domain"] + if "path" in attrs: + path = attrs["path"] return (domain, f.request.port, path) def domain_match(self, a, b): @@ -333,28 +333,26 @@ class StickyCookieState: return False def handle_response(self, f): - for i in f.response.headers.get_all("set-cookie"): + for name, (value, attrs) in f.response.cookies.items(multi=True): # FIXME: We now know that Cookie.py screws up some cookies with # valid RFC 822/1123 datetime specifications for expiry. Sigh. - name, value, attrs = cookies.parse_set_cookie_header(str(i)) a = self.ckey(attrs, f) if self.domain_match(f.request.host, a[0]): - b = attrs.lst - b.insert(0, [name, value]) - self.jar[a][name] = odict.ODictCaseless(b) + b = attrs.with_insert(0, name, value) + self.jar[a][name] = b def handle_request(self, f): l = [] if f.match(self.flt): - for i in self.jar.keys(): + for domain, port, path in self.jar.keys(): match = [ - self.domain_match(f.request.host, i[0]), - f.request.port == i[1], - f.request.path.startswith(i[2]) + self.domain_match(f.request.host, domain), + f.request.port == port, + f.request.path.startswith(path) ] if all(match): - c = self.jar[i] - l.extend([cookies.format_cookie_header(c[name].lst) for name in c.keys()]) + c = self.jar[(domain, port, path)] + l.extend([cookies.format_cookie_header(c[name].items(multi=True)) for name in c.keys()]) if l: f.request.stickycookie = True f.request.headers["cookie"] = "; ".join(l) diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 917080f7..9fafa28f 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -2,13 +2,13 @@ from __future__ import absolute_import, print_function, division from .request import Request from .response import Response from .headers import Headers -from .message import decoded -from . import http1, http2 +from .message import MultiDictView, decoded +from . import http1, http2, status_codes __all__ = [ "Request", "Response", "Headers", - "decoded", - "http1", "http2", + "MultiDictView", "decoded", + "http1", "http2", "status_codes", ] diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index fd531146..c5ac4591 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -1,6 +1,8 @@ +import collections import re from email.utils import parsedate_tz, formatdate, mktime_tz +from netlib.multidict import ImmutableMultiDict from .. import odict """ @@ -155,25 +157,52 @@ def _parse_set_cookie_pairs(s): return pairs +def parse_set_cookie_headers(headers): + ret = [] + for header in headers: + v = parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append((name, SetCookie(value, attrs))) + return ret + + +class CookieAttrs(ImmutableMultiDict): + @staticmethod + def _kconv(v): + return v.lower() + + @staticmethod + def _reduce_values(values): + # See the StickyCookieTest for a weird cookie that only makes sense + # if we take the last part. + return values[-1] + + +SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"]) + + def parse_set_cookie_header(line): """ Parse a Set-Cookie header value Returns a (name, value, attrs) tuple, or None, where attrs is an - ODictCaseless set of attributes. No attempt is made to parse attribute + CookieAttrs dict of attributes. No attempt is made to parse attribute values - they are treated purely as strings. """ pairs = _parse_set_cookie_pairs(line) if pairs: - return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + return pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:]) def format_set_cookie_header(name, value, attrs): """ Formats a Set-Cookie header value. """ - pairs = [[name, value]] - pairs.extend(attrs.lst) + pairs = [(name, value)] + pairs.extend( + attrs.fields if hasattr(attrs, "fields") else attrs + ) return _format_set_cookie_pairs(pairs) @@ -214,10 +243,10 @@ def refresh_set_cookie_header(c, delta): raise ValueError("Invalid Cookie") if "expires" in attrs: - e = parsedate_tz(attrs["expires"][-1]) + e = parsedate_tz(attrs["expires"]) if e: f = mktime_tz(e) + delta - attrs["expires"] = [formatdate(f)] + attrs = attrs.with_set_all("expires", [formatdate(f)]) else: # This can happen when the expires tag is invalid. # reddit.com sends a an expires tag like this: "Thu, 31 Dec @@ -225,7 +254,7 @@ def refresh_set_cookie_header(c, delta): # strictly correct according to the cookie spec. Browsers # appear to parse this tolerantly - maybe we should too. # For now, we just ignore this. - del attrs["expires"] + attrs = attrs.with_delitem("expires") ret = format_set_cookie_header(name, value, attrs) if not ret: diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 7e39c371..8959394c 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -83,6 +83,10 @@ class Headers(MultiDict): """ super(Headers, self).__init__(fields) + for key, value in self.fields: + if not isinstance(key, bytes) or not isinstance(value, bytes): + raise TypeError("Header fields must be bytes.") + # content_type -> content-type headers = { _always_bytes(name).replace(b"_", b"-"): _always_bytes(value) diff --git a/netlib/http/message.py b/netlib/http/message.py index 262ef3e1..3c731ea6 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -238,9 +238,44 @@ class decoded(object): self.message.encode(self.ce) -class MessageMultiDict(MultiDict): +class MultiDictView(MultiDict): """ - A MultiDict that provides a proxy view to the underlying message. + Some parts in HTTP (Cookies, URL query strings, ...) require a specific data structure: A MultiDict. + It behaves mostly like an ordered dict but it can have several values for the same key. + + The MultiDictView provides a MultiDict *view* on an :py:class:`Request` or :py:class:`Response`. + That is, it represents a part of the request as a MultiDict, but doesn't contain state/data themselves. + + For example, ``request.cookies`` provides a view on the ``Cookie: ...`` header. + Any change to ``request.cookies`` will also modify the ``Cookie`` header. + Any change to the ``Cookie`` header will also modify ``request.cookies``. + + Example: + + .. code-block:: python + + # Cookies are represented as a MultiDict. + >>> request.cookies + MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] + + # MultiDicts mostly behave like a normal dict. + >>> request.cookies["name"] + "value" + + # If there is more than one value, only the first value is returned. + >>> request.cookies["a"] + "false" + + # `.get_all(key)` returns a list of all values. + >>> request.cookies.get_all("a") + ["false", "42"] + + # Changes to the headers are immediately reflected in the cookies. + >>> request.cookies + MultiDictView[("name", "value"), ...] + >>> del request.headers["Cookie"] + >>> request.cookies + MultiDictView[] # empty now """ def __init__(self, attr, message): @@ -248,7 +283,7 @@ class MessageMultiDict(MultiDict): # We do not want to call the parent constructor here as that # would cause an unnecessary parse/unparse pass. # This is here to silence linters. Message - super(MessageMultiDict, self).__init__(None) + super(MultiDictView, self).__init__(None) self._attr = attr self._message = message # type: Message diff --git a/netlib/http/request.py b/netlib/http/request.py index 26ec12cf..ae28084b 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -11,7 +11,7 @@ from netlib.http import cookies from netlib.odict import ODict from .. import encoding from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData, MessageMultiDict +from .message import Message, _native, _always_bytes, MessageData, MultiDictView # This regex extracts & splits the host header into host and port. # Handles the edge case of IPv6 addresses containing colons. @@ -224,11 +224,11 @@ class Request(Message): @property def query(self): - # type: () -> MessageMultiDict + # type: () -> MultiDictView """ - The request query string as an :py:class:`MessageMultiDict` object. + The request query string as an :py:class:`MultiDictView` object. """ - return MessageMultiDict("query", self) + return MultiDictView("query", self) @property def _query(self): @@ -244,13 +244,13 @@ class Request(Message): @property def cookies(self): - # type: () -> MessageMultiDict + # type: () -> MultiDictView """ The request cookies. - An empty :py:class:`MessageMultiDict` object if the cookie monster ate them all. + An empty :py:class:`MultiDictView` object if the cookie monster ate them all. """ - return MessageMultiDict("cookies", self) + return MultiDictView("cookies", self) @property def _cookies(self): @@ -318,17 +318,18 @@ class Request(Message): @property def urlencoded_form(self): """ - The URL-encoded form data as an :py:class:`MessageMultiDict` object. - None if the content-type indicates non-form data. + The URL-encoded form data as an :py:class:`MultiDictView` object. + An empty MultiDictView if the content-type indicates non-form data + or the content could not be parsed. """ - is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() - if is_valid_content_type: - return MessageMultiDict("urlencoded_form", self) - return None + return MultiDictView("urlencoded_form", self) @property def _urlencoded_form(self): - return tuple(utils.urldecode(self.content)) + is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() + if is_valid_content_type: + return tuple(utils.urldecode(self.content)) + return () @urlencoded_form.setter def urlencoded_form(self, value): @@ -345,45 +346,15 @@ class Request(Message): The multipart form data as an :py:class:`MultipartFormDict` object. None if the content-type indicates non-form data. """ - is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() - if is_valid_content_type: - return MessageMultiDict("multipart_form", self) - return None + return MultiDictView("multipart_form", self) @property def _multipart_form(self): - return utils.multipartdecode(self.headers, self.content) + is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() + if is_valid_content_type: + return utils.multipartdecode(self.headers, self.content) + return () @multipart_form.setter def multipart_form(self, value): raise NotImplementedError() - - # Legacy - - def get_query(self): # pragma: no cover - warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) - return self.query or ODict([]) - - def set_query(self, odict): # pragma: no cover - warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) - self.query = odict - - def get_path_components(self): # pragma: no cover - warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) - return self.path_components - - def set_path_components(self, lst): # pragma: no cover - warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) - self.path_components = lst - - def get_form_urlencoded(self): # pragma: no cover - warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) - return self.urlencoded_form or ODict([]) - - def set_form_urlencoded(self, odict): # pragma: no cover - warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) - self.urlencoded_form = odict - - def get_form_multipart(self): # pragma: no cover - warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) - return self.multipart_form or ODict([]) diff --git a/netlib/http/response.py b/netlib/http/response.py index 20074dca..6d56fc1f 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,14 +1,12 @@ from __future__ import absolute_import, print_function, division -import warnings from email.utils import parsedate_tz, formatdate, mktime_tz import time from . import cookies from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData +from .message import Message, _native, _always_bytes, MessageData, MultiDictView from .. import utils -from ..odict import ODict class ResponseData(MessageData): @@ -70,33 +68,32 @@ class Response(Message): def reason(self, reason): self.data.reason = _always_bytes(reason) - # FIXME @property def cookies(self): + # type: () -> MultiDictView """ - Get the contents of all Set-Cookie headers. + The response cookies. A possibly empty :py:class:`MultiDictView`, where the keys are + cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly) + are indicated by a Null value. - A possibly empty :py:class:`ODict`, where keys are cookie name strings, - and values are [value, attr] lists. Value is a string, and attr is - an ODictCaseless containing cookie attributes. Within attrs, unary - attributes (e.g. HTTPOnly) are indicated by a Null value. + Caveats: + Updating the attr """ - ret = [] - for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(header) - if v: - name, value, attrs = v - ret.append([name, [value, attrs]]) - return ODict(ret) - - # FIXME + return MultiDictView("cookies", self) + + @property + def _cookies(self): + h = self.headers.get_all("set-cookie") + return tuple(cookies.parse_set_cookie_headers(h)) + @cookies.setter - def cookies(self, odict): - values = [] - for i in odict.lst: - header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) - values.append(header) - self.headers.set_all("set-cookie", values) + def cookies(self, all_cookies): + cookie_headers = [] + for k, v in all_cookies: + header = cookies.format_set_cookie_header(k, v[0], v[1]) + cookie_headers.append(header) + self.headers.set_all("set-cookie", cookie_headers) def refresh(self, now=None): """ diff --git a/netlib/multidict.py b/netlib/multidict.py index a7158bc5..32d5bfc2 100644 --- a/netlib/multidict.py +++ b/netlib/multidict.py @@ -1,163 +1,240 @@ -from __future__ import absolute_import, print_function, division
-
-from abc import ABCMeta, abstractmethod
-
-from typing import Tuple
-
-try:
- from collections.abc import MutableMapping
-except ImportError: # pragma: no cover
- from collections import MutableMapping # Workaround for Python < 3.3
-
-import six
-
-from .utils import Serializable
-
-
-@six.add_metaclass(ABCMeta)
-class MultiDict(MutableMapping, Serializable):
- def __init__(self, fields=None):
-
- # it is important for us that .fields is immutable, so that we can easily
- # detect changes to it.
- self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
-
- for key, value in self.fields:
- if not isinstance(key, bytes) or not isinstance(value, bytes):
- raise TypeError("MultiDict fields must be bytes.")
-
- def __repr__(self):
- fields = tuple(
- repr(field)
- for field in self.fields
- )
- return "{cls}[{fields}]".format(
- cls=type(self).__name__,
- fields=", ".join(fields)
- )
-
- @staticmethod
- @abstractmethod
- def _reduce_values(values):
- pass
-
- @staticmethod
- @abstractmethod
- def _kconv(v):
- pass
-
- def __getitem__(self, key):
- values = self.get_all(key)
- if not values:
- raise KeyError(key)
- return self._reduce_values(values)
-
- def __setitem__(self, key, value):
- self.set_all(key, [value])
-
- def __delitem__(self, key):
- if key not in self:
- raise KeyError(key)
- key = self._kconv(key)
- self.fields = tuple(
- field for field in self.fields
- if key != self._kconv(field[0])
- )
-
- def __iter__(self):
- seen = set()
- for key, _ in self.fields:
- key_kconv = self._kconv(key)
- if key_kconv not in seen:
- seen.add(key_kconv)
- yield key
-
- def __len__(self):
- return len(set(self._kconv(key) for key, _ in self.fields))
-
- def __eq__(self, other):
- if isinstance(other, MultiDict):
- return self.fields == other.fields
- return False
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def get_all(self, key):
- """
- Return the list of items for a given key.
- If that key is not in the MultiDict,
- the return value will be an empty list.
- """
- key = self._kconv(key)
- return [
- value
- for k, value in self.fields
- if self._kconv(k) == key
- ]
-
- def set_all(self, key, values):
- """
- Remove the old values for a key and add new ones.
- """
- key_kconv = self._kconv(key)
-
- new_fields = []
- for field in self.fields:
- if self._kconv(field[0]) == key_kconv:
- if values:
- new_fields.append(
- (key, values.pop(0))
- )
- else:
- new_fields.append(field)
- while values:
- new_fields.append(
- (key, values.pop(0))
- )
- self.fields = tuple(new_fields)
-
- def add(self, key, value):
- self.insert(len(self.fields), key, value)
-
- def insert(self, index, key, value):
- item = (key, value)
- self.fields = self.fields[:index] + (item,) + self.fields[index:]
-
- def keys(self, multi=False):
- return (
- k
- for k, _ in self.items(multi)
- )
-
- def values(self, multi=False):
- return (
- v
- for _, v in self.items(multi)
- )
-
- def items(self, multi=False):
- if multi:
- return self.fields
- else:
- return super(MultiDict, self).items()
-
- def to_dict(self):
- d = {}
- for key in self:
- values = self.get_all(key)
- if len(values) == 1:
- d[key] = values[0]
- else:
- d[key] = values
- return d
-
- def get_state(self):
- return self.fields
-
- def set_state(self, state):
- self.fields = tuple(tuple(x) for x in state)
-
- @classmethod
- def from_state(cls, state):
- return cls(tuple(x) for x in state)
+from __future__ import absolute_import, print_function, division + +from abc import ABCMeta, abstractmethod + +from typing import Tuple, TypeVar + +try: + from collections.abc import MutableMapping +except ImportError: # pragma: no cover + from collections import MutableMapping # Workaround for Python < 3.3 + +import six + +from .utils import Serializable + + +@six.add_metaclass(ABCMeta) +class MultiDict(MutableMapping, Serializable): + def __init__(self, fields=None): + + # it is important for us that .fields is immutable, so that we can easily + # detect changes to it. + self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] + + def __repr__(self): + fields = tuple( + repr(field) + for field in self.fields + ) + return "{cls}[{fields}]".format( + cls=type(self).__name__, + fields=", ".join(fields) + ) + + @staticmethod + @abstractmethod + def _reduce_values(values): + pass + + @staticmethod + @abstractmethod + def _kconv(v): + pass + + def __getitem__(self, key): + values = self.get_all(key) + if not values: + raise KeyError(key) + return self._reduce_values(values) + + def __setitem__(self, key, value): + self.set_all(key, [value]) + + def __delitem__(self, key): + if key not in self: + raise KeyError(key) + key = self._kconv(key) + self.fields = tuple( + field for field in self.fields + if key != self._kconv(field[0]) + ) + + def __iter__(self): + seen = set() + for key, _ in self.fields: + key_kconv = self._kconv(key) + if key_kconv not in seen: + seen.add(key_kconv) + yield key + + def __len__(self): + return len(set(self._kconv(key) for key, _ in self.fields)) + + def __eq__(self, other): + if isinstance(other, MultiDict): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_all(self, key): + """ + Return the list of all values for a given key. + If that key is not in the MultiDict, the return value will be an empty list. + """ + key = self._kconv(key) + return [ + value + for k, value in self.fields + if self._kconv(k) == key + ] + + def set_all(self, key, values): + """ + Remove the old values for a key and add new ones. + """ + key_kconv = self._kconv(key) + + new_fields = [] + for field in self.fields: + if self._kconv(field[0]) == key_kconv: + if values: + new_fields.append( + (key, values.pop(0)) + ) + else: + new_fields.append(field) + while values: + new_fields.append( + (key, values.pop(0)) + ) + self.fields = tuple(new_fields) + + def add(self, key, value): + """ + Add an additional value for the given key at the bottom. + """ + self.insert(len(self.fields), key, value) + + def insert(self, index, key, value): + """ + Insert an additional value for the given key at the specified position. + """ + item = (key, value) + self.fields = self.fields[:index] + (item,) + self.fields[index:] + + def keys(self, multi=False): + """ + Get all keys. + + Args: + multi(bool): + If True, one key per value will be returned. + If False, duplicate keys will only be returned once. + """ + return ( + k + for k, _ in self.items(multi) + ) + + def values(self, multi=False): + """ + Get all values. + + Args: + multi(bool): + If True, all values will be returned. + If False, only the first value per key will be returned. + """ + return ( + v + for _, v in self.items(multi) + ) + + def items(self, multi=False): + """ + Get all (key, value) tuples. + + Args: + multi(bool): + If True, all (key, value) pairs will be returned + If False, only the first (key, value) pair per unique key will be returned. + """ + if multi: + return self.fields + else: + return super(MultiDict, self).items() + + def to_dict(self): + """ + Get the MultiDict as a plain Python dict. + Keys with multiple values are returned as lists. + + Example: + + .. code-block:: python + + # Simple dict with duplicate values. + >>> d + MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] + >>> d.to_dict() + { + "name": "value", + "a": ["false", "42"] + } + """ + d = {} + for key in self: + values = self.get_all(key) + if len(values) == 1: + d[key] = values[0] + else: + d[key] = values + return d + + def get_state(self): + return self.fields + + def set_state(self, state): + self.fields = tuple(tuple(x) for x in state) + + @classmethod + def from_state(cls, state): + return cls(tuple(x) for x in state) + + +@six.add_metaclass(ABCMeta) +class ImmutableMultiDict(MultiDict): + def _immutable(self, *_): + raise TypeError('{} objects are immutable'.format(self.__class__.__name__)) + + __delitem__ = set_all = insert = _immutable + + def with_delitem(self, key): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).__delitem__(key) + return ret + + def with_set_all(self, key, values): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).set_all(key, values) + return ret + + def with_insert(self, index, key, value): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).insert(index, key, value) + return ret diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index d0a258e9..ac79b093 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -98,7 +98,7 @@ def test_modify_form(): flow.request.headers["content-type"] = "" ex.run("request", flow) - assert list(flow.request.urlencoded_form.items()) == [("foo","bar")] + assert list(flow.request.urlencoded_form.items()) == [("foo", "bar")] def test_modify_querystring(): diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py index e2cee57f..6f84c4ce 100644 --- a/test/netlib/http/test_cookies.py +++ b/test/netlib/http/test_cookies.py @@ -197,24 +197,28 @@ def test_parse_set_cookie_header(): ], [ "one=uno", - ("one", "uno", []) + ("one", "uno", ()) ], [ "one=uno; foo=bar", - ("one", "uno", [["foo", "bar"]]) - ] + ("one", "uno", (("foo", "bar"),)) + ], + [ + "one=uno; foo=bar; foo=baz", + ("one", "uno", (("foo", "bar"), ("foo", "baz"))) + ], ] for s, expected in vals: ret = cookies.parse_set_cookie_header(s) if expected: assert ret[0] == expected[0] assert ret[1] == expected[1] - assert ret[2].lst == expected[2] + assert ret[2].items(multi=True) == expected[2] s2 = cookies.format_set_cookie_header(*ret) ret2 = cookies.parse_set_cookie_header(s2) assert ret2[0] == expected[0] assert ret2[1] == expected[1] - assert ret2[2].lst == expected[2] + assert ret2[2].items(multi=True) == expected[2] else: assert ret is None diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index 26593ee1..eefdc091 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -251,7 +251,7 @@ class TestRequestUtils(object): def test_get_urlencoded_form(self): request = treq(content="foobar=baz") - assert request.urlencoded_form is None + assert not request.urlencoded_form request.headers["Content-Type"] = "application/x-www-form-urlencoded" assert list(request.urlencoded_form.items()) == [("foobar", "baz")] @@ -264,7 +264,7 @@ class TestRequestUtils(object): def test_get_multipart_form(self): request = treq(content="foobar") - assert request.multipart_form is None + assert not request.multipart_form request.headers["Content-Type"] = "multipart/form-data" assert list(request.multipart_form.items()) == [] diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py index 37273541..cfd093d4 100644 --- a/test/netlib/http/test_response.py +++ b/test/netlib/http/test_response.py @@ -6,6 +6,7 @@ import six import time from netlib.http import Headers +from netlib.http.cookies import CookieAttrs from netlib.odict import ODict, ODictCaseless from netlib.tutils import raises, tresp from .test_message import _test_passthrough_attr, _test_decoded_attr @@ -56,7 +57,7 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 1 assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", ODict()] + assert result["cookiename"] == ("cookievalue", CookieAttrs()) def test_get_cookies_with_parameters(self): resp = tresp() @@ -64,13 +65,13 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 1 assert "cookiename" in result - assert result["cookiename"][0][0] == "cookievalue" - attrs = result["cookiename"][0][1] + assert result["cookiename"][0] == "cookievalue" + attrs = result["cookiename"][1] assert len(attrs) == 4 - assert attrs["domain"] == ["example.com"] - assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"] - assert attrs["path"] == ["/"] - assert attrs["httponly"] == [None] + assert attrs["domain"] == "example.com" + assert attrs["expires"] == "Wed Oct 21 16:29:41 2015" + assert attrs["path"] == "/" + assert attrs["httponly"] is None def test_get_cookies_no_value(self): resp = tresp() @@ -78,8 +79,8 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 1 assert "cookiename" in result - assert result["cookiename"][0][0] == "" - assert len(result["cookiename"][0][1]) == 2 + assert result["cookiename"][0] == "" + assert len(result["cookiename"][1]) == 2 def test_get_cookies_twocookies(self): resp = tresp() @@ -90,19 +91,16 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 2 assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", ODict()] + assert result["cookiename"] == ("cookievalue", CookieAttrs()) assert "othercookie" in result - assert result["othercookie"][0] == ["othervalue", ODict()] + assert result["othercookie"] == ("othervalue", CookieAttrs()) def test_set_cookies(self): resp = tresp() - v = resp.cookies - v.add("foo", ["bar", ODictCaseless()]) - resp.cookies = v + resp.cookies["foo"] = ("bar", {}) - v = resp.cookies - assert len(v) == 1 - assert v["foo"] == [["bar", ODictCaseless()]] + assert len(resp.cookies) == 1 + assert resp.cookies["foo"] == ("bar", CookieAttrs()) def test_refresh(self): r = tresp() |