diff options
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/encoding.py | 49 | ||||
-rw-r--r-- | netlib/http/message.py | 121 | ||||
-rw-r--r-- | netlib/http/request.py | 26 | ||||
-rw-r--r-- | netlib/http/url.py | 41 | ||||
-rw-r--r-- | netlib/multidict.py | 6 | ||||
-rw-r--r-- | netlib/strutils.py | 6 |
6 files changed, 128 insertions, 121 deletions
diff --git a/netlib/encoding.py b/netlib/encoding.py index e3cf5f30..da282194 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -4,6 +4,7 @@ Utility functions for decoding response bodies. from __future__ import absolute_import import codecs +import collections from io import BytesIO import gzip import zlib @@ -11,7 +12,15 @@ import zlib from typing import Union # noqa -def decode(obj, encoding, errors='strict'): +# We have a shared single-element cache for encoding and decoding. +# This is quite useful in practice, e.g. +# flow.request.content = flow.request.content.replace(b"foo", b"bar") +# does not require an .encode() call if content does not contain b"foo" +CachedDecode = collections.namedtuple("CachedDecode", "encoded encoding errors decoded") +_cache = CachedDecode(None, None, None, None) + + +def decode(encoded, encoding, errors='strict'): # type: (Union[str, bytes], str, str) -> Union[str, bytes] """ Decode the given input object @@ -22,20 +31,32 @@ def decode(obj, encoding, errors='strict'): Raises: ValueError, if decoding fails. """ + global _cache + cached = ( + isinstance(encoded, bytes) and + _cache.encoded == encoded and + _cache.encoding == encoding and + _cache.errors == errors + ) + if cached: + return _cache.decoded try: try: - return custom_decode[encoding](obj) + decoded = custom_decode[encoding](encoded) except KeyError: - return codecs.decode(obj, encoding, errors) + decoded = codecs.decode(encoded, encoding, errors) + if encoding in ("gzip", "deflate"): + _cache = CachedDecode(encoded, encoding, errors, decoded) + return decoded except Exception as e: raise ValueError("{} when decoding {} with {}".format( type(e).__name__, - repr(obj)[:10], + repr(encoded)[:10], repr(encoding), )) -def encode(obj, encoding, errors='strict'): +def encode(decoded, encoding, errors='strict'): # type: (Union[str, bytes], str, str) -> Union[str, bytes] """ Encode the given input object @@ -46,15 +67,27 @@ def encode(obj, encoding, errors='strict'): Raises: ValueError, if encoding fails. """ + global _cache + cached = ( + isinstance(decoded, bytes) and + _cache.decoded == decoded and + _cache.encoding == encoding and + _cache.errors == errors + ) + if cached: + return _cache.encoded try: try: - return custom_encode[encoding](obj) + encoded = custom_encode[encoding](decoded) except KeyError: - return codecs.encode(obj, encoding, errors) + encoded = codecs.encode(decoded, encoding, errors) + if encoding in ("gzip", "deflate"): + _cache = CachedDecode(encoded, encoding, errors, decoded) + return encoded except Exception as e: raise ValueError("{} when encoding {} with {}".format( type(e).__name__, - repr(obj)[:10], + repr(decoded)[:10], repr(encoding), )) diff --git a/netlib/http/message.py b/netlib/http/message.py index 34709f0a..be35b8d1 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -32,9 +32,6 @@ class MessageData(basetypes.Serializable): def __ne__(self, other): return not self.__eq__(other) - def __hash__(self): - return hash(frozenset(self.__dict__.items())) - def set_state(self, state): for k, v in state.items(): if k == "headers": @@ -52,23 +49,7 @@ class MessageData(basetypes.Serializable): return cls(**state) -class CachedDecode(object): - __slots__ = ["encoded", "encoding", "strict", "decoded"] - - def __init__(self, object, encoding, strict, decoded): - self.encoded = object - self.encoding = encoding - self.strict = strict - self.decoded = decoded - -no_cached_decode = CachedDecode(None, None, None, None) - - class Message(basetypes.Serializable): - def __init__(self): - self._content_cache = no_cached_decode # type: CachedDecode - self._text_cache = no_cached_decode # type: CachedDecode - def __eq__(self, other): if isinstance(other, Message): return self.data == other.data @@ -77,9 +58,6 @@ class Message(basetypes.Serializable): def __ne__(self, other): return not self.__eq__(other) - def __hash__(self): - return hash(self.data) ^ 1 - def get_state(self): return self.data.get_state() @@ -132,25 +110,15 @@ class Message(basetypes.Serializable): if self.raw_content is None: return None ce = self.headers.get("content-encoding") - cached = ( - self._content_cache.encoded == self.raw_content and - (self._content_cache.strict or not strict) and - self._content_cache.encoding == ce - ) - if not cached: - is_strict = True - if ce: - try: - decoded = encoding.decode(self.raw_content, ce) - except ValueError: - if strict: - raise - is_strict = False - decoded = self.raw_content - else: - decoded = self.raw_content - self._content_cache = CachedDecode(self.raw_content, ce, is_strict, decoded) - return self._content_cache.decoded + if ce: + try: + return encoding.decode(self.raw_content, ce) + except ValueError: + if strict: + raise + return self.raw_content + else: + return self.raw_content def set_content(self, value): if value is None: @@ -163,22 +131,13 @@ class Message(basetypes.Serializable): .format(type(value).__name__) ) ce = self.headers.get("content-encoding") - cached = ( - self._content_cache.decoded == value and - self._content_cache.encoding == ce and - self._content_cache.strict - ) - if not cached: - try: - encoded = encoding.encode(value, ce or "identity") - except ValueError: - # So we have an invalid content-encoding? - # Let's remove it! - del self.headers["content-encoding"] - ce = None - encoded = value - self._content_cache = CachedDecode(encoded, ce, True, value) - self.raw_content = self._content_cache.encoded + try: + self.raw_content = encoding.encode(value, ce or "identity") + except ValueError: + # So we have an invalid content-encoding? + # Let's remove it! + del self.headers["content-encoding"] + self.raw_content = value self.headers["content-length"] = str(len(self.raw_content)) content = property(get_content, set_content) @@ -250,22 +209,12 @@ class Message(basetypes.Serializable): enc = self._guess_encoding() content = self.get_content(strict) - cached = ( - self._text_cache.encoded == content and - (self._text_cache.strict or not strict) and - self._text_cache.encoding == enc - ) - if not cached: - is_strict = self._content_cache.strict - try: - decoded = encoding.decode(content, enc) - except ValueError: - if strict: - raise - is_strict = False - decoded = self.content.decode("utf8", "replace" if six.PY2 else "surrogateescape") - self._text_cache = CachedDecode(content, enc, is_strict, decoded) - return self._text_cache.decoded + try: + return encoding.decode(content, enc) + except ValueError: + if strict: + raise + return content.decode("utf8", "replace" if six.PY2 else "surrogateescape") def set_text(self, text): if text is None: @@ -273,23 +222,15 @@ class Message(basetypes.Serializable): return enc = self._guess_encoding() - cached = ( - self._text_cache.decoded == text and - self._text_cache.encoding == enc and - self._text_cache.strict - ) - if not cached: - try: - encoded = encoding.encode(text, enc) - except ValueError: - # Fall back to UTF-8 and update the content-type header. - ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {}) - ct[2]["charset"] = "utf-8" - self.headers["content-type"] = headers.assemble_content_type(*ct) - enc = "utf8" - encoded = text.encode(enc, "replace" if six.PY2 else "surrogateescape") - self._text_cache = CachedDecode(encoded, enc, True, text) - self.content = self._text_cache.encoded + try: + self.content = encoding.encode(text, enc) + except ValueError: + # Fall back to UTF-8 and update the content-type header. + ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {}) + ct[2]["charset"] = "utf-8" + self.headers["content-type"] = headers.assemble_content_type(*ct) + enc = "utf8" + self.content = text.encode(enc, "replace" if six.PY2 else "surrogateescape") text = property(get_text, set_text) diff --git a/netlib/http/request.py b/netlib/http/request.py index ecaa9b79..061217a3 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -253,14 +253,13 @@ class Request(message.Message): ) def _get_query(self): - _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + query = urllib.parse.urlparse(self.url).query return tuple(netlib.http.url.decode(query)) - def _set_query(self, value): - query = netlib.http.url.encode(value) - scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) - _, _, _, self.path = netlib.http.url.parse( - urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) + def _set_query(self, query_data): + query = netlib.http.url.encode(query_data) + _, _, path, params, _, fragment = urllib.parse.urlparse(self.url) + self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment]) @query.setter def query(self, value): @@ -296,19 +295,18 @@ class Request(message.Message): The URL's path components as a tuple of strings. Components are unquoted. """ - _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + path = urllib.parse.urlparse(self.url).path # This needs to be a tuple so that it's immutable. # Otherwise, this would fail silently: # request.path_components.append("foo") - return tuple(urllib.parse.unquote(i) for i in path.split("/") if i) + return tuple(netlib.http.url.unquote(i) for i in path.split("/") if i) @path_components.setter def path_components(self, components): - components = map(lambda x: urllib.parse.quote(x, safe=""), components) + components = map(lambda x: netlib.http.url.quote(x, safe=""), components) path = "/" + "/".join(components) - scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) - _, _, _, self.path = netlib.http.url.parse( - urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) + _, _, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment]) def anticache(self): """ @@ -365,13 +363,13 @@ class Request(message.Message): pass return () - def _set_urlencoded_form(self, value): + def _set_urlencoded_form(self, form_data): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. This will overwrite the existing content if there is one. """ self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = netlib.http.url.encode(value).encode() + self.content = netlib.http.url.encode(form_data).encode() @urlencoded_form.setter def urlencoded_form(self, value): diff --git a/netlib/http/url.py b/netlib/http/url.py index 2fc6e7ee..076854b9 100644 --- a/netlib/http/url.py +++ b/netlib/http/url.py @@ -82,18 +82,51 @@ def unparse(scheme, host, port, path=""): def encode(s): + # type: Sequence[Tuple[str,str]] -> str """ Takes a list of (key, value) tuples and returns a urlencoded string. """ - s = [tuple(i) for i in s] - return urllib.parse.urlencode(s, False) + if six.PY2: + return urllib.parse.urlencode(s, False) + else: + return urllib.parse.urlencode(s, False, errors="surrogateescape") def decode(s): """ - Takes a urlencoded string and returns a list of (key, value) tuples. + Takes a urlencoded string and returns a list of surrogate-escaped (key, value) tuples. + """ + if six.PY2: + return urllib.parse.parse_qsl(s, keep_blank_values=True) + else: + return urllib.parse.parse_qsl(s, keep_blank_values=True, errors='surrogateescape') + + +def quote(b, safe="/"): + """ + Returns: + An ascii-encodable str. + """ + # type: (str) -> str + if six.PY2: + return urllib.parse.quote(b, safe=safe) + else: + return urllib.parse.quote(b, safe=safe, errors="surrogateescape") + + +def unquote(s): """ - return urllib.parse.parse_qsl(s, keep_blank_values=True) + Args: + s: A surrogate-escaped str + Returns: + A surrogate-escaped str + """ + # type: (str) -> str + + if six.PY2: + return urllib.parse.unquote(s) + else: + return urllib.parse.unquote(s, errors="surrogateescape") def hostport(scheme, host, port): diff --git a/netlib/multidict.py b/netlib/multidict.py index 51053ff6..e9fec155 100644 --- a/netlib/multidict.py +++ b/netlib/multidict.py @@ -79,9 +79,6 @@ class _MultiDict(MutableMapping, basetypes.Serializable): def __ne__(self, other): return not self.__eq__(other) - def __hash__(self): - return hash(self.fields) - def get_all(self, key): """ Return the list of all values for a given key. @@ -241,6 +238,9 @@ class ImmutableMultiDict(MultiDict): __delitem__ = set_all = insert = _immutable + def __hash__(self): + return hash(self.fields) + def with_delitem(self, key): """ Returns: diff --git a/netlib/strutils.py b/netlib/strutils.py index 32e77927..8f27ebb7 100644 --- a/netlib/strutils.py +++ b/netlib/strutils.py @@ -51,8 +51,7 @@ else: def escape_control_characters(text, keep_spacing=True): """ - Replace all unicode C1 control characters from the given text with their respective control pictures. - For example, a null byte is replaced with the unicode character "\u2400". + Replace all unicode C1 control characters from the given text with a single "." Args: keep_spacing: If True, tabs and newlines will not be replaced. @@ -99,6 +98,9 @@ def bytes_to_escaped_str(data, keep_spacing=False): def escaped_str_to_bytes(data): """ Take an escaped string and return the unescaped bytes equivalent. + + Raises: + ValueError, if the escape sequence is invalid. """ if not isinstance(data, six.string_types): if six.PY2: |