diff options
Diffstat (limited to 'netlib')
| -rw-r--r-- | netlib/encoding.py | 15 | ||||
| -rw-r--r-- | netlib/http/message.py | 2 | ||||
| -rw-r--r-- | netlib/http/request.py | 2 | ||||
| -rw-r--r-- | netlib/strutils.py | 4 | ||||
| -rw-r--r-- | netlib/tcp.py | 46 | 
5 files changed, 45 insertions, 24 deletions
| diff --git a/netlib/encoding.py b/netlib/encoding.py index da282194..9c8acff7 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -8,6 +8,7 @@ import collections  from io import BytesIO  import gzip  import zlib +import brotli  from typing import Union  # noqa @@ -45,7 +46,7 @@ def decode(encoded, encoding, errors='strict'):              decoded = custom_decode[encoding](encoded)          except KeyError:              decoded = codecs.decode(encoded, encoding, errors) -        if encoding in ("gzip", "deflate"): +        if encoding in ("gzip", "deflate", "br"):              _cache = CachedDecode(encoded, encoding, errors, decoded)          return decoded      except Exception as e: @@ -81,7 +82,7 @@ def encode(decoded, encoding, errors='strict'):              encoded = custom_encode[encoding](decoded)          except KeyError:              encoded = codecs.encode(decoded, encoding, errors) -        if encoding in ("gzip", "deflate"): +        if encoding in ("gzip", "deflate", "br"):              _cache = CachedDecode(encoded, encoding, errors, decoded)          return encoded      except Exception as e: @@ -113,6 +114,14 @@ def encode_gzip(content):      return s.getvalue() +def decode_brotli(content): +    return brotli.decompress(content) + + +def encode_brotli(content): +    return brotli.compress(content) + +  def decode_deflate(content):      """          Returns decompressed data for DEFLATE. Some servers may respond with @@ -139,11 +148,13 @@ custom_decode = {      "identity": identity,      "gzip": decode_gzip,      "deflate": decode_deflate, +    "br": decode_brotli,  }  custom_encode = {      "identity": identity,      "gzip": encode_gzip,      "deflate": encode_deflate, +    "br": encode_brotli,  }  __all__ = ["encode", "decode"] diff --git a/netlib/http/message.py b/netlib/http/message.py index be35b8d1..ce92bab1 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -248,7 +248,7 @@ class Message(basetypes.Serializable):      def encode(self, e):          """ -        Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". +        Encodes body with the encoding e, where e is "gzip", "deflate", "identity", or "br".          Any existing content-encodings are overwritten,          the content is not decoded beforehand. diff --git a/netlib/http/request.py b/netlib/http/request.py index 061217a3..d59fead4 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -337,7 +337,7 @@ class Request(message.Message):              self.headers["accept-encoding"] = (                  ', '.join(                      e -                    for e in {"gzip", "identity", "deflate"} +                    for e in {"gzip", "identity", "deflate", "br"}                      if e in accept_encoding                  )              ) diff --git a/netlib/strutils.py b/netlib/strutils.py index 8f27ebb7..4a46b6b1 100644 --- a/netlib/strutils.py +++ b/netlib/strutils.py @@ -69,7 +69,7 @@ def escape_control_characters(text, keep_spacing=True):      return text.translate(trans) -def bytes_to_escaped_str(data, keep_spacing=False): +def bytes_to_escaped_str(data, keep_spacing=False, escape_single_quotes=False):      """      Take bytes and return a safe string that can be displayed to the user. @@ -86,6 +86,8 @@ def bytes_to_escaped_str(data, keep_spacing=False):      # We always insert a double-quote here so that we get a single-quoted string back      # https://stackoverflow.com/questions/29019340/why-does-python-use-different-quotes-for-representing-strings-depending-on-their      ret = repr(b'"' + data).lstrip("b")[2:-1] +    if not escape_single_quotes: +        ret = re.sub(r"(?<!\\)(\\\\)*\\'", lambda m: (m.group(1) or "") + "'", ret)      if keep_spacing:          ret = re.sub(              r"(?<!\\)(\\\\)*\\([nrt])", diff --git a/netlib/tcp.py b/netlib/tcp.py index cf099edd..e5c84165 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -8,6 +8,10 @@ import time  import traceback  import binascii + +from typing import Optional  # noqa + +from netlib import strutils  from six.moves import range  import certifi @@ -35,7 +39,7 @@ EINTR = 4  if os.environ.get("NO_ALPN"):      HAS_ALPN = False  else: -    HAS_ALPN = OpenSSL._util.lib.Cryptography_HAS_ALPN +    HAS_ALPN = SSL._lib.Cryptography_HAS_ALPN  # To enable all SSL methods use: SSLv23  # then add options to disable certain methods @@ -287,16 +291,7 @@ class Reader(_FileLike):                  raise exceptions.TcpException(repr(e))          elif isinstance(self.o, SSL.Connection):              try: -                if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): -                    return self.o.recv(length, socket.MSG_PEEK) -                else: -                    # TODO: remove once a new version is released -                    # Polyfill for pyOpenSSL <= 0.15.1 -                    # Taken from https://github.com/pyca/pyopenssl/commit/1d95dea7fea03c7c0df345a5ea30c12d8a0378d2 -                    buf = SSL._ffi.new("char[]", length) -                    result = SSL._lib.SSL_peek(self.o._ssl, buf, length) -                    self.o._raise_ssl_error(self.o._ssl, result) -                    return SSL._ffi.buffer(buf, result)[:] +                return self.o.recv(length, socket.MSG_PEEK)              except SSL.Error as e:                  six.reraise(exceptions.TlsException, exceptions.TlsException(str(e)), sys.exc_info()[2])          else: @@ -511,6 +506,7 @@ class _Connection(object):                              alpn_protos=None,                              alpn_select=None,                              alpn_select_callback=None, +                            sni=None,                              ):          """          Creates an SSL Context. @@ -532,8 +528,14 @@ class _Connection(object):          if verify_options is not None:              def verify_cert(conn, x509, errno, err_depth, is_cert_verified):                  if not is_cert_verified: -                    self.ssl_verification_error = dict(errno=errno, -                                                       depth=err_depth) +                    self.ssl_verification_error = exceptions.InvalidCertificateException( +                        "Certificate Verification Error for {}: {} (errno: {}, depth: {})".format( +                            sni, +                            strutils.native(SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)), "utf8"), +                            errno, +                            err_depth +                        ) +                    )                  return is_cert_verified              context.set_verify(verify_options, verify_cert) @@ -609,7 +611,7 @@ class TCPClient(_Connection):          self.source_address = source_address          self.cert = None          self.server_certs = [] -        self.ssl_verification_error = None +        self.ssl_verification_error = None  # type: Optional[exceptions.InvalidCertificateException]          self.sni = None      @property @@ -671,6 +673,7 @@ class TCPClient(_Connection):          context = self.create_ssl_context(              alpn_protos=alpn_protos, +            sni=sni,              **sslctx_kwargs          )          self.connection = SSL.Connection(context, self.connection) @@ -682,14 +685,14 @@ class TCPClient(_Connection):              self.connection.do_handshake()          except SSL.Error as v:              if self.ssl_verification_error: -                raise exceptions.InvalidCertificateException("SSL handshake error: %s" % repr(v)) +                raise self.ssl_verification_error              else:                  raise exceptions.TlsException("SSL handshake error: %s" % repr(v))          else:              # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on              # certificate validation failure -            if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error is not None: -                raise exceptions.InvalidCertificateException("SSL handshake error: certificate verify failed") +            if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error: +                raise self.ssl_verification_error          self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) @@ -710,9 +713,14 @@ class TCPClient(_Connection):                  hostname = "no-hostname"              ssl_match_hostname.match_hostname(crt, hostname)          except (ValueError, ssl_match_hostname.CertificateError) as e: -            self.ssl_verification_error = dict(depth=0, errno="Invalid Hostname") +            self.ssl_verification_error = exceptions.InvalidCertificateException( +                "Certificate Verification Error for {}: {}".format( +                    sni or repr(self.address), +                    str(e) +                ) +            )              if verification_mode == SSL.VERIFY_PEER: -                raise exceptions.InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e))) +                raise self.ssl_verification_error          self.ssl_established = True          self.rfile.set_descriptor(self.connection) | 
