diff options
| author | Maximilian Hils <git@maximilianhils.com> | 2017-01-06 00:31:06 +0100 | 
|---|---|---|
| committer | Maximilian Hils <git@maximilianhils.com> | 2017-01-07 23:08:50 +0100 | 
| commit | 042261266f5b901b2b0745fd108c9a92525e9087 (patch) | |
| tree | 534f27bcad822a6e5ad5951875a4a1a2aba8bb93 | |
| parent | af194918cf862294216e67555b2a5e4ab9f93b08 (diff) | |
| download | mitmproxy-042261266f5b901b2b0745fd108c9a92525e9087.tar.gz mitmproxy-042261266f5b901b2b0745fd108c9a92525e9087.tar.bz2 mitmproxy-042261266f5b901b2b0745fd108c9a92525e9087.zip  | |
minor encoding fixes
- native() -> always_str()
  The old function name does not make sense on Python 3 only.
- Inline utility functions in message.py.
| -rw-r--r-- | mitmproxy/io_compat.py | 4 | ||||
| -rw-r--r-- | mitmproxy/net/http/message.py | 13 | ||||
| -rw-r--r-- | mitmproxy/net/http/request.py | 16 | ||||
| -rw-r--r-- | mitmproxy/net/tcp.py | 2 | ||||
| -rw-r--r-- | mitmproxy/net/wsgi.py | 18 | ||||
| -rw-r--r-- | mitmproxy/utils/strutils.py | 34 | ||||
| -rw-r--r-- | pathod/log.py | 2 | ||||
| -rw-r--r-- | pathod/pathoc.py | 18 | ||||
| -rw-r--r-- | test/mitmproxy/net/http/test_response.py | 15 | ||||
| -rw-r--r-- | test/mitmproxy/utils/test_strutils.py | 9 | 
10 files changed, 68 insertions, 63 deletions
diff --git a/mitmproxy/io_compat.py b/mitmproxy/io_compat.py index 8cdd0346..d0e33bce 100644 --- a/mitmproxy/io_compat.py +++ b/mitmproxy/io_compat.py @@ -93,7 +93,7 @@ def convert_100_200(data):  def _convert_dict_keys(o: Any) -> Any:      if isinstance(o, dict): -        return {strutils.native(k): _convert_dict_keys(v) for k, v in o.items()} +        return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}      else:          return o @@ -103,7 +103,7 @@ def _convert_dict_vals(o: dict, values_to_convert: dict) -> dict:          if not o or k not in o:              continue          if v is True: -            o[k] = strutils.native(o[k]) +            o[k] = strutils.always_str(o[k])          else:              _convert_dict_vals(o[k], v)      return o diff --git a/mitmproxy/net/http/message.py b/mitmproxy/net/http/message.py index 166f919a..c0a78ea9 100644 --- a/mitmproxy/net/http/message.py +++ b/mitmproxy/net/http/message.py @@ -7,15 +7,6 @@ from mitmproxy.types import serializable  from mitmproxy.net.http import headers -# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. -def _native(x): -    return x.decode("utf-8", "surrogateescape") - - -def _always_bytes(x): -    return strutils.always_bytes(x, "utf-8", "surrogateescape") - -  class MessageData(serializable.Serializable):      def __eq__(self, other):          if isinstance(other, MessageData): @@ -142,11 +133,11 @@ class Message(serializable.Serializable):          """          Version string, e.g. "HTTP/1.1"          """ -        return _native(self.data.http_version) +        return self.data.http_version.decode("utf-8", "surrogateescape")      @http_version.setter      def http_version(self, http_version): -        self.data.http_version = _always_bytes(http_version) +        self.data.http_version = strutils.always_bytes(http_version, "utf-8", "surrogateescape")      @property      def timestamp_start(self): diff --git a/mitmproxy/net/http/request.py b/mitmproxy/net/http/request.py index 7cc4def7..822f8229 100644 --- a/mitmproxy/net/http/request.py +++ b/mitmproxy/net/http/request.py @@ -115,24 +115,24 @@ class Request(message.Message):          """          HTTP request method, e.g. "GET".          """ -        return message._native(self.data.method).upper() +        return self.data.method.decode("utf-8", "surrogateescape").upper()      @method.setter      def method(self, method): -        self.data.method = message._always_bytes(method) +        self.data.method = strutils.always_bytes(method, "utf-8", "surrogateescape")      @property      def scheme(self):          """          HTTP request scheme, which should be "http" or "https".          """ -        if not self.data.scheme: -            return self.data.scheme -        return message._native(self.data.scheme) +        if self.data.scheme is None: +            return None +        return self.data.scheme.decode("utf-8", "surrogateescape")      @scheme.setter      def scheme(self, scheme): -        self.data.scheme = message._always_bytes(scheme) +        self.data.scheme = strutils.always_bytes(scheme, "utf-8", "surrogateescape")      @property      def host(self): @@ -190,11 +190,11 @@ class Request(message.Message):          if self.data.path is None:              return None          else: -            return message._native(self.data.path) +            return self.data.path.decode("utf-8", "surrogateescape")      @path.setter      def path(self, path): -        self.data.path = message._always_bytes(path) +        self.data.path = strutils.always_bytes(path, "utf-8", "surrogateescape")      @property      def url(self): diff --git a/mitmproxy/net/tcp.py b/mitmproxy/net/tcp.py index 2dd32c9b..eabc8006 100644 --- a/mitmproxy/net/tcp.py +++ b/mitmproxy/net/tcp.py @@ -538,7 +538,7 @@ class _Connection:                      self.ssl_verification_error = exceptions.InvalidCertificateException(                          "Certificate Verification Error for {}: {} (errno: {}, depth: {})".format(                              sni, -                            strutils.native(SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)), "utf8"), +                            strutils.always_str(SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)), "utf8"),                              errno,                              err_depth                          ) diff --git a/mitmproxy/net/wsgi.py b/mitmproxy/net/wsgi.py index b2705ea1..8bc5bb89 100644 --- a/mitmproxy/net/wsgi.py +++ b/mitmproxy/net/wsgi.py @@ -57,38 +57,38 @@ class WSGIAdaptor:          Raises:              ValueError, if the content-encoding is invalid.          """ -        path = strutils.native(flow.request.path, "latin-1") +        path = strutils.always_str(flow.request.path, "latin-1")          if '?' in path: -            path_info, query = strutils.native(path, "latin-1").split('?', 1) +            path_info, query = strutils.always_str(path, "latin-1").split('?', 1)          else:              path_info = path              query = ''          environ = {              'wsgi.version': (1, 0), -            'wsgi.url_scheme': strutils.native(flow.request.scheme, "latin-1"), +            'wsgi.url_scheme': strutils.always_str(flow.request.scheme, "latin-1"),              'wsgi.input': io.BytesIO(flow.request.content or b""),              'wsgi.errors': errsoc,              'wsgi.multithread': True,              'wsgi.multiprocess': False,              'wsgi.run_once': False,              'SERVER_SOFTWARE': self.sversion, -            'REQUEST_METHOD': strutils.native(flow.request.method, "latin-1"), +            'REQUEST_METHOD': strutils.always_str(flow.request.method, "latin-1"),              'SCRIPT_NAME': '',              'PATH_INFO': urllib.parse.unquote(path_info),              'QUERY_STRING': query, -            'CONTENT_TYPE': strutils.native(flow.request.headers.get('Content-Type', ''), "latin-1"), -            'CONTENT_LENGTH': strutils.native(flow.request.headers.get('Content-Length', ''), "latin-1"), +            'CONTENT_TYPE': strutils.always_str(flow.request.headers.get('Content-Type', ''), "latin-1"), +            'CONTENT_LENGTH': strutils.always_str(flow.request.headers.get('Content-Length', ''), "latin-1"),              'SERVER_NAME': self.domain,              'SERVER_PORT': str(self.port), -            'SERVER_PROTOCOL': strutils.native(flow.request.http_version, "latin-1"), +            'SERVER_PROTOCOL': strutils.always_str(flow.request.http_version, "latin-1"),          }          environ.update(extra)          if flow.client_conn.address: -            environ["REMOTE_ADDR"] = strutils.native(flow.client_conn.address.host, "latin-1") +            environ["REMOTE_ADDR"] = strutils.always_str(flow.client_conn.address.host, "latin-1")              environ["REMOTE_PORT"] = flow.client_conn.address.port          for key, value in flow.request.headers.items(): -            key = 'HTTP_' + strutils.native(key, "latin-1").upper().replace('-', '_') +            key = 'HTTP_' + strutils.always_str(key, "latin-1").upper().replace('-', '_')              if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):                  environ[key] = value          return environ diff --git a/mitmproxy/utils/strutils.py b/mitmproxy/utils/strutils.py index 57cfbc79..29465615 100644 --- a/mitmproxy/utils/strutils.py +++ b/mitmproxy/utils/strutils.py @@ -1,28 +1,28 @@  import re  import codecs +from typing import AnyStr, Optional -def always_bytes(unicode_or_bytes, *encode_args): -    if isinstance(unicode_or_bytes, str): -        return unicode_or_bytes.encode(*encode_args) -    elif isinstance(unicode_or_bytes, bytes) or unicode_or_bytes is None: -        return unicode_or_bytes +def always_bytes(str_or_bytes: Optional[AnyStr], *encode_args) -> Optional[bytes]: +    if isinstance(str_or_bytes, bytes) or str_or_bytes is None: +        return str_or_bytes +    elif isinstance(str_or_bytes, str): +        return str_or_bytes.encode(*encode_args)      else: -        raise TypeError("Expected str or bytes, but got {}.".format(type(unicode_or_bytes).__name__)) +        raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__)) -def native(s, *encoding_opts): +def always_str(str_or_bytes: Optional[AnyStr], *decode_args) -> Optional[str]:      """ -    Convert :py:class:`bytes` or :py:class:`unicode` to the native -    :py:class:`str` type, using latin1 encoding if conversion is necessary. - -    https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types +    Returns, +        str_or_bytes unmodified, if      """ -    if not isinstance(s, (bytes, str)): -        raise TypeError("%r is neither bytes nor unicode" % s) -    if isinstance(s, bytes): -        return s.decode(*encoding_opts) -    return s +    if isinstance(str_or_bytes, str) or str_or_bytes is None: +        return str_or_bytes +    elif isinstance(str_or_bytes, bytes): +        return str_or_bytes.decode(*decode_args) +    else: +        raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__))  # Translate control characters to "safe" characters. This implementation initially @@ -135,7 +135,7 @@ def hexdump(s):          part = s[i:i + 16]          x = " ".join("{:0=2x}".format(i) for i in part)          x = x.ljust(47)  # 16*2 + 15 -        part_repr = native(escape_control_characters( +        part_repr = always_str(escape_control_characters(              part.decode("ascii", "replace").replace(u"\ufffd", u"."),              False          )) diff --git a/pathod/log.py b/pathod/log.py index 4e5f355f..f7a7fc98 100644 --- a/pathod/log.py +++ b/pathod/log.py @@ -61,7 +61,7 @@ class LogCtx:              for line in strutils.hexdump(data):                  self("\t%s %s %s" % line)          else: -            data = strutils.native( +            data = strutils.always_str(                  strutils.escape_control_characters(                      data                          .decode("ascii", "replace") diff --git a/pathod/pathoc.py b/pathod/pathoc.py index 066c330c..3e804b63 100644 --- a/pathod/pathoc.py +++ b/pathod/pathoc.py @@ -44,7 +44,7 @@ class SSLInfo:      def __str__(self):          parts = [ -            "Application Layer Protocol: %s" % strutils.native(self.alp, "utf8"), +            "Application Layer Protocol: %s" % strutils.always_str(self.alp, "utf8"),              "Cipher: %s, %s bit, %s" % self.cipher,              "SSL certificate chain:"          ] @@ -53,24 +53,24 @@ class SSLInfo:              parts.append("\tSubject: ")              for cn in i.get_subject().get_components():                  parts.append("\t\t%s=%s" % ( -                    strutils.native(cn[0], "utf8"), -                    strutils.native(cn[1], "utf8")) +                    strutils.always_str(cn[0], "utf8"), +                    strutils.always_str(cn[1], "utf8"))                  )              parts.append("\tIssuer: ")              for cn in i.get_issuer().get_components():                  parts.append("\t\t%s=%s" % ( -                    strutils.native(cn[0], "utf8"), -                    strutils.native(cn[1], "utf8")) +                    strutils.always_str(cn[0], "utf8"), +                    strutils.always_str(cn[1], "utf8"))                  )              parts.extend(                  [                      "\tVersion: %s" % i.get_version(),                      "\tValidity: %s - %s" % ( -                        strutils.native(i.get_notBefore(), "utf8"), -                        strutils.native(i.get_notAfter(), "utf8") +                        strutils.always_str(i.get_notBefore(), "utf8"), +                        strutils.always_str(i.get_notAfter(), "utf8")                      ),                      "\tSerial: %s" % i.get_serial_number(), -                    "\tAlgorithm: %s" % strutils.native(i.get_signature_algorithm(), "utf8") +                    "\tAlgorithm: %s" % strutils.always_str(i.get_signature_algorithm(), "utf8")                  ]              )              pk = i.get_pubkey() @@ -82,7 +82,7 @@ class SSLInfo:              parts.append("\tPubkey: %s bit %s" % (pk.bits(), t))              s = certs.SSLCert(i)              if s.altnames: -                parts.append("\tSANs: %s" % " ".join(strutils.native(n, "utf8") for n in s.altnames)) +                parts.append("\tSANs: %s" % " ".join(strutils.always_str(n, "utf8") for n in s.altnames))          return "\n".join(parts) diff --git a/test/mitmproxy/net/http/test_response.py b/test/mitmproxy/net/http/test_response.py index 239fb6ef..eae957a8 100644 --- a/test/mitmproxy/net/http/test_response.py +++ b/test/mitmproxy/net/http/test_response.py @@ -55,7 +55,20 @@ class TestResponseCore:          _test_passthrough_attr(tresp(), "status_code")      def test_reason(self): -        _test_decoded_attr(tresp(), "reason") +        resp = tresp() +        assert resp.reason == "OK" + +        resp.reason = "ABC" +        assert resp.data.reason == b"ABC" + +        resp.reason = b"DEF" +        assert resp.data.reason == b"DEF" + +        resp.reason = None +        assert resp.data.reason is None + +        resp.data.reason = b'cr\xe9e' +        assert resp.reason == "crée"  class TestResponseUtils: diff --git a/test/mitmproxy/utils/test_strutils.py b/test/mitmproxy/utils/test_strutils.py index 84281c6b..1372d31f 100644 --- a/test/mitmproxy/utils/test_strutils.py +++ b/test/mitmproxy/utils/test_strutils.py @@ -11,11 +11,12 @@ def test_always_bytes():          strutils.always_bytes(42, "ascii") -def test_native(): +def test_always_str():      with tutils.raises(TypeError): -        strutils.native(42) -    assert strutils.native(u"foo") == u"foo" -    assert strutils.native(b"foo") == u"foo" +        strutils.always_str(42) +    assert strutils.always_str("foo") == "foo" +    assert strutils.always_str(b"foo") == "foo" +    assert strutils.always_str(None) is None  def test_escape_control_characters():  | 
