aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2016-07-25 15:16:16 -0700
committerMaximilian Hils <git@maximilianhils.com>2016-07-25 15:16:16 -0700
commit79ebcb046e8669f80357a6c3046ec76c6adf49be (patch)
tree441981a16f1be1e620584e4a47f41767ce5585b2 /netlib
parent3254595584e1d711e7ae292ad34753a52f7a0fc1 (diff)
parent56796aeda25dda66621ce78af227ff46049ef811 (diff)
downloadmitmproxy-79ebcb046e8669f80357a6c3046ec76c6adf49be.tar.gz
mitmproxy-79ebcb046e8669f80357a6c3046ec76c6adf49be.tar.bz2
mitmproxy-79ebcb046e8669f80357a6c3046ec76c6adf49be.zip
Merge remote-tracking branch 'origin/master' into flow_editing_v2
Diffstat (limited to 'netlib')
-rw-r--r--netlib/encoding.py49
-rw-r--r--netlib/http/message.py121
-rw-r--r--netlib/http/request.py26
-rw-r--r--netlib/http/url.py41
-rw-r--r--netlib/multidict.py6
-rw-r--r--netlib/strutils.py6
6 files changed, 128 insertions, 121 deletions
diff --git a/netlib/encoding.py b/netlib/encoding.py
index e3cf5f30..da282194 100644
--- a/netlib/encoding.py
+++ b/netlib/encoding.py
@@ -4,6 +4,7 @@ Utility functions for decoding response bodies.
from __future__ import absolute_import
import codecs
+import collections
from io import BytesIO
import gzip
import zlib
@@ -11,7 +12,15 @@ import zlib
from typing import Union # noqa
-def decode(obj, encoding, errors='strict'):
+# We have a shared single-element cache for encoding and decoding.
+# This is quite useful in practice, e.g.
+# flow.request.content = flow.request.content.replace(b"foo", b"bar")
+# does not require an .encode() call if content does not contain b"foo"
+CachedDecode = collections.namedtuple("CachedDecode", "encoded encoding errors decoded")
+_cache = CachedDecode(None, None, None, None)
+
+
+def decode(encoded, encoding, errors='strict'):
# type: (Union[str, bytes], str, str) -> Union[str, bytes]
"""
Decode the given input object
@@ -22,20 +31,32 @@ def decode(obj, encoding, errors='strict'):
Raises:
ValueError, if decoding fails.
"""
+ global _cache
+ cached = (
+ isinstance(encoded, bytes) and
+ _cache.encoded == encoded and
+ _cache.encoding == encoding and
+ _cache.errors == errors
+ )
+ if cached:
+ return _cache.decoded
try:
try:
- return custom_decode[encoding](obj)
+ decoded = custom_decode[encoding](encoded)
except KeyError:
- return codecs.decode(obj, encoding, errors)
+ decoded = codecs.decode(encoded, encoding, errors)
+ if encoding in ("gzip", "deflate"):
+ _cache = CachedDecode(encoded, encoding, errors, decoded)
+ return decoded
except Exception as e:
raise ValueError("{} when decoding {} with {}".format(
type(e).__name__,
- repr(obj)[:10],
+ repr(encoded)[:10],
repr(encoding),
))
-def encode(obj, encoding, errors='strict'):
+def encode(decoded, encoding, errors='strict'):
# type: (Union[str, bytes], str, str) -> Union[str, bytes]
"""
Encode the given input object
@@ -46,15 +67,27 @@ def encode(obj, encoding, errors='strict'):
Raises:
ValueError, if encoding fails.
"""
+ global _cache
+ cached = (
+ isinstance(decoded, bytes) and
+ _cache.decoded == decoded and
+ _cache.encoding == encoding and
+ _cache.errors == errors
+ )
+ if cached:
+ return _cache.encoded
try:
try:
- return custom_encode[encoding](obj)
+ encoded = custom_encode[encoding](decoded)
except KeyError:
- return codecs.encode(obj, encoding, errors)
+ encoded = codecs.encode(decoded, encoding, errors)
+ if encoding in ("gzip", "deflate"):
+ _cache = CachedDecode(encoded, encoding, errors, decoded)
+ return encoded
except Exception as e:
raise ValueError("{} when encoding {} with {}".format(
type(e).__name__,
- repr(obj)[:10],
+ repr(decoded)[:10],
repr(encoding),
))
diff --git a/netlib/http/message.py b/netlib/http/message.py
index 34709f0a..be35b8d1 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -32,9 +32,6 @@ class MessageData(basetypes.Serializable):
def __ne__(self, other):
return not self.__eq__(other)
- def __hash__(self):
- return hash(frozenset(self.__dict__.items()))
-
def set_state(self, state):
for k, v in state.items():
if k == "headers":
@@ -52,23 +49,7 @@ class MessageData(basetypes.Serializable):
return cls(**state)
-class CachedDecode(object):
- __slots__ = ["encoded", "encoding", "strict", "decoded"]
-
- def __init__(self, object, encoding, strict, decoded):
- self.encoded = object
- self.encoding = encoding
- self.strict = strict
- self.decoded = decoded
-
-no_cached_decode = CachedDecode(None, None, None, None)
-
-
class Message(basetypes.Serializable):
- def __init__(self):
- self._content_cache = no_cached_decode # type: CachedDecode
- self._text_cache = no_cached_decode # type: CachedDecode
-
def __eq__(self, other):
if isinstance(other, Message):
return self.data == other.data
@@ -77,9 +58,6 @@ class Message(basetypes.Serializable):
def __ne__(self, other):
return not self.__eq__(other)
- def __hash__(self):
- return hash(self.data) ^ 1
-
def get_state(self):
return self.data.get_state()
@@ -132,25 +110,15 @@ class Message(basetypes.Serializable):
if self.raw_content is None:
return None
ce = self.headers.get("content-encoding")
- cached = (
- self._content_cache.encoded == self.raw_content and
- (self._content_cache.strict or not strict) and
- self._content_cache.encoding == ce
- )
- if not cached:
- is_strict = True
- if ce:
- try:
- decoded = encoding.decode(self.raw_content, ce)
- except ValueError:
- if strict:
- raise
- is_strict = False
- decoded = self.raw_content
- else:
- decoded = self.raw_content
- self._content_cache = CachedDecode(self.raw_content, ce, is_strict, decoded)
- return self._content_cache.decoded
+ if ce:
+ try:
+ return encoding.decode(self.raw_content, ce)
+ except ValueError:
+ if strict:
+ raise
+ return self.raw_content
+ else:
+ return self.raw_content
def set_content(self, value):
if value is None:
@@ -163,22 +131,13 @@ class Message(basetypes.Serializable):
.format(type(value).__name__)
)
ce = self.headers.get("content-encoding")
- cached = (
- self._content_cache.decoded == value and
- self._content_cache.encoding == ce and
- self._content_cache.strict
- )
- if not cached:
- try:
- encoded = encoding.encode(value, ce or "identity")
- except ValueError:
- # So we have an invalid content-encoding?
- # Let's remove it!
- del self.headers["content-encoding"]
- ce = None
- encoded = value
- self._content_cache = CachedDecode(encoded, ce, True, value)
- self.raw_content = self._content_cache.encoded
+ try:
+ self.raw_content = encoding.encode(value, ce or "identity")
+ except ValueError:
+ # So we have an invalid content-encoding?
+ # Let's remove it!
+ del self.headers["content-encoding"]
+ self.raw_content = value
self.headers["content-length"] = str(len(self.raw_content))
content = property(get_content, set_content)
@@ -250,22 +209,12 @@ class Message(basetypes.Serializable):
enc = self._guess_encoding()
content = self.get_content(strict)
- cached = (
- self._text_cache.encoded == content and
- (self._text_cache.strict or not strict) and
- self._text_cache.encoding == enc
- )
- if not cached:
- is_strict = self._content_cache.strict
- try:
- decoded = encoding.decode(content, enc)
- except ValueError:
- if strict:
- raise
- is_strict = False
- decoded = self.content.decode("utf8", "replace" if six.PY2 else "surrogateescape")
- self._text_cache = CachedDecode(content, enc, is_strict, decoded)
- return self._text_cache.decoded
+ try:
+ return encoding.decode(content, enc)
+ except ValueError:
+ if strict:
+ raise
+ return content.decode("utf8", "replace" if six.PY2 else "surrogateescape")
def set_text(self, text):
if text is None:
@@ -273,23 +222,15 @@ class Message(basetypes.Serializable):
return
enc = self._guess_encoding()
- cached = (
- self._text_cache.decoded == text and
- self._text_cache.encoding == enc and
- self._text_cache.strict
- )
- if not cached:
- try:
- encoded = encoding.encode(text, enc)
- except ValueError:
- # Fall back to UTF-8 and update the content-type header.
- ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
- ct[2]["charset"] = "utf-8"
- self.headers["content-type"] = headers.assemble_content_type(*ct)
- enc = "utf8"
- encoded = text.encode(enc, "replace" if six.PY2 else "surrogateescape")
- self._text_cache = CachedDecode(encoded, enc, True, text)
- self.content = self._text_cache.encoded
+ try:
+ self.content = encoding.encode(text, enc)
+ except ValueError:
+ # Fall back to UTF-8 and update the content-type header.
+ ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
+ ct[2]["charset"] = "utf-8"
+ self.headers["content-type"] = headers.assemble_content_type(*ct)
+ enc = "utf8"
+ self.content = text.encode(enc, "replace" if six.PY2 else "surrogateescape")
text = property(get_text, set_text)
diff --git a/netlib/http/request.py b/netlib/http/request.py
index ecaa9b79..061217a3 100644
--- a/netlib/http/request.py
+++ b/netlib/http/request.py
@@ -253,14 +253,13 @@ class Request(message.Message):
)
def _get_query(self):
- _, _, _, _, query, _ = urllib.parse.urlparse(self.url)
+ query = urllib.parse.urlparse(self.url).query
return tuple(netlib.http.url.decode(query))
- def _set_query(self, value):
- query = netlib.http.url.encode(value)
- scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
- _, _, _, self.path = netlib.http.url.parse(
- urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
+ def _set_query(self, query_data):
+ query = netlib.http.url.encode(query_data)
+ _, _, path, params, _, fragment = urllib.parse.urlparse(self.url)
+ self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment])
@query.setter
def query(self, value):
@@ -296,19 +295,18 @@ class Request(message.Message):
The URL's path components as a tuple of strings.
Components are unquoted.
"""
- _, _, path, _, _, _ = urllib.parse.urlparse(self.url)
+ path = urllib.parse.urlparse(self.url).path
# This needs to be a tuple so that it's immutable.
# Otherwise, this would fail silently:
# request.path_components.append("foo")
- return tuple(urllib.parse.unquote(i) for i in path.split("/") if i)
+ return tuple(netlib.http.url.unquote(i) for i in path.split("/") if i)
@path_components.setter
def path_components(self, components):
- components = map(lambda x: urllib.parse.quote(x, safe=""), components)
+ components = map(lambda x: netlib.http.url.quote(x, safe=""), components)
path = "/" + "/".join(components)
- scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url)
- _, _, _, self.path = netlib.http.url.parse(
- urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
+ _, _, _, params, query, fragment = urllib.parse.urlparse(self.url)
+ self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment])
def anticache(self):
"""
@@ -365,13 +363,13 @@ class Request(message.Message):
pass
return ()
- def _set_urlencoded_form(self, value):
+ def _set_urlencoded_form(self, form_data):
"""
Sets the body to the URL-encoded form data, and adds the appropriate content-type header.
This will overwrite the existing content if there is one.
"""
self.headers["content-type"] = "application/x-www-form-urlencoded"
- self.content = netlib.http.url.encode(value).encode()
+ self.content = netlib.http.url.encode(form_data).encode()
@urlencoded_form.setter
def urlencoded_form(self, value):
diff --git a/netlib/http/url.py b/netlib/http/url.py
index 2fc6e7ee..076854b9 100644
--- a/netlib/http/url.py
+++ b/netlib/http/url.py
@@ -82,18 +82,51 @@ def unparse(scheme, host, port, path=""):
def encode(s):
+ # type: Sequence[Tuple[str,str]] -> str
"""
Takes a list of (key, value) tuples and returns a urlencoded string.
"""
- s = [tuple(i) for i in s]
- return urllib.parse.urlencode(s, False)
+ if six.PY2:
+ return urllib.parse.urlencode(s, False)
+ else:
+ return urllib.parse.urlencode(s, False, errors="surrogateescape")
def decode(s):
"""
- Takes a urlencoded string and returns a list of (key, value) tuples.
+ Takes a urlencoded string and returns a list of surrogate-escaped (key, value) tuples.
+ """
+ if six.PY2:
+ return urllib.parse.parse_qsl(s, keep_blank_values=True)
+ else:
+ return urllib.parse.parse_qsl(s, keep_blank_values=True, errors='surrogateescape')
+
+
+def quote(b, safe="/"):
+ """
+ Returns:
+ An ascii-encodable str.
+ """
+ # type: (str) -> str
+ if six.PY2:
+ return urllib.parse.quote(b, safe=safe)
+ else:
+ return urllib.parse.quote(b, safe=safe, errors="surrogateescape")
+
+
+def unquote(s):
"""
- return urllib.parse.parse_qsl(s, keep_blank_values=True)
+ Args:
+ s: A surrogate-escaped str
+ Returns:
+ A surrogate-escaped str
+ """
+ # type: (str) -> str
+
+ if six.PY2:
+ return urllib.parse.unquote(s)
+ else:
+ return urllib.parse.unquote(s, errors="surrogateescape")
def hostport(scheme, host, port):
diff --git a/netlib/multidict.py b/netlib/multidict.py
index 51053ff6..e9fec155 100644
--- a/netlib/multidict.py
+++ b/netlib/multidict.py
@@ -79,9 +79,6 @@ class _MultiDict(MutableMapping, basetypes.Serializable):
def __ne__(self, other):
return not self.__eq__(other)
- def __hash__(self):
- return hash(self.fields)
-
def get_all(self, key):
"""
Return the list of all values for a given key.
@@ -241,6 +238,9 @@ class ImmutableMultiDict(MultiDict):
__delitem__ = set_all = insert = _immutable
+ def __hash__(self):
+ return hash(self.fields)
+
def with_delitem(self, key):
"""
Returns:
diff --git a/netlib/strutils.py b/netlib/strutils.py
index 32e77927..8f27ebb7 100644
--- a/netlib/strutils.py
+++ b/netlib/strutils.py
@@ -51,8 +51,7 @@ else:
def escape_control_characters(text, keep_spacing=True):
"""
- Replace all unicode C1 control characters from the given text with their respective control pictures.
- For example, a null byte is replaced with the unicode character "\u2400".
+ Replace all unicode C1 control characters from the given text with a single "."
Args:
keep_spacing: If True, tabs and newlines will not be replaced.
@@ -99,6 +98,9 @@ def bytes_to_escaped_str(data, keep_spacing=False):
def escaped_str_to_bytes(data):
"""
Take an escaped string and return the unescaped bytes equivalent.
+
+ Raises:
+ ValueError, if the escape sequence is invalid.
"""
if not isinstance(data, six.string_types):
if six.PY2: