aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/encoding.py47
-rw-r--r--netlib/http/message.py115
-rw-r--r--test/netlib/http/test_message.py33
-rw-r--r--test/netlib/test_encoding.py30
4 files changed, 100 insertions, 125 deletions
diff --git a/netlib/encoding.py b/netlib/encoding.py
index e3cf5f30..29e2a420 100644
--- a/netlib/encoding.py
+++ b/netlib/encoding.py
@@ -4,6 +4,7 @@ Utility functions for decoding response bodies.
from __future__ import absolute_import
import codecs
+import collections
from io import BytesIO
import gzip
import zlib
@@ -11,7 +12,15 @@ import zlib
from typing import Union # noqa
-def decode(obj, encoding, errors='strict'):
+# We have a shared single-element cache for encoding and decoding.
+# This is quite useful in practice, e.g.
+# flow.request.content = flow.request.content.replace(b"foo", b"bar")
+# does not require an .encode() call if content does not contain b"foo"
+CachedDecode = collections.namedtuple("CachedDecode", "encoded encoding errors decoded")
+_cache = CachedDecode(None, None, None, None)
+
+
+def decode(encoded, encoding, errors='strict'):
# type: (Union[str, bytes], str, str) -> Union[str, bytes]
"""
Decode the given input object
@@ -22,20 +31,31 @@ def decode(obj, encoding, errors='strict'):
Raises:
ValueError, if decoding fails.
"""
+ global _cache
+ cached = (
+ _cache.encoded == encoded and
+ _cache.encoding == encoding and
+ _cache.errors == errors
+ )
+ if cached:
+ return _cache.decoded
try:
try:
- return custom_decode[encoding](obj)
+ decoded = custom_decode[encoding](encoded)
except KeyError:
- return codecs.decode(obj, encoding, errors)
+ decoded = codecs.decode(encoded, encoding, errors)
+ if encoding in ("gzip", "deflate"):
+ _cache = CachedDecode(encoded, encoding, errors, decoded)
+ return decoded
except Exception as e:
raise ValueError("{} when decoding {} with {}".format(
type(e).__name__,
- repr(obj)[:10],
+ repr(encoded)[:10],
repr(encoding),
))
-def encode(obj, encoding, errors='strict'):
+def encode(decoded, encoding, errors='strict'):
# type: (Union[str, bytes], str, str) -> Union[str, bytes]
"""
Encode the given input object
@@ -46,15 +66,26 @@ def encode(obj, encoding, errors='strict'):
Raises:
ValueError, if encoding fails.
"""
+ global _cache
+ cached = (
+ _cache.decoded == decoded and
+ _cache.encoding == encoding and
+ _cache.errors == errors
+ )
+ if cached:
+ return _cache.encoded
try:
try:
- return custom_encode[encoding](obj)
+ encoded = custom_encode[encoding](decoded)
except KeyError:
- return codecs.encode(obj, encoding, errors)
+ encoded = codecs.encode(decoded, encoding, errors)
+ if encoding in ("gzip", "deflate"):
+ _cache = CachedDecode(encoded, encoding, errors, decoded)
+ return encoded
except Exception as e:
raise ValueError("{} when encoding {} with {}".format(
type(e).__name__,
- repr(obj)[:10],
+ repr(decoded)[:10],
repr(encoding),
))
diff --git a/netlib/http/message.py b/netlib/http/message.py
index a86e7489..be35b8d1 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -49,23 +49,7 @@ class MessageData(basetypes.Serializable):
return cls(**state)
-class CachedDecode(object):
- __slots__ = ["encoded", "encoding", "strict", "decoded"]
-
- def __init__(self, object, encoding, strict, decoded):
- self.encoded = object
- self.encoding = encoding
- self.strict = strict
- self.decoded = decoded
-
-no_cached_decode = CachedDecode(None, None, None, None)
-
-
class Message(basetypes.Serializable):
- def __init__(self):
- self._content_cache = no_cached_decode # type: CachedDecode
- self._text_cache = no_cached_decode # type: CachedDecode
-
def __eq__(self, other):
if isinstance(other, Message):
return self.data == other.data
@@ -126,25 +110,15 @@ class Message(basetypes.Serializable):
if self.raw_content is None:
return None
ce = self.headers.get("content-encoding")
- cached = (
- self._content_cache.encoded == self.raw_content and
- (self._content_cache.strict or not strict) and
- self._content_cache.encoding == ce
- )
- if not cached:
- is_strict = True
- if ce:
- try:
- decoded = encoding.decode(self.raw_content, ce)
- except ValueError:
- if strict:
- raise
- is_strict = False
- decoded = self.raw_content
- else:
- decoded = self.raw_content
- self._content_cache = CachedDecode(self.raw_content, ce, is_strict, decoded)
- return self._content_cache.decoded
+ if ce:
+ try:
+ return encoding.decode(self.raw_content, ce)
+ except ValueError:
+ if strict:
+ raise
+ return self.raw_content
+ else:
+ return self.raw_content
def set_content(self, value):
if value is None:
@@ -157,22 +131,13 @@ class Message(basetypes.Serializable):
.format(type(value).__name__)
)
ce = self.headers.get("content-encoding")
- cached = (
- self._content_cache.decoded == value and
- self._content_cache.encoding == ce and
- self._content_cache.strict
- )
- if not cached:
- try:
- encoded = encoding.encode(value, ce or "identity")
- except ValueError:
- # So we have an invalid content-encoding?
- # Let's remove it!
- del self.headers["content-encoding"]
- ce = None
- encoded = value
- self._content_cache = CachedDecode(encoded, ce, True, value)
- self.raw_content = self._content_cache.encoded
+ try:
+ self.raw_content = encoding.encode(value, ce or "identity")
+ except ValueError:
+ # So we have an invalid content-encoding?
+ # Let's remove it!
+ del self.headers["content-encoding"]
+ self.raw_content = value
self.headers["content-length"] = str(len(self.raw_content))
content = property(get_content, set_content)
@@ -244,22 +209,12 @@ class Message(basetypes.Serializable):
enc = self._guess_encoding()
content = self.get_content(strict)
- cached = (
- self._text_cache.encoded == content and
- (self._text_cache.strict or not strict) and
- self._text_cache.encoding == enc
- )
- if not cached:
- is_strict = self._content_cache.strict
- try:
- decoded = encoding.decode(content, enc)
- except ValueError:
- if strict:
- raise
- is_strict = False
- decoded = self.content.decode("utf8", "replace" if six.PY2 else "surrogateescape")
- self._text_cache = CachedDecode(content, enc, is_strict, decoded)
- return self._text_cache.decoded
+ try:
+ return encoding.decode(content, enc)
+ except ValueError:
+ if strict:
+ raise
+ return content.decode("utf8", "replace" if six.PY2 else "surrogateescape")
def set_text(self, text):
if text is None:
@@ -267,23 +222,15 @@ class Message(basetypes.Serializable):
return
enc = self._guess_encoding()
- cached = (
- self._text_cache.decoded == text and
- self._text_cache.encoding == enc and
- self._text_cache.strict
- )
- if not cached:
- try:
- encoded = encoding.encode(text, enc)
- except ValueError:
- # Fall back to UTF-8 and update the content-type header.
- ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
- ct[2]["charset"] = "utf-8"
- self.headers["content-type"] = headers.assemble_content_type(*ct)
- enc = "utf8"
- encoded = text.encode(enc, "replace" if six.PY2 else "surrogateescape")
- self._text_cache = CachedDecode(encoded, enc, True, text)
- self.content = self._text_cache.encoded
+ try:
+ self.content = encoding.encode(text, enc)
+ except ValueError:
+ # Fall back to UTF-8 and update the content-type header.
+ ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
+ ct[2]["charset"] = "utf-8"
+ self.headers["content-type"] = headers.assemble_content_type(*ct)
+ enc = "utf8"
+ self.content = text.encode(enc, "replace" if six.PY2 else "surrogateescape")
text = property(get_text, set_text)
diff --git a/test/netlib/http/test_message.py b/test/netlib/http/test_message.py
index 7f93830e..12e4706c 100644
--- a/test/netlib/http/test_message.py
+++ b/test/netlib/http/test_message.py
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function, division
-import mock
import six
from netlib.tutils import tresp
@@ -113,14 +112,6 @@ class TestMessageContentEncoding(object):
assert r.content == b"message"
assert r.raw_content != b"message"
- r.raw_content = b"foo"
- with mock.patch("netlib.encoding.decode") as e:
- assert r.content
- assert e.call_count == 1
- e.reset_mock()
- assert r.content
- assert e.call_count == 0
-
def test_modify(self):
r = tresp()
assert "content-encoding" not in r.headers
@@ -131,13 +122,6 @@ class TestMessageContentEncoding(object):
r.decode()
assert r.raw_content == b"foo"
- r.encode("identity")
- with mock.patch("netlib.encoding.encode") as e:
- r.content = b"foo"
- assert e.call_count == 0
- r.content = b"bar"
- assert e.call_count == 1
-
with tutils.raises(TypeError):
r.content = u"foo"
@@ -212,15 +196,6 @@ class TestMessageText(object):
r.headers["content-type"] = "text/html; charset=utf8"
assert r.text == u"ü"
- r.encode("identity")
- r.raw_content = b"foo"
- with mock.patch("netlib.encoding.decode") as e:
- assert r.text
- assert e.call_count == 2
- e.reset_mock()
- assert r.text
- assert e.call_count == 0
-
def test_guess_json(self):
r = tresp(content=b'"\xc3\xbc"')
r.headers["content-type"] = "application/json"
@@ -245,14 +220,6 @@ class TestMessageText(object):
assert r.raw_content == b"\xc3\xbc"
assert r.headers["content-length"] == "2"
- r.encode("identity")
- with mock.patch("netlib.encoding.encode") as e:
- e.return_value = b""
- r.text = u"ü"
- assert e.call_count == 0
- r.text = u"ä"
- assert e.call_count == 2
-
def test_unknown_ce(self):
r = tresp()
r.headers["content-type"] = "text/html; charset=wtf"
diff --git a/test/netlib/test_encoding.py b/test/netlib/test_encoding.py
index de10fc48..a5e81379 100644
--- a/test/netlib/test_encoding.py
+++ b/test/netlib/test_encoding.py
@@ -1,3 +1,4 @@
+import mock
from netlib import encoding, tutils
@@ -37,3 +38,32 @@ def test_deflate():
)
with tutils.raises(ValueError):
encoding.decode(b"bogus", "deflate")
+
+
+def test_cache():
+ decode_gzip = mock.MagicMock()
+ decode_gzip.return_value = b"decoded"
+ encode_gzip = mock.MagicMock()
+ encode_gzip.return_value = b"encoded"
+
+ with mock.patch.dict(encoding.custom_decode, gzip=decode_gzip):
+ with mock.patch.dict(encoding.custom_encode, gzip=encode_gzip):
+ assert encoding.decode(b"encoded", "gzip") == b"decoded"
+ assert decode_gzip.call_count == 1
+
+ # should be cached
+ assert encoding.decode(b"encoded", "gzip") == b"decoded"
+ assert decode_gzip.call_count == 1
+
+ # the other way around as well
+ assert encoding.encode(b"decoded", "gzip") == b"encoded"
+ assert encode_gzip.call_count == 0
+
+ # different encoding
+ decode_gzip.return_value = b"bar"
+ assert encoding.encode(b"decoded", "deflate") != b"decoded"
+ assert encode_gzip.call_count == 0
+
+ # This is not in the cache anymore
+ assert encoding.encode(b"decoded", "gzip") == b"encoded"
+ assert encode_gzip.call_count == 1