diff options
Diffstat (limited to 'netlib/http')
-rw-r--r-- | netlib/http/message.py | 121 | ||||
-rw-r--r-- | netlib/http/request.py | 26 | ||||
-rw-r--r-- | netlib/http/url.py | 41 |
3 files changed, 80 insertions, 108 deletions
diff --git a/netlib/http/message.py b/netlib/http/message.py index 34709f0a..be35b8d1 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -32,9 +32,6 @@ class MessageData(basetypes.Serializable): def __ne__(self, other): return not self.__eq__(other) - def __hash__(self): - return hash(frozenset(self.__dict__.items())) - def set_state(self, state): for k, v in state.items(): if k == "headers": @@ -52,23 +49,7 @@ class MessageData(basetypes.Serializable): return cls(**state) -class CachedDecode(object): - __slots__ = ["encoded", "encoding", "strict", "decoded"] - - def __init__(self, object, encoding, strict, decoded): - self.encoded = object - self.encoding = encoding - self.strict = strict - self.decoded = decoded - -no_cached_decode = CachedDecode(None, None, None, None) - - class Message(basetypes.Serializable): - def __init__(self): - self._content_cache = no_cached_decode # type: CachedDecode - self._text_cache = no_cached_decode # type: CachedDecode - def __eq__(self, other): if isinstance(other, Message): return self.data == other.data @@ -77,9 +58,6 @@ class Message(basetypes.Serializable): def __ne__(self, other): return not self.__eq__(other) - def __hash__(self): - return hash(self.data) ^ 1 - def get_state(self): return self.data.get_state() @@ -132,25 +110,15 @@ class Message(basetypes.Serializable): if self.raw_content is None: return None ce = self.headers.get("content-encoding") - cached = ( - self._content_cache.encoded == self.raw_content and - (self._content_cache.strict or not strict) and - self._content_cache.encoding == ce - ) - if not cached: - is_strict = True - if ce: - try: - decoded = encoding.decode(self.raw_content, ce) - except ValueError: - if strict: - raise - is_strict = False - decoded = self.raw_content - else: - decoded = self.raw_content - self._content_cache = CachedDecode(self.raw_content, ce, is_strict, decoded) - return self._content_cache.decoded + if ce: + try: + return encoding.decode(self.raw_content, ce) + except ValueError: + if strict: + raise + return self.raw_content + else: + return self.raw_content def set_content(self, value): if value is None: @@ -163,22 +131,13 @@ class Message(basetypes.Serializable): .format(type(value).__name__) ) ce = self.headers.get("content-encoding") - cached = ( - self._content_cache.decoded == value and - self._content_cache.encoding == ce and - self._content_cache.strict - ) - if not cached: - try: - encoded = encoding.encode(value, ce or "identity") - except ValueError: - # So we have an invalid content-encoding? - # Let's remove it! - del self.headers["content-encoding"] - ce = None - encoded = value - self._content_cache = CachedDecode(encoded, ce, True, value) - self.raw_content = self._content_cache.encoded + try: + self.raw_content = encoding.encode(value, ce or "identity") + except ValueError: + # So we have an invalid content-encoding? + # Let's remove it! + del self.headers["content-encoding"] + self.raw_content = value self.headers["content-length"] = str(len(self.raw_content)) content = property(get_content, set_content) @@ -250,22 +209,12 @@ class Message(basetypes.Serializable): enc = self._guess_encoding() content = self.get_content(strict) - cached = ( - self._text_cache.encoded == content and - (self._text_cache.strict or not strict) and - self._text_cache.encoding == enc - ) - if not cached: - is_strict = self._content_cache.strict - try: - decoded = encoding.decode(content, enc) - except ValueError: - if strict: - raise - is_strict = False - decoded = self.content.decode("utf8", "replace" if six.PY2 else "surrogateescape") - self._text_cache = CachedDecode(content, enc, is_strict, decoded) - return self._text_cache.decoded + try: + return encoding.decode(content, enc) + except ValueError: + if strict: + raise + return content.decode("utf8", "replace" if six.PY2 else "surrogateescape") def set_text(self, text): if text is None: @@ -273,23 +222,15 @@ class Message(basetypes.Serializable): return enc = self._guess_encoding() - cached = ( - self._text_cache.decoded == text and - self._text_cache.encoding == enc and - self._text_cache.strict - ) - if not cached: - try: - encoded = encoding.encode(text, enc) - except ValueError: - # Fall back to UTF-8 and update the content-type header. - ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {}) - ct[2]["charset"] = "utf-8" - self.headers["content-type"] = headers.assemble_content_type(*ct) - enc = "utf8" - encoded = text.encode(enc, "replace" if six.PY2 else "surrogateescape") - self._text_cache = CachedDecode(encoded, enc, True, text) - self.content = self._text_cache.encoded + try: + self.content = encoding.encode(text, enc) + except ValueError: + # Fall back to UTF-8 and update the content-type header. + ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {}) + ct[2]["charset"] = "utf-8" + self.headers["content-type"] = headers.assemble_content_type(*ct) + enc = "utf8" + self.content = text.encode(enc, "replace" if six.PY2 else "surrogateescape") text = property(get_text, set_text) diff --git a/netlib/http/request.py b/netlib/http/request.py index ecaa9b79..061217a3 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -253,14 +253,13 @@ class Request(message.Message): ) def _get_query(self): - _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + query = urllib.parse.urlparse(self.url).query return tuple(netlib.http.url.decode(query)) - def _set_query(self, value): - query = netlib.http.url.encode(value) - scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) - _, _, _, self.path = netlib.http.url.parse( - urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) + def _set_query(self, query_data): + query = netlib.http.url.encode(query_data) + _, _, path, params, _, fragment = urllib.parse.urlparse(self.url) + self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment]) @query.setter def query(self, value): @@ -296,19 +295,18 @@ class Request(message.Message): The URL's path components as a tuple of strings. Components are unquoted. """ - _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + path = urllib.parse.urlparse(self.url).path # This needs to be a tuple so that it's immutable. # Otherwise, this would fail silently: # request.path_components.append("foo") - return tuple(urllib.parse.unquote(i) for i in path.split("/") if i) + return tuple(netlib.http.url.unquote(i) for i in path.split("/") if i) @path_components.setter def path_components(self, components): - components = map(lambda x: urllib.parse.quote(x, safe=""), components) + components = map(lambda x: netlib.http.url.quote(x, safe=""), components) path = "/" + "/".join(components) - scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) - _, _, _, self.path = netlib.http.url.parse( - urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) + _, _, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment]) def anticache(self): """ @@ -365,13 +363,13 @@ class Request(message.Message): pass return () - def _set_urlencoded_form(self, value): + def _set_urlencoded_form(self, form_data): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. This will overwrite the existing content if there is one. """ self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = netlib.http.url.encode(value).encode() + self.content = netlib.http.url.encode(form_data).encode() @urlencoded_form.setter def urlencoded_form(self, value): diff --git a/netlib/http/url.py b/netlib/http/url.py index 2fc6e7ee..076854b9 100644 --- a/netlib/http/url.py +++ b/netlib/http/url.py @@ -82,18 +82,51 @@ def unparse(scheme, host, port, path=""): def encode(s): + # type: Sequence[Tuple[str,str]] -> str """ Takes a list of (key, value) tuples and returns a urlencoded string. """ - s = [tuple(i) for i in s] - return urllib.parse.urlencode(s, False) + if six.PY2: + return urllib.parse.urlencode(s, False) + else: + return urllib.parse.urlencode(s, False, errors="surrogateescape") def decode(s): """ - Takes a urlencoded string and returns a list of (key, value) tuples. + Takes a urlencoded string and returns a list of surrogate-escaped (key, value) tuples. + """ + if six.PY2: + return urllib.parse.parse_qsl(s, keep_blank_values=True) + else: + return urllib.parse.parse_qsl(s, keep_blank_values=True, errors='surrogateescape') + + +def quote(b, safe="/"): + """ + Returns: + An ascii-encodable str. + """ + # type: (str) -> str + if six.PY2: + return urllib.parse.quote(b, safe=safe) + else: + return urllib.parse.quote(b, safe=safe, errors="surrogateescape") + + +def unquote(s): """ - return urllib.parse.parse_qsl(s, keep_blank_values=True) + Args: + s: A surrogate-escaped str + Returns: + A surrogate-escaped str + """ + # type: (str) -> str + + if six.PY2: + return urllib.parse.unquote(s) + else: + return urllib.parse.unquote(s, errors="surrogateescape") def hostport(scheme, host, port): |