aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2016-05-18 22:50:19 -0700
committerMaximilian Hils <git@maximilianhils.com>2016-05-18 22:50:19 -0700
commit6f8db2d7eb32684a8328e0ae8bdd73eceb861707 (patch)
tree254d964e9f8b95393b82683f66b9c2f77fb060de
parent8e39b7bf38e7becd1116dfcded380327fd0228d0 (diff)
downloadmitmproxy-6f8db2d7eb32684a8328e0ae8bdd73eceb861707.tar.gz
mitmproxy-6f8db2d7eb32684a8328e0ae8bdd73eceb861707.tar.bz2
mitmproxy-6f8db2d7eb32684a8328e0ae8bdd73eceb861707.zip
improve MultiDict, add ImmutableMultiDict, adjust response.cookies
-rw-r--r--examples/modify_form.py2
-rw-r--r--mitmproxy/console/flowview.py50
-rw-r--r--mitmproxy/console/grideditor.py12
-rw-r--r--mitmproxy/flow.py28
-rw-r--r--netlib/http/__init__.py8
-rw-r--r--netlib/http/cookies.py43
-rw-r--r--netlib/http/headers.py4
-rw-r--r--netlib/http/message.py41
-rw-r--r--netlib/http/request.py69
-rw-r--r--netlib/http/response.py45
-rw-r--r--netlib/multidict.py403
-rw-r--r--test/mitmproxy/test_examples.py2
-rw-r--r--test/netlib/http/test_cookies.py14
-rw-r--r--test/netlib/http/test_request.py4
-rw-r--r--test/netlib/http/test_response.py32
15 files changed, 433 insertions, 324 deletions
diff --git a/examples/modify_form.py b/examples/modify_form.py
index c4edb2cd..3fe0cf96 100644
--- a/examples/modify_form.py
+++ b/examples/modify_form.py
@@ -1,5 +1,5 @@
def request(context, flow):
- if flow.request.urlencoded_form is not None:
+ if flow.request.urlencoded_form:
flow.request.urlencoded_form["mitmproxy"] = "rocks"
else:
# This sets the proper content type and overrides the body.
diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py
index b2ebe49e..3538c4ad 100644
--- a/mitmproxy/console/flowview.py
+++ b/mitmproxy/console/flowview.py
@@ -6,8 +6,7 @@ import sys
import math
import urwid
-from netlib import odict
-from netlib.http import Headers
+from netlib.http import Headers, status_codes
from . import common, grideditor, signals, searchable, tabs
from . import flowdetailview
from .. import utils, controller, contentviews
@@ -316,21 +315,18 @@ class FlowView(tabs.Tabs):
return "Invalid URL."
signals.flow_change.send(self, flow = self.flow)
- def set_resp_code(self, code):
- response = self.flow.response
+ def set_resp_status_code(self, status_code):
try:
- response.status_code = int(code)
+ status_code = int(status_code)
except ValueError:
return None
- import BaseHTTPServer
- if int(code) in BaseHTTPServer.BaseHTTPRequestHandler.responses:
- response.msg = BaseHTTPServer.BaseHTTPRequestHandler.responses[
- int(code)][0]
+ self.flow.response.status_code = status_code
+ if status_code in status_codes.RESPONSES:
+ self.flow.response.reason = status_codes.RESPONSES[status_code]
signals.flow_change.send(self, flow = self.flow)
- def set_resp_msg(self, msg):
- response = self.flow.response
- response.msg = msg
+ def set_resp_reason(self, reason):
+ self.flow.response.reason = reason
signals.flow_change.send(self, flow = self.flow)
def set_headers(self, fields, conn):
@@ -338,22 +334,22 @@ class FlowView(tabs.Tabs):
signals.flow_change.send(self, flow = self.flow)
def set_query(self, lst, conn):
- conn.set_query(odict.ODict(lst))
+ conn.query = lst
signals.flow_change.send(self, flow = self.flow)
def set_path_components(self, lst, conn):
- conn.set_path_components(lst)
+ conn.path_components = lst
signals.flow_change.send(self, flow = self.flow)
def set_form(self, lst, conn):
- conn.set_form_urlencoded(odict.ODict(lst))
+ conn.urlencoded_form = lst
signals.flow_change.send(self, flow = self.flow)
def edit_form(self, conn):
self.master.view_grideditor(
grideditor.URLEncodedFormEditor(
self.master,
- conn.get_form_urlencoded().lst,
+ conn.urlencoded_form.items(multi=True),
self.set_form,
conn
)
@@ -364,7 +360,7 @@ class FlowView(tabs.Tabs):
self.edit_form(conn)
def set_cookies(self, lst, conn):
- conn.cookies = odict.ODict(lst)
+ conn.cookies = lst
signals.flow_change.send(self, flow = self.flow)
def set_setcookies(self, data, conn):
@@ -388,7 +384,7 @@ class FlowView(tabs.Tabs):
self.master.view_grideditor(
grideditor.CookieEditor(
self.master,
- message.cookies.lst,
+ message.cookies.items(multi=True),
self.set_cookies,
message
)
@@ -397,7 +393,7 @@ class FlowView(tabs.Tabs):
self.master.view_grideditor(
grideditor.SetCookieEditor(
self.master,
- message.cookies,
+ message.cookies.items(multi=True),
self.set_setcookies,
message
)
@@ -413,7 +409,7 @@ class FlowView(tabs.Tabs):
c = self.master.spawn_editor(message.content or "")
message.content = c.rstrip("\n")
elif part == "f":
- if not message.get_form_urlencoded() and message.content:
+ if not message.urlencoded_form and message.content:
signals.status_prompt_onekey.send(
prompt = "Existing body is not a URL-encoded form. Clear and edit?",
keys = [
@@ -435,7 +431,7 @@ class FlowView(tabs.Tabs):
)
)
elif part == "p":
- p = message.get_path_components()
+ p = message.path_components
self.master.view_grideditor(
grideditor.PathEditor(
self.master,
@@ -448,7 +444,7 @@ class FlowView(tabs.Tabs):
self.master.view_grideditor(
grideditor.QueryEditor(
self.master,
- message.get_query().lst,
+ message.query.items(multi=True),
self.set_query, message
)
)
@@ -458,7 +454,7 @@ class FlowView(tabs.Tabs):
text = message.url,
callback = self.set_url
)
- elif part == "m":
+ elif part == "m" and message == self.flow.request:
signals.status_prompt_onekey.send(
prompt = "Method",
keys = common.METHOD_OPTIONS,
@@ -468,13 +464,13 @@ class FlowView(tabs.Tabs):
signals.status_prompt.send(
prompt = "Code",
text = str(message.status_code),
- callback = self.set_resp_code
+ callback = self.set_resp_status_code
)
- elif part == "m":
+ elif part == "m" and message == self.flow.response:
signals.status_prompt.send(
prompt = "Message",
- text = message.msg,
- callback = self.set_resp_msg
+ text = message.reason,
+ callback = self.set_resp_reason
)
signals.flow_change.send(self, flow = self.flow)
diff --git a/mitmproxy/console/grideditor.py b/mitmproxy/console/grideditor.py
index 46ff348e..11ce7d02 100644
--- a/mitmproxy/console/grideditor.py
+++ b/mitmproxy/console/grideditor.py
@@ -700,17 +700,17 @@ class SetCookieEditor(GridEditor):
def data_in(self, data):
flattened = []
- for k, v in data.items():
- flattened.append([k, v[0], v[1].lst])
+ for key, (value, attrs) in data:
+ flattened.append([key, value, attrs.items(multi=True)])
return flattened
def data_out(self, data):
vals = []
- for i in data:
+ for key, value, attrs in data:
vals.append(
[
- i[0],
- [i[1], odict.ODictCaseless(i[2])]
+ key,
+ (value, attrs)
]
)
- return odict.ODict(vals)
+ return vals
diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py
index 4663144d..647ebf68 100644
--- a/mitmproxy/flow.py
+++ b/mitmproxy/flow.py
@@ -319,10 +319,10 @@ class StickyCookieState:
"""
domain = f.request.host
path = "/"
- if attrs["domain"]:
- domain = attrs["domain"][-1]
- if attrs["path"]:
- path = attrs["path"][-1]
+ if "domain" in attrs:
+ domain = attrs["domain"]
+ if "path" in attrs:
+ path = attrs["path"]
return (domain, f.request.port, path)
def domain_match(self, a, b):
@@ -333,28 +333,26 @@ class StickyCookieState:
return False
def handle_response(self, f):
- for i in f.response.headers.get_all("set-cookie"):
+ for name, (value, attrs) in f.response.cookies.items(multi=True):
# FIXME: We now know that Cookie.py screws up some cookies with
# valid RFC 822/1123 datetime specifications for expiry. Sigh.
- name, value, attrs = cookies.parse_set_cookie_header(str(i))
a = self.ckey(attrs, f)
if self.domain_match(f.request.host, a[0]):
- b = attrs.lst
- b.insert(0, [name, value])
- self.jar[a][name] = odict.ODictCaseless(b)
+ b = attrs.with_insert(0, name, value)
+ self.jar[a][name] = b
def handle_request(self, f):
l = []
if f.match(self.flt):
- for i in self.jar.keys():
+ for domain, port, path in self.jar.keys():
match = [
- self.domain_match(f.request.host, i[0]),
- f.request.port == i[1],
- f.request.path.startswith(i[2])
+ self.domain_match(f.request.host, domain),
+ f.request.port == port,
+ f.request.path.startswith(path)
]
if all(match):
- c = self.jar[i]
- l.extend([cookies.format_cookie_header(c[name].lst) for name in c.keys()])
+ c = self.jar[(domain, port, path)]
+ l.extend([cookies.format_cookie_header(c[name].items(multi=True)) for name in c.keys()])
if l:
f.request.stickycookie = True
f.request.headers["cookie"] = "; ".join(l)
diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py
index 917080f7..9fafa28f 100644
--- a/netlib/http/__init__.py
+++ b/netlib/http/__init__.py
@@ -2,13 +2,13 @@ from __future__ import absolute_import, print_function, division
from .request import Request
from .response import Response
from .headers import Headers
-from .message import decoded
-from . import http1, http2
+from .message import MultiDictView, decoded
+from . import http1, http2, status_codes
__all__ = [
"Request",
"Response",
"Headers",
- "decoded",
- "http1", "http2",
+ "MultiDictView", "decoded",
+ "http1", "http2", "status_codes",
]
diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py
index fd531146..c5ac4591 100644
--- a/netlib/http/cookies.py
+++ b/netlib/http/cookies.py
@@ -1,6 +1,8 @@
+import collections
import re
from email.utils import parsedate_tz, formatdate, mktime_tz
+from netlib.multidict import ImmutableMultiDict
from .. import odict
"""
@@ -155,25 +157,52 @@ def _parse_set_cookie_pairs(s):
return pairs
+def parse_set_cookie_headers(headers):
+ ret = []
+ for header in headers:
+ v = parse_set_cookie_header(header)
+ if v:
+ name, value, attrs = v
+ ret.append((name, SetCookie(value, attrs)))
+ return ret
+
+
+class CookieAttrs(ImmutableMultiDict):
+ @staticmethod
+ def _kconv(v):
+ return v.lower()
+
+ @staticmethod
+ def _reduce_values(values):
+ # See the StickyCookieTest for a weird cookie that only makes sense
+ # if we take the last part.
+ return values[-1]
+
+
+SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"])
+
+
def parse_set_cookie_header(line):
"""
Parse a Set-Cookie header value
Returns a (name, value, attrs) tuple, or None, where attrs is an
- ODictCaseless set of attributes. No attempt is made to parse attribute
+ CookieAttrs dict of attributes. No attempt is made to parse attribute
values - they are treated purely as strings.
"""
pairs = _parse_set_cookie_pairs(line)
if pairs:
- return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:])
+ return pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:])
def format_set_cookie_header(name, value, attrs):
"""
Formats a Set-Cookie header value.
"""
- pairs = [[name, value]]
- pairs.extend(attrs.lst)
+ pairs = [(name, value)]
+ pairs.extend(
+ attrs.fields if hasattr(attrs, "fields") else attrs
+ )
return _format_set_cookie_pairs(pairs)
@@ -214,10 +243,10 @@ def refresh_set_cookie_header(c, delta):
raise ValueError("Invalid Cookie")
if "expires" in attrs:
- e = parsedate_tz(attrs["expires"][-1])
+ e = parsedate_tz(attrs["expires"])
if e:
f = mktime_tz(e) + delta
- attrs["expires"] = [formatdate(f)]
+ attrs = attrs.with_set_all("expires", [formatdate(f)])
else:
# This can happen when the expires tag is invalid.
# reddit.com sends a an expires tag like this: "Thu, 31 Dec
@@ -225,7 +254,7 @@ def refresh_set_cookie_header(c, delta):
# strictly correct according to the cookie spec. Browsers
# appear to parse this tolerantly - maybe we should too.
# For now, we just ignore this.
- del attrs["expires"]
+ attrs = attrs.with_delitem("expires")
ret = format_set_cookie_header(name, value, attrs)
if not ret:
diff --git a/netlib/http/headers.py b/netlib/http/headers.py
index 7e39c371..8959394c 100644
--- a/netlib/http/headers.py
+++ b/netlib/http/headers.py
@@ -83,6 +83,10 @@ class Headers(MultiDict):
"""
super(Headers, self).__init__(fields)
+ for key, value in self.fields:
+ if not isinstance(key, bytes) or not isinstance(value, bytes):
+ raise TypeError("Header fields must be bytes.")
+
# content_type -> content-type
headers = {
_always_bytes(name).replace(b"_", b"-"): _always_bytes(value)
diff --git a/netlib/http/message.py b/netlib/http/message.py
index 262ef3e1..3c731ea6 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -238,9 +238,44 @@ class decoded(object):
self.message.encode(self.ce)
-class MessageMultiDict(MultiDict):
+class MultiDictView(MultiDict):
"""
- A MultiDict that provides a proxy view to the underlying message.
+ Some parts in HTTP (Cookies, URL query strings, ...) require a specific data structure: A MultiDict.
+ It behaves mostly like an ordered dict but it can have several values for the same key.
+
+ The MultiDictView provides a MultiDict *view* on an :py:class:`Request` or :py:class:`Response`.
+ That is, it represents a part of the request as a MultiDict, but doesn't contain state/data themselves.
+
+ For example, ``request.cookies`` provides a view on the ``Cookie: ...`` header.
+ Any change to ``request.cookies`` will also modify the ``Cookie`` header.
+ Any change to the ``Cookie`` header will also modify ``request.cookies``.
+
+ Example:
+
+ .. code-block:: python
+
+ # Cookies are represented as a MultiDict.
+ >>> request.cookies
+ MultiDictView[("name", "value"), ("a", "false"), ("a", "42")]
+
+ # MultiDicts mostly behave like a normal dict.
+ >>> request.cookies["name"]
+ "value"
+
+ # If there is more than one value, only the first value is returned.
+ >>> request.cookies["a"]
+ "false"
+
+ # `.get_all(key)` returns a list of all values.
+ >>> request.cookies.get_all("a")
+ ["false", "42"]
+
+ # Changes to the headers are immediately reflected in the cookies.
+ >>> request.cookies
+ MultiDictView[("name", "value"), ...]
+ >>> del request.headers["Cookie"]
+ >>> request.cookies
+ MultiDictView[] # empty now
"""
def __init__(self, attr, message):
@@ -248,7 +283,7 @@ class MessageMultiDict(MultiDict):
# We do not want to call the parent constructor here as that
# would cause an unnecessary parse/unparse pass.
# This is here to silence linters. Message
- super(MessageMultiDict, self).__init__(None)
+ super(MultiDictView, self).__init__(None)
self._attr = attr
self._message = message # type: Message
diff --git a/netlib/http/request.py b/netlib/http/request.py
index 26ec12cf..ae28084b 100644
--- a/netlib/http/request.py
+++ b/netlib/http/request.py
@@ -11,7 +11,7 @@ from netlib.http import cookies
from netlib.odict import ODict
from .. import encoding
from .headers import Headers
-from .message import Message, _native, _always_bytes, MessageData, MessageMultiDict
+from .message import Message, _native, _always_bytes, MessageData, MultiDictView
# This regex extracts & splits the host header into host and port.
# Handles the edge case of IPv6 addresses containing colons.
@@ -224,11 +224,11 @@ class Request(Message):
@property
def query(self):
- # type: () -> MessageMultiDict
+ # type: () -> MultiDictView
"""
- The request query string as an :py:class:`MessageMultiDict` object.
+ The request query string as an :py:class:`MultiDictView` object.
"""
- return MessageMultiDict("query", self)
+ return MultiDictView("query", self)
@property
def _query(self):
@@ -244,13 +244,13 @@ class Request(Message):
@property
def cookies(self):
- # type: () -> MessageMultiDict
+ # type: () -> MultiDictView
"""
The request cookies.
- An empty :py:class:`MessageMultiDict` object if the cookie monster ate them all.
+ An empty :py:class:`MultiDictView` object if the cookie monster ate them all.
"""
- return MessageMultiDict("cookies", self)
+ return MultiDictView("cookies", self)
@property
def _cookies(self):
@@ -318,17 +318,18 @@ class Request(Message):
@property
def urlencoded_form(self):
"""
- The URL-encoded form data as an :py:class:`MessageMultiDict` object.
- None if the content-type indicates non-form data.
+ The URL-encoded form data as an :py:class:`MultiDictView` object.
+ An empty MultiDictView if the content-type indicates non-form data
+ or the content could not be parsed.
"""
- is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower()
- if is_valid_content_type:
- return MessageMultiDict("urlencoded_form", self)
- return None
+ return MultiDictView("urlencoded_form", self)
@property
def _urlencoded_form(self):
- return tuple(utils.urldecode(self.content))
+ is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower()
+ if is_valid_content_type:
+ return tuple(utils.urldecode(self.content))
+ return ()
@urlencoded_form.setter
def urlencoded_form(self, value):
@@ -345,45 +346,15 @@ class Request(Message):
The multipart form data as an :py:class:`MultipartFormDict` object.
None if the content-type indicates non-form data.
"""
- is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower()
- if is_valid_content_type:
- return MessageMultiDict("multipart_form", self)
- return None
+ return MultiDictView("multipart_form", self)
@property
def _multipart_form(self):
- return utils.multipartdecode(self.headers, self.content)
+ is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower()
+ if is_valid_content_type:
+ return utils.multipartdecode(self.headers, self.content)
+ return ()
@multipart_form.setter
def multipart_form(self, value):
raise NotImplementedError()
-
- # Legacy
-
- def get_query(self): # pragma: no cover
- warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning)
- return self.query or ODict([])
-
- def set_query(self, odict): # pragma: no cover
- warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning)
- self.query = odict
-
- def get_path_components(self): # pragma: no cover
- warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning)
- return self.path_components
-
- def set_path_components(self, lst): # pragma: no cover
- warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning)
- self.path_components = lst
-
- def get_form_urlencoded(self): # pragma: no cover
- warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning)
- return self.urlencoded_form or ODict([])
-
- def set_form_urlencoded(self, odict): # pragma: no cover
- warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning)
- self.urlencoded_form = odict
-
- def get_form_multipart(self): # pragma: no cover
- warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning)
- return self.multipart_form or ODict([])
diff --git a/netlib/http/response.py b/netlib/http/response.py
index 20074dca..6d56fc1f 100644
--- a/netlib/http/response.py
+++ b/netlib/http/response.py
@@ -1,14 +1,12 @@
from __future__ import absolute_import, print_function, division
-import warnings
from email.utils import parsedate_tz, formatdate, mktime_tz
import time
from . import cookies
from .headers import Headers
-from .message import Message, _native, _always_bytes, MessageData
+from .message import Message, _native, _always_bytes, MessageData, MultiDictView
from .. import utils
-from ..odict import ODict
class ResponseData(MessageData):
@@ -70,33 +68,32 @@ class Response(Message):
def reason(self, reason):
self.data.reason = _always_bytes(reason)
- # FIXME
@property
def cookies(self):
+ # type: () -> MultiDictView
"""
- Get the contents of all Set-Cookie headers.
+ The response cookies. A possibly empty :py:class:`MultiDictView`, where the keys are
+ cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is
+ an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly)
+ are indicated by a Null value.
- A possibly empty :py:class:`ODict`, where keys are cookie name strings,
- and values are [value, attr] lists. Value is a string, and attr is
- an ODictCaseless containing cookie attributes. Within attrs, unary
- attributes (e.g. HTTPOnly) are indicated by a Null value.
+ Caveats:
+ Updating the attr
"""
- ret = []
- for header in self.headers.get_all("set-cookie"):
- v = cookies.parse_set_cookie_header(header)
- if v:
- name, value, attrs = v
- ret.append([name, [value, attrs]])
- return ODict(ret)
-
- # FIXME
+ return MultiDictView("cookies", self)
+
+ @property
+ def _cookies(self):
+ h = self.headers.get_all("set-cookie")
+ return tuple(cookies.parse_set_cookie_headers(h))
+
@cookies.setter
- def cookies(self, odict):
- values = []
- for i in odict.lst:
- header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1])
- values.append(header)
- self.headers.set_all("set-cookie", values)
+ def cookies(self, all_cookies):
+ cookie_headers = []
+ for k, v in all_cookies:
+ header = cookies.format_set_cookie_header(k, v[0], v[1])
+ cookie_headers.append(header)
+ self.headers.set_all("set-cookie", cookie_headers)
def refresh(self, now=None):
"""
diff --git a/netlib/multidict.py b/netlib/multidict.py
index a7158bc5..32d5bfc2 100644
--- a/netlib/multidict.py
+++ b/netlib/multidict.py
@@ -1,163 +1,240 @@
-from __future__ import absolute_import, print_function, division
-
-from abc import ABCMeta, abstractmethod
-
-from typing import Tuple
-
-try:
- from collections.abc import MutableMapping
-except ImportError: # pragma: no cover
- from collections import MutableMapping # Workaround for Python < 3.3
-
-import six
-
-from .utils import Serializable
-
-
-@six.add_metaclass(ABCMeta)
-class MultiDict(MutableMapping, Serializable):
- def __init__(self, fields=None):
-
- # it is important for us that .fields is immutable, so that we can easily
- # detect changes to it.
- self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
-
- for key, value in self.fields:
- if not isinstance(key, bytes) or not isinstance(value, bytes):
- raise TypeError("MultiDict fields must be bytes.")
-
- def __repr__(self):
- fields = tuple(
- repr(field)
- for field in self.fields
- )
- return "{cls}[{fields}]".format(
- cls=type(self).__name__,
- fields=", ".join(fields)
- )
-
- @staticmethod
- @abstractmethod
- def _reduce_values(values):
- pass
-
- @staticmethod
- @abstractmethod
- def _kconv(v):
- pass
-
- def __getitem__(self, key):
- values = self.get_all(key)
- if not values:
- raise KeyError(key)
- return self._reduce_values(values)
-
- def __setitem__(self, key, value):
- self.set_all(key, [value])
-
- def __delitem__(self, key):
- if key not in self:
- raise KeyError(key)
- key = self._kconv(key)
- self.fields = tuple(
- field for field in self.fields
- if key != self._kconv(field[0])
- )
-
- def __iter__(self):
- seen = set()
- for key, _ in self.fields:
- key_kconv = self._kconv(key)
- if key_kconv not in seen:
- seen.add(key_kconv)
- yield key
-
- def __len__(self):
- return len(set(self._kconv(key) for key, _ in self.fields))
-
- def __eq__(self, other):
- if isinstance(other, MultiDict):
- return self.fields == other.fields
- return False
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def get_all(self, key):
- """
- Return the list of items for a given key.
- If that key is not in the MultiDict,
- the return value will be an empty list.
- """
- key = self._kconv(key)
- return [
- value
- for k, value in self.fields
- if self._kconv(k) == key
- ]
-
- def set_all(self, key, values):
- """
- Remove the old values for a key and add new ones.
- """
- key_kconv = self._kconv(key)
-
- new_fields = []
- for field in self.fields:
- if self._kconv(field[0]) == key_kconv:
- if values:
- new_fields.append(
- (key, values.pop(0))
- )
- else:
- new_fields.append(field)
- while values:
- new_fields.append(
- (key, values.pop(0))
- )
- self.fields = tuple(new_fields)
-
- def add(self, key, value):
- self.insert(len(self.fields), key, value)
-
- def insert(self, index, key, value):
- item = (key, value)
- self.fields = self.fields[:index] + (item,) + self.fields[index:]
-
- def keys(self, multi=False):
- return (
- k
- for k, _ in self.items(multi)
- )
-
- def values(self, multi=False):
- return (
- v
- for _, v in self.items(multi)
- )
-
- def items(self, multi=False):
- if multi:
- return self.fields
- else:
- return super(MultiDict, self).items()
-
- def to_dict(self):
- d = {}
- for key in self:
- values = self.get_all(key)
- if len(values) == 1:
- d[key] = values[0]
- else:
- d[key] = values
- return d
-
- def get_state(self):
- return self.fields
-
- def set_state(self, state):
- self.fields = tuple(tuple(x) for x in state)
-
- @classmethod
- def from_state(cls, state):
- return cls(tuple(x) for x in state)
+from __future__ import absolute_import, print_function, division
+
+from abc import ABCMeta, abstractmethod
+
+from typing import Tuple, TypeVar
+
+try:
+ from collections.abc import MutableMapping
+except ImportError: # pragma: no cover
+ from collections import MutableMapping # Workaround for Python < 3.3
+
+import six
+
+from .utils import Serializable
+
+
+@six.add_metaclass(ABCMeta)
+class MultiDict(MutableMapping, Serializable):
+ def __init__(self, fields=None):
+
+ # it is important for us that .fields is immutable, so that we can easily
+ # detect changes to it.
+ self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
+
+ def __repr__(self):
+ fields = tuple(
+ repr(field)
+ for field in self.fields
+ )
+ return "{cls}[{fields}]".format(
+ cls=type(self).__name__,
+ fields=", ".join(fields)
+ )
+
+ @staticmethod
+ @abstractmethod
+ def _reduce_values(values):
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def _kconv(v):
+ pass
+
+ def __getitem__(self, key):
+ values = self.get_all(key)
+ if not values:
+ raise KeyError(key)
+ return self._reduce_values(values)
+
+ def __setitem__(self, key, value):
+ self.set_all(key, [value])
+
+ def __delitem__(self, key):
+ if key not in self:
+ raise KeyError(key)
+ key = self._kconv(key)
+ self.fields = tuple(
+ field for field in self.fields
+ if key != self._kconv(field[0])
+ )
+
+ def __iter__(self):
+ seen = set()
+ for key, _ in self.fields:
+ key_kconv = self._kconv(key)
+ if key_kconv not in seen:
+ seen.add(key_kconv)
+ yield key
+
+ def __len__(self):
+ return len(set(self._kconv(key) for key, _ in self.fields))
+
+ def __eq__(self, other):
+ if isinstance(other, MultiDict):
+ return self.fields == other.fields
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def get_all(self, key):
+ """
+ Return the list of all values for a given key.
+ If that key is not in the MultiDict, the return value will be an empty list.
+ """
+ key = self._kconv(key)
+ return [
+ value
+ for k, value in self.fields
+ if self._kconv(k) == key
+ ]
+
+ def set_all(self, key, values):
+ """
+ Remove the old values for a key and add new ones.
+ """
+ key_kconv = self._kconv(key)
+
+ new_fields = []
+ for field in self.fields:
+ if self._kconv(field[0]) == key_kconv:
+ if values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ else:
+ new_fields.append(field)
+ while values:
+ new_fields.append(
+ (key, values.pop(0))
+ )
+ self.fields = tuple(new_fields)
+
+ def add(self, key, value):
+ """
+ Add an additional value for the given key at the bottom.
+ """
+ self.insert(len(self.fields), key, value)
+
+ def insert(self, index, key, value):
+ """
+ Insert an additional value for the given key at the specified position.
+ """
+ item = (key, value)
+ self.fields = self.fields[:index] + (item,) + self.fields[index:]
+
+ def keys(self, multi=False):
+ """
+ Get all keys.
+
+ Args:
+ multi(bool):
+ If True, one key per value will be returned.
+ If False, duplicate keys will only be returned once.
+ """
+ return (
+ k
+ for k, _ in self.items(multi)
+ )
+
+ def values(self, multi=False):
+ """
+ Get all values.
+
+ Args:
+ multi(bool):
+ If True, all values will be returned.
+ If False, only the first value per key will be returned.
+ """
+ return (
+ v
+ for _, v in self.items(multi)
+ )
+
+ def items(self, multi=False):
+ """
+ Get all (key, value) tuples.
+
+ Args:
+ multi(bool):
+ If True, all (key, value) pairs will be returned
+ If False, only the first (key, value) pair per unique key will be returned.
+ """
+ if multi:
+ return self.fields
+ else:
+ return super(MultiDict, self).items()
+
+ def to_dict(self):
+ """
+ Get the MultiDict as a plain Python dict.
+ Keys with multiple values are returned as lists.
+
+ Example:
+
+ .. code-block:: python
+
+ # Simple dict with duplicate values.
+ >>> d
+ MultiDictView[("name", "value"), ("a", "false"), ("a", "42")]
+ >>> d.to_dict()
+ {
+ "name": "value",
+ "a": ["false", "42"]
+ }
+ """
+ d = {}
+ for key in self:
+ values = self.get_all(key)
+ if len(values) == 1:
+ d[key] = values[0]
+ else:
+ d[key] = values
+ return d
+
+ def get_state(self):
+ return self.fields
+
+ def set_state(self, state):
+ self.fields = tuple(tuple(x) for x in state)
+
+ @classmethod
+ def from_state(cls, state):
+ return cls(tuple(x) for x in state)
+
+
+@six.add_metaclass(ABCMeta)
+class ImmutableMultiDict(MultiDict):
+ def _immutable(self, *_):
+ raise TypeError('{} objects are immutable'.format(self.__class__.__name__))
+
+ __delitem__ = set_all = insert = _immutable
+
+ def with_delitem(self, key):
+ """
+ Returns:
+ An updated ImmutableMultiDict. The original object will not be modified.
+ """
+ ret = self.copy()
+ super(ImmutableMultiDict, ret).__delitem__(key)
+ return ret
+
+ def with_set_all(self, key, values):
+ """
+ Returns:
+ An updated ImmutableMultiDict. The original object will not be modified.
+ """
+ ret = self.copy()
+ super(ImmutableMultiDict, ret).set_all(key, values)
+ return ret
+
+ def with_insert(self, index, key, value):
+ """
+ Returns:
+ An updated ImmutableMultiDict. The original object will not be modified.
+ """
+ ret = self.copy()
+ super(ImmutableMultiDict, ret).insert(index, key, value)
+ return ret
diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py
index d0a258e9..ac79b093 100644
--- a/test/mitmproxy/test_examples.py
+++ b/test/mitmproxy/test_examples.py
@@ -98,7 +98,7 @@ def test_modify_form():
flow.request.headers["content-type"] = ""
ex.run("request", flow)
- assert list(flow.request.urlencoded_form.items()) == [("foo","bar")]
+ assert list(flow.request.urlencoded_form.items()) == [("foo", "bar")]
def test_modify_querystring():
diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py
index e2cee57f..6f84c4ce 100644
--- a/test/netlib/http/test_cookies.py
+++ b/test/netlib/http/test_cookies.py
@@ -197,24 +197,28 @@ def test_parse_set_cookie_header():
],
[
"one=uno",
- ("one", "uno", [])
+ ("one", "uno", ())
],
[
"one=uno; foo=bar",
- ("one", "uno", [["foo", "bar"]])
- ]
+ ("one", "uno", (("foo", "bar"),))
+ ],
+ [
+ "one=uno; foo=bar; foo=baz",
+ ("one", "uno", (("foo", "bar"), ("foo", "baz")))
+ ],
]
for s, expected in vals:
ret = cookies.parse_set_cookie_header(s)
if expected:
assert ret[0] == expected[0]
assert ret[1] == expected[1]
- assert ret[2].lst == expected[2]
+ assert ret[2].items(multi=True) == expected[2]
s2 = cookies.format_set_cookie_header(*ret)
ret2 = cookies.parse_set_cookie_header(s2)
assert ret2[0] == expected[0]
assert ret2[1] == expected[1]
- assert ret2[2].lst == expected[2]
+ assert ret2[2].items(multi=True) == expected[2]
else:
assert ret is None
diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py
index 26593ee1..eefdc091 100644
--- a/test/netlib/http/test_request.py
+++ b/test/netlib/http/test_request.py
@@ -251,7 +251,7 @@ class TestRequestUtils(object):
def test_get_urlencoded_form(self):
request = treq(content="foobar=baz")
- assert request.urlencoded_form is None
+ assert not request.urlencoded_form
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
assert list(request.urlencoded_form.items()) == [("foobar", "baz")]
@@ -264,7 +264,7 @@ class TestRequestUtils(object):
def test_get_multipart_form(self):
request = treq(content="foobar")
- assert request.multipart_form is None
+ assert not request.multipart_form
request.headers["Content-Type"] = "multipart/form-data"
assert list(request.multipart_form.items()) == []
diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py
index 37273541..cfd093d4 100644
--- a/test/netlib/http/test_response.py
+++ b/test/netlib/http/test_response.py
@@ -6,6 +6,7 @@ import six
import time
from netlib.http import Headers
+from netlib.http.cookies import CookieAttrs
from netlib.odict import ODict, ODictCaseless
from netlib.tutils import raises, tresp
from .test_message import _test_passthrough_attr, _test_decoded_attr
@@ -56,7 +57,7 @@ class TestResponseUtils(object):
result = resp.cookies
assert len(result) == 1
assert "cookiename" in result
- assert result["cookiename"][0] == ["cookievalue", ODict()]
+ assert result["cookiename"] == ("cookievalue", CookieAttrs())
def test_get_cookies_with_parameters(self):
resp = tresp()
@@ -64,13 +65,13 @@ class TestResponseUtils(object):
result = resp.cookies
assert len(result) == 1
assert "cookiename" in result
- assert result["cookiename"][0][0] == "cookievalue"
- attrs = result["cookiename"][0][1]
+ assert result["cookiename"][0] == "cookievalue"
+ attrs = result["cookiename"][1]
assert len(attrs) == 4
- assert attrs["domain"] == ["example.com"]
- assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"]
- assert attrs["path"] == ["/"]
- assert attrs["httponly"] == [None]
+ assert attrs["domain"] == "example.com"
+ assert attrs["expires"] == "Wed Oct 21 16:29:41 2015"
+ assert attrs["path"] == "/"
+ assert attrs["httponly"] is None
def test_get_cookies_no_value(self):
resp = tresp()
@@ -78,8 +79,8 @@ class TestResponseUtils(object):
result = resp.cookies
assert len(result) == 1
assert "cookiename" in result
- assert result["cookiename"][0][0] == ""
- assert len(result["cookiename"][0][1]) == 2
+ assert result["cookiename"][0] == ""
+ assert len(result["cookiename"][1]) == 2
def test_get_cookies_twocookies(self):
resp = tresp()
@@ -90,19 +91,16 @@ class TestResponseUtils(object):
result = resp.cookies
assert len(result) == 2
assert "cookiename" in result
- assert result["cookiename"][0] == ["cookievalue", ODict()]
+ assert result["cookiename"] == ("cookievalue", CookieAttrs())
assert "othercookie" in result
- assert result["othercookie"][0] == ["othervalue", ODict()]
+ assert result["othercookie"] == ("othervalue", CookieAttrs())
def test_set_cookies(self):
resp = tresp()
- v = resp.cookies
- v.add("foo", ["bar", ODictCaseless()])
- resp.cookies = v
+ resp.cookies["foo"] = ("bar", {})
- v = resp.cookies
- assert len(v) == 1
- assert v["foo"] == [["bar", ODictCaseless()]]
+ assert len(resp.cookies) == 1
+ assert resp.cookies["foo"] == ("bar", CookieAttrs())
def test_refresh(self):
r = tresp()