diff options
Diffstat (limited to 'netlib')
92 files changed, 10285 insertions, 0 deletions
diff --git a/netlib/.appveyor.yml b/netlib/.appveyor.yml new file mode 100644 index 00000000..cd1354c2 --- /dev/null +++ b/netlib/.appveyor.yml @@ -0,0 +1,11 @@ +version: '{build}' +shallow_clone: true +environment: +  matrix: +    - PYTHON: "C:\\Python27" +install: +  - "%PYTHON%\\Scripts\\pip install --src . -r requirements.txt" +  - "%PYTHON%\\python -c \"from OpenSSL import SSL; print(SSL.SSLeay_version(SSL.SSLEAY_VERSION))\"" +build: off  # Not a C# project +test_script: +  - "%PYTHON%\\Scripts\\py.test -n 4 --timeout 10" diff --git a/netlib/.coveragerc b/netlib/.coveragerc new file mode 100644 index 00000000..ccbebf8c --- /dev/null +++ b/netlib/.coveragerc @@ -0,0 +1,11 @@ +[run] +branch = True + +[report] +show_missing = True +include = *netlib/netlib* +exclude_lines = +    pragma: nocover +    pragma: no cover +    raise NotImplementedError() +omit = *contrib* diff --git a/netlib/.env b/netlib/.env new file mode 100644 index 00000000..69ac3f05 --- /dev/null +++ b/netlib/.env @@ -0,0 +1,6 @@ +DIR="$( dirname "${BASH_SOURCE[0]}" )" +ACTIVATE_DIR="$(if [ -f "$DIR/../venv.mitmproxy/bin/activate" ]; then echo 'bin'; else echo 'Scripts'; fi;)" +if [ -z "$VIRTUAL_ENV" ] && [ -f "$DIR/../venv.mitmproxy/$ACTIVATE_DIR/activate" ]; then +    echo "Activating mitmproxy virtualenv..." +    source "$DIR/../venv.mitmproxy/$ACTIVATE_DIR/activate" +fi diff --git a/netlib/.gitignore b/netlib/.gitignore new file mode 100644 index 00000000..d8ffb588 --- /dev/null +++ b/netlib/.gitignore @@ -0,0 +1,16 @@ +MANIFEST +/build +/dist +/tmp +/doc +*.py[cdo] +*.swp +*.swo +.coverage +.idea/ +__pycache__ +_cffi__* +.eggs/ +netlib.egg-info/ +pathod/ +.cache/
\ No newline at end of file diff --git a/netlib/.landscape.yml b/netlib/.landscape.yml new file mode 100644 index 00000000..9a3b615f --- /dev/null +++ b/netlib/.landscape.yml @@ -0,0 +1,16 @@ +max-line-length: 120 +pylint: +  options: +    dummy-variables-rgx: _$|.+_$|dummy_.+ + +  disable: +    - missing-docstring +    - protected-access +    - too-few-public-methods +    - too-many-arguments +    - too-many-instance-attributes +    - too-many-locals +    - too-many-public-methods +    - too-many-return-statements +    - too-many-statements +    - unpacking-non-sequence
\ No newline at end of file diff --git a/netlib/.travis.yml b/netlib/.travis.yml new file mode 100644 index 00000000..651fdae8 --- /dev/null +++ b/netlib/.travis.yml @@ -0,0 +1,98 @@ +sudo: false +language: python + +matrix: +  fast_finish: true +  include: +    - python: 2.7 +    - python: 2.7 +      env: OPENSSL=1.0.2 +      addons: +        apt: +          sources: +            # Debian sid currently holds OpenSSL 1.0.2 +            # change this with future releases! +            - debian-sid +          packages: +            - libssl-dev +    - python: 3.5 +    - python: 3.5 +      env: OPENSSL=1.0.2 +      addons: +        apt: +          sources: +            # Debian sid currently holds OpenSSL 1.0.2 +            # change this with future releases! +            - debian-sid +          packages: +            - libssl-dev +    - python: pypy +    - python: pypy +      env: OPENSSL=1.0.2 +      addons: +        apt: +          sources: +            # Debian sid currently holds OpenSSL 1.0.2 +            # change this with future releases! +            - debian-sid +          packages: +            - libssl-dev + +install: +  - | +    if [[ $TRAVIS_OS_NAME == "osx" ]] +    then +      brew update || brew update # try again if it fails +      brew outdated openssl || brew upgrade openssl +      brew install python +    fi +  - | +      if [ "$TRAVIS_PYTHON_VERSION" = "pypy" ]; then +        export PYENV_ROOT="$HOME/.pyenv" +        if [ -f "$PYENV_ROOT/bin/pyenv" ]; then +          pushd "$PYENV_ROOT" && git pull && popd +        else +          rm -rf "$PYENV_ROOT" && git clone --depth 1 https://github.com/yyuu/pyenv.git "$PYENV_ROOT" +        fi +        export PYPY_VERSION="4.0.1" +        "$PYENV_ROOT/bin/pyenv" install --skip-existing "pypy-$PYPY_VERSION" +        virtualenv --python="$PYENV_ROOT/versions/pypy-$PYPY_VERSION/bin/python" "$HOME/virtualenvs/pypy-$PYPY_VERSION" +        source "$HOME/virtualenvs/pypy-$PYPY_VERSION/bin/activate" +      fi +  - "pip install -U pip setuptools" +  - "pip install --src . -r requirements.txt" + +before_script: +  - "openssl version -a" + +script: +  - "py.test -s --cov netlib --timeout 10" + +after_success: +  - coveralls + +notifications: +  irc: +    channels: +      - "irc.oftc.net#mitmproxy" +    on_success: change +    on_failure: always +  slack: +    rooms: +        - mitmproxy:YaDGC9Gt9TEM7o8zkC2OLNsu#ci +    on_success: always +    on_failure: always + +# exclude cryptography from cache +# it depends on libssl-dev version +# which needs to be compiled specifically to each version +before_cache: +  - pip uninstall -y cryptography + +cache: +  directories: +    - $HOME/.cache/pip +    - /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages +    - /home/travis/virtualenv/python2.7.9/bin +    - /home/travis/virtualenv/pypy-2.5.0/site-packages +    - /home/travis/virtualenv/pypy-2.5.0/bin diff --git a/netlib/CONTRIBUTORS b/netlib/CONTRIBUTORS new file mode 100644 index 00000000..3a4b9b46 --- /dev/null +++ b/netlib/CONTRIBUTORS @@ -0,0 +1,22 @@ +   253	Aldo Cortesi +   230	Maximilian Hils +   123	Thomas Kriechbaumer +     8	Chandler Abraham +     8	Kyle Morton +     5	Sam Cleveland +     3	Benjamin Lee +     3	Sandor Nemes +     2	Brad Peabody +     2	Israel Nir +     2	Matthias Urlichs +     2	Pedro Worcel +     2	Sean Coates +     1	Andrey Plotnikov +     1	Bradley Baetz +     1	Felix Yan +     1	M. Utku Altinkaya +     1	Paul +     1	Pritam Baral +     1	Rouli +     1	Tim Becker +     1	kronick diff --git a/netlib/LICENSE b/netlib/LICENSE new file mode 100644 index 00000000..c08a0186 --- /dev/null +++ b/netlib/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2013, Aldo Cortesi. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/netlib/MANIFEST.in b/netlib/MANIFEST.in new file mode 100644 index 00000000..db0e2ed6 --- /dev/null +++ b/netlib/MANIFEST.in @@ -0,0 +1,4 @@ +include LICENSE CONTRIBUTORS README.rst +graft test +prune test/tools +recursive-exclude * *.pyc *.pyo *.swo *.swp
\ No newline at end of file diff --git a/netlib/README.rst b/netlib/README.rst new file mode 100644 index 00000000..694e3ad9 --- /dev/null +++ b/netlib/README.rst @@ -0,0 +1,35 @@ +|travis| |coveralls| |downloads| |latest-release| |python-versions| + +Netlib is a collection of network utility classes, used by the pathod and +mitmproxy projects. It differs from other projects in some fundamental +respects, because both pathod and mitmproxy often need to violate standards. +This means that protocols are implemented as small, well-contained and flexible +functions, and are designed to allow misbehaviour when needed. + + +Hacking +------- + +If you'd like to work on netlib, check out the instructions in mitmproxy's README_. + +.. |travis| image:: https://img.shields.io/travis/mitmproxy/netlib/master.svg +    :target: https://travis-ci.org/mitmproxy/netlib +    :alt: Build Status + +.. |coveralls| image:: https://img.shields.io/coveralls/mitmproxy/netlib/master.svg +    :target: https://coveralls.io/r/mitmproxy/netlib +    :alt: Coverage Status + +.. |downloads| image:: https://img.shields.io/pypi/dm/netlib.svg?color=orange +    :target: https://pypi.python.org/pypi/netlib +    :alt: Downloads + +.. |latest-release| image:: https://img.shields.io/pypi/v/netlib.svg +    :target: https://pypi.python.org/pypi/netlib +    :alt: Latest Version + +.. |python-versions| image:: https://img.shields.io/pypi/pyversions/netlib.svg +    :target: https://pypi.python.org/pypi/netlib +    :alt: Supported Python versions + +.. _README: https://github.com/mitmproxy/mitmproxy#hacking
\ No newline at end of file diff --git a/netlib/check_coding_style.sh b/netlib/check_coding_style.sh new file mode 100755 index 00000000..a1c94e03 --- /dev/null +++ b/netlib/check_coding_style.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +autopep8 -i -r -a -a . +if [[ -n "$(git status -s)" ]]; then +  echo "autopep8 yielded the following changes:" +  git status -s +  git --no-pager diff +  exit 0 # don't be so strict about coding style errors +fi + +autoflake -i -r --remove-all-unused-imports --remove-unused-variables . +if [[ -n "$(git status -s)" ]]; then +  echo "autoflake yielded the following changes:" +  git status -s +  git --no-pager diff +  exit 0 # don't be so strict about coding style errors +fi + +echo "Coding style seems to be ok." +exit 0 diff --git a/netlib/netlib/__init__.py b/netlib/netlib/__init__.py new file mode 100644 index 00000000..9b4faa33 --- /dev/null +++ b/netlib/netlib/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/netlib/certutils.py b/netlib/netlib/certutils.py new file mode 100644 index 00000000..616a778e --- /dev/null +++ b/netlib/netlib/certutils.py @@ -0,0 +1,472 @@ +from __future__ import (absolute_import, print_function, division) +import os +import ssl +import time +import datetime +from six.moves import filter +import ipaddress + +import sys +from pyasn1.type import univ, constraint, char, namedtype, tag +from pyasn1.codec.der.decoder import decode +from pyasn1.error import PyAsn1Error +import OpenSSL + +from .utils import Serializable + +# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 + +DEFAULT_EXP = 94608000  # = 24 * 60 * 60 * 365 * 3 +# Generated with "openssl dhparam". It's too slow to generate this on startup. +DEFAULT_DHPARAM = b""" +-----BEGIN DH PARAMETERS----- +MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 +O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv +j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ +Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB +chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC +ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq +o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX +IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv +A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8 +6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I +rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= +-----END DH PARAMETERS----- +""" + + +def create_ca(o, cn, exp): +    key = OpenSSL.crypto.PKey() +    key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048) +    cert = OpenSSL.crypto.X509() +    cert.set_serial_number(int(time.time() * 10000)) +    cert.set_version(2) +    cert.get_subject().CN = cn +    cert.get_subject().O = o +    cert.gmtime_adj_notBefore(-3600 * 48) +    cert.gmtime_adj_notAfter(exp) +    cert.set_issuer(cert.get_subject()) +    cert.set_pubkey(key) +    cert.add_extensions([ +        OpenSSL.crypto.X509Extension( +            b"basicConstraints", +            True, +            b"CA:TRUE" +        ), +        OpenSSL.crypto.X509Extension( +            b"nsCertType", +            False, +            b"sslCA" +        ), +        OpenSSL.crypto.X509Extension( +            b"extendedKeyUsage", +            False, +            b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" +        ), +        OpenSSL.crypto.X509Extension( +            b"keyUsage", +            True, +            b"keyCertSign, cRLSign" +        ), +        OpenSSL.crypto.X509Extension( +            b"subjectKeyIdentifier", +            False, +            b"hash", +            subject=cert +        ), +    ]) +    cert.sign(key, "sha256") +    return key, cert + + +def dummy_cert(privkey, cacert, commonname, sans): +    """ +        Generates a dummy certificate. + +        privkey: CA private key +        cacert: CA certificate +        commonname: Common name for the generated certificate. +        sans: A list of Subject Alternate Names. + +        Returns cert if operation succeeded, None if not. +    """ +    ss = [] +    for i in sans: +        try: +            ipaddress.ip_address(i.decode("ascii")) +        except ValueError: +            ss.append(b"DNS: %s" % i) +        else: +            ss.append(b"IP: %s" % i) +    ss = b", ".join(ss) + +    cert = OpenSSL.crypto.X509() +    cert.gmtime_adj_notBefore(-3600 * 48) +    cert.gmtime_adj_notAfter(DEFAULT_EXP) +    cert.set_issuer(cacert.get_subject()) +    if commonname is not None: +        cert.get_subject().CN = commonname +    cert.set_serial_number(int(time.time() * 10000)) +    if ss: +        cert.set_version(2) +        cert.add_extensions( +            [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)]) +    cert.set_pubkey(cacert.get_pubkey()) +    cert.sign(privkey, "sha256") +    return SSLCert(cert) + + +# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. +# +# class _Node(UserDict.UserDict): +#     def __init__(self): +#         UserDict.UserDict.__init__(self) +#         self.value = None +# +# +# class DNTree: +#     """ +#         Domain store that knows about wildcards. DNS wildcards are very +#         restricted - the only valid variety is an asterisk on the left-most +#         domain component, i.e.: +# +#             *.foo.com +#     """ +#     def __init__(self): +#         self.d = _Node() +# +#     def add(self, dn, cert): +#         parts = dn.split(".") +#         parts.reverse() +#         current = self.d +#         for i in parts: +#             current = current.setdefault(i, _Node()) +#         current.value = cert +# +#     def get(self, dn): +#         parts = dn.split(".") +#         current = self.d +#         for i in reversed(parts): +#             if i in current: +#                 current = current[i] +#             elif "*" in current: +#                 return current["*"].value +#             else: +#                 return None +#         return current.value + + +class CertStoreEntry(object): + +    def __init__(self, cert, privatekey, chain_file): +        self.cert = cert +        self.privatekey = privatekey +        self.chain_file = chain_file + + +class CertStore(object): + +    """ +        Implements an in-memory certificate store. +    """ + +    def __init__( +            self, +            default_privatekey, +            default_ca, +            default_chain_file, +            dhparams): +        self.default_privatekey = default_privatekey +        self.default_ca = default_ca +        self.default_chain_file = default_chain_file +        self.dhparams = dhparams +        self.certs = dict() + +    @staticmethod +    def load_dhparam(path): + +        # netlib<=0.10 doesn't generate a dhparam file. +        # Create it now if neccessary. +        if not os.path.exists(path): +            with open(path, "wb") as f: +                f.write(DEFAULT_DHPARAM) + +        bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r") +        if bio != OpenSSL.SSL._ffi.NULL: +            bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) +            dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( +                bio, +                OpenSSL.SSL._ffi.NULL, +                OpenSSL.SSL._ffi.NULL, +                OpenSSL.SSL._ffi.NULL) +            dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) +            return dh + +    @classmethod +    def from_store(cls, path, basename): +        ca_path = os.path.join(path, basename + "-ca.pem") +        if not os.path.exists(ca_path): +            key, ca = cls.create_store(path, basename) +        else: +            with open(ca_path, "rb") as f: +                raw = f.read() +            ca = OpenSSL.crypto.load_certificate( +                OpenSSL.crypto.FILETYPE_PEM, +                raw) +            key = OpenSSL.crypto.load_privatekey( +                OpenSSL.crypto.FILETYPE_PEM, +                raw) +        dh_path = os.path.join(path, basename + "-dhparam.pem") +        dh = cls.load_dhparam(dh_path) +        return cls(key, ca, ca_path, dh) + +    @staticmethod +    def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP): +        if not os.path.exists(path): +            os.makedirs(path) + +        o = o or basename +        cn = cn or basename + +        key, ca = create_ca(o=o, cn=cn, exp=expiry) +        # Dump the CA plus private key +        with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: +            f.write( +                OpenSSL.crypto.dump_privatekey( +                    OpenSSL.crypto.FILETYPE_PEM, +                    key)) +            f.write( +                OpenSSL.crypto.dump_certificate( +                    OpenSSL.crypto.FILETYPE_PEM, +                    ca)) + +        # Dump the certificate in PEM format +        with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: +            f.write( +                OpenSSL.crypto.dump_certificate( +                    OpenSSL.crypto.FILETYPE_PEM, +                    ca)) + +        # Create a .cer file with the same contents for Android +        with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: +            f.write( +                OpenSSL.crypto.dump_certificate( +                    OpenSSL.crypto.FILETYPE_PEM, +                    ca)) + +        # Dump the certificate in PKCS12 format for Windows devices +        with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: +            p12 = OpenSSL.crypto.PKCS12() +            p12.set_certificate(ca) +            p12.set_privatekey(key) +            f.write(p12.export()) + +        with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: +            f.write(DEFAULT_DHPARAM) + +        return key, ca + +    def add_cert_file(self, spec, path): +        with open(path, "rb") as f: +            raw = f.read() +        cert = SSLCert( +            OpenSSL.crypto.load_certificate( +                OpenSSL.crypto.FILETYPE_PEM, +                raw)) +        try: +            privatekey = OpenSSL.crypto.load_privatekey( +                OpenSSL.crypto.FILETYPE_PEM, +                raw) +        except Exception: +            privatekey = self.default_privatekey +        self.add_cert( +            CertStoreEntry(cert, privatekey, path), +            spec +        ) + +    def add_cert(self, entry, *names): +        """ +            Adds a cert to the certstore. We register the CN in the cert plus +            any SANs, and also the list of names provided as an argument. +        """ +        if entry.cert.cn: +            self.certs[entry.cert.cn] = entry +        for i in entry.cert.altnames: +            self.certs[i] = entry +        for i in names: +            self.certs[i] = entry + +    @staticmethod +    def asterisk_forms(dn): +        if dn is None: +            return [] +        parts = dn.split(b".") +        parts.reverse() +        curr_dn = b"" +        dn_forms = [b"*"] +        for part in parts[:-1]: +            curr_dn = b"." + part + curr_dn  # .example.com +            dn_forms.append(b"*" + curr_dn)   # *.example.com +        if parts[-1] != b"*": +            dn_forms.append(parts[-1] + curr_dn) +        return dn_forms + +    def get_cert(self, commonname, sans): +        """ +            Returns an (cert, privkey, cert_chain) tuple. + +            commonname: Common name for the generated certificate. Must be a +            valid, plain-ASCII, IDNA-encoded domain name. + +            sans: A list of Subject Alternate Names. +        """ + +        potential_keys = self.asterisk_forms(commonname) +        for s in sans: +            potential_keys.extend(self.asterisk_forms(s)) +        potential_keys.append((commonname, tuple(sans))) + +        name = next( +            filter(lambda key: key in self.certs, potential_keys), +            None +        ) +        if name: +            entry = self.certs[name] +        else: +            entry = CertStoreEntry( +                cert=dummy_cert( +                    self.default_privatekey, +                    self.default_ca, +                    commonname, +                    sans), +                privatekey=self.default_privatekey, +                chain_file=self.default_chain_file) +            self.certs[(commonname, tuple(sans))] = entry + +        return entry.cert, entry.privatekey, entry.chain_file + + +class _GeneralName(univ.Choice): +    # We are only interested in dNSNames. We use a default handler to ignore +    # other types. +    # TODO: We should also handle iPAddresses. +    componentType = namedtype.NamedTypes( +        namedtype.NamedType('dNSName', char.IA5String().subtype( +            implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) +        ) +        ), +    ) + + +class _GeneralNames(univ.SequenceOf): +    componentType = _GeneralName() +    sizeSpec = univ.SequenceOf.sizeSpec + \ +        constraint.ValueSizeConstraint(1, 1024) + + +class SSLCert(Serializable): + +    def __init__(self, cert): +        """ +            Returns a (common name, [subject alternative names]) tuple. +        """ +        self.x509 = cert + +    def __eq__(self, other): +        return self.digest("sha256") == other.digest("sha256") + +    def __ne__(self, other): +        return not self.__eq__(other) + +    def get_state(self): +        return self.to_pem() + +    def set_state(self, state): +        self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) + +    @classmethod +    def from_state(cls, state): +        cls.from_pem(state) + +    @classmethod +    def from_pem(cls, txt): +        x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) +        return cls(x509) + +    @classmethod +    def from_der(cls, der): +        pem = ssl.DER_cert_to_PEM_cert(der) +        return cls.from_pem(pem) + +    def to_pem(self): +        return OpenSSL.crypto.dump_certificate( +            OpenSSL.crypto.FILETYPE_PEM, +            self.x509) + +    def digest(self, name): +        return self.x509.digest(name) + +    @property +    def issuer(self): +        return self.x509.get_issuer().get_components() + +    @property +    def notbefore(self): +        t = self.x509.get_notBefore() +        return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") + +    @property +    def notafter(self): +        t = self.x509.get_notAfter() +        return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") + +    @property +    def has_expired(self): +        return self.x509.has_expired() + +    @property +    def subject(self): +        return self.x509.get_subject().get_components() + +    @property +    def serial(self): +        return self.x509.get_serial_number() + +    @property +    def keyinfo(self): +        pk = self.x509.get_pubkey() +        types = { +            OpenSSL.crypto.TYPE_RSA: "RSA", +            OpenSSL.crypto.TYPE_DSA: "DSA", +        } +        return ( +            types.get(pk.type(), "UNKNOWN"), +            pk.bits() +        ) + +    @property +    def cn(self): +        c = None +        for i in self.subject: +            if i[0] == b"CN": +                c = i[1] +        return c + +    @property +    def altnames(self): +        """ +        Returns: +            All DNS altnames. +        """ +        # tcp.TCPClient.convert_to_ssl assumes that this property only contains DNS altnames for hostname verification. +        altnames = [] +        for i in range(self.x509.get_extension_count()): +            ext = self.x509.get_extension(i) +            if ext.get_short_name() == b"subjectAltName": +                try: +                    dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) +                except PyAsn1Error: +                    continue +                for i in dec[0]: +                    altnames.append(i[0].asOctets()) +        return altnames diff --git a/netlib/netlib/encoding.py b/netlib/netlib/encoding.py new file mode 100644 index 00000000..14479e00 --- /dev/null +++ b/netlib/netlib/encoding.py @@ -0,0 +1,88 @@ +""" +    Utility functions for decoding response bodies. +""" +from __future__ import absolute_import +from io import BytesIO +import gzip +import zlib +from .utils import always_byte_args + + +ENCODINGS = {"identity", "gzip", "deflate"} + + +def decode(e, content): +    if not isinstance(content, bytes): +        return None +    encoding_map = { +        "identity": identity, +        "gzip": decode_gzip, +        "deflate": decode_deflate, +    } +    if e not in encoding_map: +        return None +    return encoding_map[e](content) + + +def encode(e, content): +    if not isinstance(content, bytes): +        return None +    encoding_map = { +        "identity": identity, +        "gzip": encode_gzip, +        "deflate": encode_deflate, +    } +    if e not in encoding_map: +        return None +    return encoding_map[e](content) + + +def identity(content): +    """ +        Returns content unchanged. Identity is the default value of +        Accept-Encoding headers. +    """ +    return content + + +def decode_gzip(content): +    gfile = gzip.GzipFile(fileobj=BytesIO(content)) +    try: +        return gfile.read() +    except (IOError, EOFError): +        return None + + +def encode_gzip(content): +    s = BytesIO() +    gf = gzip.GzipFile(fileobj=s, mode='wb') +    gf.write(content) +    gf.close() +    return s.getvalue() + + +def decode_deflate(content): +    """ +        Returns decompressed data for DEFLATE. Some servers may respond with +        compressed data without a zlib header or checksum. An undocumented +        feature of zlib permits the lenient decompression of data missing both +        values. + +        http://bugs.python.org/issue5784 +    """ +    try: +        try: +            return zlib.decompress(content) +        except zlib.error: +            return zlib.decompress(content, -15) +    except zlib.error: +        return None + + +def encode_deflate(content): +    """ +        Returns compressed content, always including zlib header and checksum. +    """ +    return zlib.compress(content) + +__all__ = ["ENCODINGS", "encode", "decode"] diff --git a/netlib/netlib/exceptions.py b/netlib/netlib/exceptions.py new file mode 100644 index 00000000..05f1054b --- /dev/null +++ b/netlib/netlib/exceptions.py @@ -0,0 +1,56 @@ +""" +We try to be very hygienic regarding the exceptions we throw: +Every Exception netlib raises shall be a subclass of NetlibException. + + +See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/ +""" +from __future__ import absolute_import, print_function, division + + +class NetlibException(Exception): +    """ +    Base class for all exceptions thrown by netlib. +    """ +    def __init__(self, message=None): +        super(NetlibException, self).__init__(message) + + +class Disconnect(object): +    """Immediate EOF""" + + +class HttpException(NetlibException): +    pass + + +class HttpReadDisconnect(HttpException, Disconnect): +    pass + + +class HttpSyntaxException(HttpException): +    pass + + +class TcpException(NetlibException): +    pass + + +class TcpDisconnect(TcpException, Disconnect): +    pass + + +class TcpReadIncomplete(TcpException): +    pass + + +class TcpTimeout(TcpException): +    pass + + +class TlsException(NetlibException): +    pass + + +class InvalidCertificateException(TlsException): +    pass diff --git a/netlib/netlib/http/__init__.py b/netlib/netlib/http/__init__.py new file mode 100644 index 00000000..fd632cd5 --- /dev/null +++ b/netlib/netlib/http/__init__.py @@ -0,0 +1,14 @@ +from __future__ import absolute_import, print_function, division +from .request import Request +from .response import Response +from .headers import Headers +from .message import decoded, CONTENT_MISSING +from . import http1, http2 + +__all__ = [ +    "Request", +    "Response", +    "Headers", +    "decoded", "CONTENT_MISSING", +    "http1", "http2", +] diff --git a/netlib/netlib/http/authentication.py b/netlib/netlib/http/authentication.py new file mode 100644 index 00000000..d769abe5 --- /dev/null +++ b/netlib/netlib/http/authentication.py @@ -0,0 +1,167 @@ +from __future__ import (absolute_import, print_function, division) +from argparse import Action, ArgumentTypeError +import binascii + + +def parse_http_basic_auth(s): +    words = s.split() +    if len(words) != 2: +        return None +    scheme = words[0] +    try: +        user = binascii.a2b_base64(words[1]).decode("utf8", "replace") +    except binascii.Error: +        return None +    parts = user.split(':') +    if len(parts) != 2: +        return None +    return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): +    v = binascii.b2a_base64((username + ":" + password).encode("utf8")).decode("ascii") +    return scheme + " " + v + + +class NullProxyAuth(object): + +    """ +        No proxy auth at all (returns empty challange headers) +    """ + +    def __init__(self, password_manager): +        self.password_manager = password_manager + +    def clean(self, headers_): +        """ +            Clean up authentication headers, so they're not passed upstream. +        """ + +    def authenticate(self, headers_): +        """ +            Tests that the user is allowed to use the proxy +        """ +        return True + +    def auth_challenge_headers(self): +        """ +            Returns a dictionary containing the headers require to challenge the user +        """ +        return {} + + +class BasicProxyAuth(NullProxyAuth): +    CHALLENGE_HEADER = 'Proxy-Authenticate' +    AUTH_HEADER = 'Proxy-Authorization' + +    def __init__(self, password_manager, realm): +        NullProxyAuth.__init__(self, password_manager) +        self.realm = realm + +    def clean(self, headers): +        del headers[self.AUTH_HEADER] + +    def authenticate(self, headers): +        auth_value = headers.get(self.AUTH_HEADER) +        if not auth_value: +            return False +        parts = parse_http_basic_auth(auth_value) +        if not parts: +            return False +        scheme, username, password = parts +        if scheme.lower() != 'basic': +            return False +        if not self.password_manager.test(username, password): +            return False +        self.username = username +        return True + +    def auth_challenge_headers(self): +        return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} + + +class PassMan(object): + +    def test(self, username_, password_token_): +        return False + + +class PassManNonAnon(PassMan): + +    """ +        Ensure the user specifies a username, accept any password. +    """ + +    def test(self, username, password_token_): +        if username: +            return True +        return False + + +class PassManHtpasswd(PassMan): + +    """ +        Read usernames and passwords from an htpasswd file +    """ + +    def __init__(self, path): +        """ +            Raises ValueError if htpasswd file is invalid. +        """ +        import passlib.apache +        self.htpasswd = passlib.apache.HtpasswdFile(path) + +    def test(self, username, password_token): +        return bool(self.htpasswd.check_password(username, password_token)) + + +class PassManSingleUser(PassMan): + +    def __init__(self, username, password): +        self.username, self.password = username, password + +    def test(self, username, password_token): +        return self.username == username and self.password == password_token + + +class AuthAction(Action): + +    """ +    Helper class to allow seamless integration int argparse. Example usage: +    parser.add_argument( +        "--nonanonymous", +        action=NonanonymousAuthAction, nargs=0, +        help="Allow access to any user long as a credentials are specified." +    ) +    """ + +    def __call__(self, parser, namespace, values, option_string=None): +        passman = self.getPasswordManager(values) +        authenticator = BasicProxyAuth(passman, "mitmproxy") +        setattr(namespace, self.dest, authenticator) + +    def getPasswordManager(self, s):  # pragma: nocover +        raise NotImplementedError() + + +class SingleuserAuthAction(AuthAction): + +    def getPasswordManager(self, s): +        if len(s.split(':')) != 2: +            raise ArgumentTypeError( +                "Invalid single-user specification. Please use the format username:password" +            ) +        username, password = s.split(':') +        return PassManSingleUser(username, password) + + +class NonanonymousAuthAction(AuthAction): + +    def getPasswordManager(self, s): +        return PassManNonAnon() + + +class HtpasswdAuthAction(AuthAction): + +    def getPasswordManager(self, s): +        return PassManHtpasswd(s) diff --git a/netlib/netlib/http/cookies.py b/netlib/netlib/http/cookies.py new file mode 100644 index 00000000..18544b5e --- /dev/null +++ b/netlib/netlib/http/cookies.py @@ -0,0 +1,193 @@ +import re + +from .. import odict + +""" +A flexible module for cookie parsing and manipulation. + +This module differs from usual standards-compliant cookie modules in a number +of ways. We try to be as permissive as possible, and to retain even mal-formed +information. Duplicate cookies are preserved in parsing, and can be set in +formatting. We do attempt to escape and quote values where needed, but will not +reject data that violate the specs. + +Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do +not parse the comma-separated variant of Set-Cookie that allows multiple +cookies to be set in a single header. Technically this should be feasible, but +it turns out that violations of RFC6265 that makes the parsing problem +indeterminate are much more common than genuine occurences of the multi-cookie +variants. Serialization follows RFC6265. + +    http://tools.ietf.org/html/rfc6265 +    http://tools.ietf.org/html/rfc2109 +    http://tools.ietf.org/html/rfc2965 +""" + +# TODO: Disallow LHS-only Cookie values + + +def _read_until(s, start, term): +    """ +        Read until one of the characters in term is reached. +    """ +    if start == len(s): +        return "", start + 1 +    for i in range(start, len(s)): +        if s[i] in term: +            return s[start:i], i +    return s[start:i + 1], i + 1 + + +def _read_token(s, start): +    """ +        Read a token - the LHS of a token/value pair in a cookie. +    """ +    return _read_until(s, start, ";=") + + +def _read_quoted_string(s, start): +    """ +        start: offset to the first quote of the string to be read + +        A sort of loose super-set of the various quoted string specifications. + +        RFC6265 disallows backslashes or double quotes within quoted strings. +        Prior RFCs use backslashes to escape. This leaves us free to apply +        backslash escaping by default and be compatible with everything. +    """ +    escaping = False +    ret = [] +    # Skip the first quote +    i = start  # initialize in case the loop doesn't run. +    for i in range(start + 1, len(s)): +        if escaping: +            ret.append(s[i]) +            escaping = False +        elif s[i] == '"': +            break +        elif s[i] == "\\": +            escaping = True +        else: +            ret.append(s[i]) +    return "".join(ret), i + 1 + + +def _read_value(s, start, delims): +    """ +        Reads a value - the RHS of a token/value pair in a cookie. + +        special: If the value is special, commas are premitted. Else comma +        terminates. This helps us support old and new style values. +    """ +    if start >= len(s): +        return "", start +    elif s[start] == '"': +        return _read_quoted_string(s, start) +    else: +        return _read_until(s, start, delims) + + +def _read_pairs(s, off=0): +    """ +        Read pairs of lhs=rhs values. + +        off: start offset +        specials: a lower-cased list of keys that may contain commas +    """ +    vals = [] +    while True: +        lhs, off = _read_token(s, off) +        lhs = lhs.lstrip() +        if lhs: +            rhs = None +            if off < len(s): +                if s[off] == "=": +                    rhs, off = _read_value(s, off + 1, ";") +            vals.append([lhs, rhs]) +        off += 1 +        if not off < len(s): +            break +    return vals, off + + +def _has_special(s): +    for i in s: +        if i in '",;\\': +            return True +        o = ord(i) +        if o < 0x21 or o > 0x7e: +            return True +    return False + + +ESCAPE = re.compile(r"([\"\\])") + + +def _format_pairs(lst, specials=(), sep="; "): +    """ +        specials: A lower-cased list of keys that will not be quoted. +    """ +    vals = [] +    for k, v in lst: +        if v is None: +            vals.append(k) +        else: +            if k.lower() not in specials and _has_special(v): +                v = ESCAPE.sub(r"\\\1", v) +                v = '"%s"' % v +            vals.append("%s=%s" % (k, v)) +    return sep.join(vals) + + +def _format_set_cookie_pairs(lst): +    return _format_pairs( +        lst, +        specials=("expires", "path") +    ) + + +def _parse_set_cookie_pairs(s): +    """ +        For Set-Cookie, we support multiple cookies as described in RFC2109. +        This function therefore returns a list of lists. +    """ +    pairs, off_ = _read_pairs(s) +    return pairs + + +def parse_set_cookie_header(line): +    """ +        Parse a Set-Cookie header value + +        Returns a (name, value, attrs) tuple, or None, where attrs is an +        ODictCaseless set of attributes. No attempt is made to parse attribute +        values - they are treated purely as strings. +    """ +    pairs = _parse_set_cookie_pairs(line) +    if pairs: +        return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + + +def format_set_cookie_header(name, value, attrs): +    """ +        Formats a Set-Cookie header value. +    """ +    pairs = [[name, value]] +    pairs.extend(attrs.lst) +    return _format_set_cookie_pairs(pairs) + + +def parse_cookie_header(line): +    """ +        Parse a Cookie header value. +        Returns a (possibly empty) ODict object. +    """ +    pairs, off_ = _read_pairs(line) +    return odict.ODict(pairs) + + +def format_cookie_header(od): +    """ +        Formats a Cookie header value. +    """ +    return _format_pairs(od.lst) diff --git a/netlib/netlib/http/headers.py b/netlib/netlib/http/headers.py new file mode 100644 index 00000000..78404796 --- /dev/null +++ b/netlib/netlib/http/headers.py @@ -0,0 +1,204 @@ +""" + +Unicode Handling +---------------- +See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ +""" +from __future__ import absolute_import, print_function, division +import copy +try: +    from collections.abc import MutableMapping +except ImportError:  # pragma: nocover +    from collections import MutableMapping  # Workaround for Python < 3.3 + + +import six + +from netlib.utils import always_byte_args, always_bytes, Serializable + +if six.PY2:  # pragma: nocover +    _native = lambda x: x +    _always_bytes = lambda x: x +    _always_byte_args = lambda x: x +else: +    # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. +    _native = lambda x: x.decode("utf-8", "surrogateescape") +    _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") +    _always_byte_args = always_byte_args("utf-8", "surrogateescape") + + +class Headers(MutableMapping, Serializable): +    """ +    Header class which allows both convenient access to individual headers as well as +    direct access to the underlying raw data. Provides a full dictionary interface. + +    Example: + +    .. code-block:: python + +        # Create headers with keyword arguments +        >>> h = Headers(host="example.com", content_type="application/xml") + +        # Headers mostly behave like a normal dict. +        >>> h["Host"] +        "example.com" + +        # HTTP Headers are case insensitive +        >>> h["host"] +        "example.com" + +        # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples +        >>> h = Headers([ +            [b"Host",b"example.com"], +            [b"Accept",b"text/html"], +            [b"accept",b"application/xml"] +        ]) + +        # Multiple headers are folded into a single header as per RFC7230 +        >>> h["Accept"] +        "text/html, application/xml" + +        # Setting a header removes all existing headers with the same name. +        >>> h["Accept"] = "application/text" +        >>> h["Accept"] +        "application/text" + +        # bytes(h) returns a HTTP1 header block. +        >>> print(bytes(h)) +        Host: example.com +        Accept: application/text + +        # For full control, the raw header fields can be accessed +        >>> h.fields + +    Caveats: +        For use with the "Set-Cookie" header, see :py:meth:`get_all`. +    """ + +    @_always_byte_args +    def __init__(self, fields=None, **headers): +        """ +        Args: +            fields: (optional) list of ``(name, value)`` header byte tuples, +                e.g. ``[(b"Host", b"example.com")]``. All names and values must be bytes. +            **headers: Additional headers to set. Will overwrite existing values from `fields`. +                For convenience, underscores in header names will be transformed to dashes - +                this behaviour does not extend to other methods. +                If ``**headers`` contains multiple keys that have equal ``.lower()`` s, +                the behavior is undefined. +        """ +        self.fields = fields or [] + +        for name, value in self.fields: +            if not isinstance(name, bytes) or not isinstance(value, bytes): +                raise ValueError("Headers passed as fields must be bytes.") + +        # content_type -> content-type +        headers = { +            _always_bytes(name).replace(b"_", b"-"): value +            for name, value in six.iteritems(headers) +            } +        self.update(headers) + +    def __bytes__(self): +        if self.fields: +            return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" +        else: +            return b"" + +    if six.PY2:  # pragma: nocover +        __str__ = __bytes__ + +    @_always_byte_args +    def __getitem__(self, name): +        values = self.get_all(name) +        if not values: +            raise KeyError(name) +        return ", ".join(values) + +    @_always_byte_args +    def __setitem__(self, name, value): +        idx = self._index(name) + +        # To please the human eye, we insert at the same position the first existing header occured. +        if idx is not None: +            del self[name] +            self.fields.insert(idx, [name, value]) +        else: +            self.fields.append([name, value]) + +    @_always_byte_args +    def __delitem__(self, name): +        if name not in self: +            raise KeyError(name) +        name = name.lower() +        self.fields = [ +            field for field in self.fields +            if name != field[0].lower() +        ] + +    def __iter__(self): +        seen = set() +        for name, _ in self.fields: +            name_lower = name.lower() +            if name_lower not in seen: +                seen.add(name_lower) +                yield _native(name) + +    def __len__(self): +        return len(set(name.lower() for name, _ in self.fields)) + +    # __hash__ = object.__hash__ + +    def _index(self, name): +        name = name.lower() +        for i, field in enumerate(self.fields): +            if field[0].lower() == name: +                return i +        return None + +    def __eq__(self, other): +        if isinstance(other, Headers): +            return self.fields == other.fields +        return False + +    def __ne__(self, other): +        return not self.__eq__(other) + +    @_always_byte_args +    def get_all(self, name): +        """ +        Like :py:meth:`get`, but does not fold multiple headers into a single one. +        This is useful for Set-Cookie headers, which do not support folding. + +        See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 +        """ +        name_lower = name.lower() +        values = [_native(value) for n, value in self.fields if n.lower() == name_lower] +        return values + +    @_always_byte_args +    def set_all(self, name, values): +        """ +        Explicitly set multiple headers for the given key. +        See: :py:meth:`get_all` +        """ +        values = map(_always_bytes, values)  # _always_byte_args does not fix lists +        if name in self: +            del self[name] +        self.fields.extend( +            [name, value] for value in values +        ) + +    def copy(self): +        return Headers(copy.copy(self.fields)) + +    def get_state(self): +        return tuple(tuple(field) for field in self.fields) + +    def set_state(self, state): +        self.fields = [list(field) for field in state] + +    @classmethod +    def from_state(cls, state): +        return cls([list(field) for field in state])
\ No newline at end of file diff --git a/netlib/netlib/http/http1/__init__.py b/netlib/netlib/http/http1/__init__.py new file mode 100644 index 00000000..2aa7e26a --- /dev/null +++ b/netlib/netlib/http/http1/__init__.py @@ -0,0 +1,25 @@ +from __future__ import absolute_import, print_function, division +from .read import ( +    read_request, read_request_head, +    read_response, read_response_head, +    read_body, +    connection_close, +    expected_http_body_size, +) +from .assemble import ( +    assemble_request, assemble_request_head, +    assemble_response, assemble_response_head, +    assemble_body, +) + + +__all__ = [ +    "read_request", "read_request_head", +    "read_response", "read_response_head", +    "read_body", +    "connection_close", +    "expected_http_body_size", +    "assemble_request", "assemble_request_head", +    "assemble_response", "assemble_response_head", +    "assemble_body", +] diff --git a/netlib/netlib/http/http1/assemble.py b/netlib/netlib/http/http1/assemble.py new file mode 100644 index 00000000..785ee8d3 --- /dev/null +++ b/netlib/netlib/http/http1/assemble.py @@ -0,0 +1,104 @@ +from __future__ import absolute_import, print_function, division + +from ... import utils +import itertools +from ...exceptions import HttpException +from .. import CONTENT_MISSING + + +def assemble_request(request): +    if request.content == CONTENT_MISSING: +        raise HttpException("Cannot assemble flow with CONTENT_MISSING") +    head = assemble_request_head(request) +    body = b"".join(assemble_body(request.data.headers, [request.data.content])) +    return head + body + + +def assemble_request_head(request): +    first_line = _assemble_request_line(request.data) +    headers = _assemble_request_headers(request.data) +    return b"%s\r\n%s\r\n" % (first_line, headers) + + +def assemble_response(response): +    if response.content == CONTENT_MISSING: +        raise HttpException("Cannot assemble flow with CONTENT_MISSING") +    head = assemble_response_head(response) +    body = b"".join(assemble_body(response.data.headers, [response.data.content])) +    return head + body + + +def assemble_response_head(response): +    first_line = _assemble_response_line(response.data) +    headers = _assemble_response_headers(response.data) +    return b"%s\r\n%s\r\n" % (first_line, headers) + + +def assemble_body(headers, body_chunks): +    if "chunked" in headers.get("transfer-encoding", "").lower(): +        for chunk in body_chunks: +            if chunk: +                yield b"%x\r\n%s\r\n" % (len(chunk), chunk) +        yield b"0\r\n\r\n" +    else: +        for chunk in body_chunks: +            yield chunk + + +def _assemble_request_line(request_data): +    """ +    Args: +        request_data (netlib.http.request.RequestData) +    """ +    form = request_data.first_line_format +    if form == "relative": +        return b"%s %s %s" % ( +            request_data.method, +            request_data.path, +            request_data.http_version +        ) +    elif form == "authority": +        return b"%s %s:%d %s" % ( +            request_data.method, +            request_data.host, +            request_data.port, +            request_data.http_version +        ) +    elif form == "absolute": +        return b"%s %s://%s:%d%s %s" % ( +            request_data.method, +            request_data.scheme, +            request_data.host, +            request_data.port, +            request_data.path, +            request_data.http_version +        ) +    else: +        raise RuntimeError("Invalid request form") + + +def _assemble_request_headers(request_data): +    """ +    Args: +        request_data (netlib.http.request.RequestData) +    """ +    headers = request_data.headers.copy() +    if "host" not in headers and request_data.scheme and request_data.host and request_data.port: +        headers["host"] = utils.hostport( +            request_data.scheme, +            request_data.host, +            request_data.port +        ) +    return bytes(headers) + + +def _assemble_response_line(response_data): +    return b"%s %d %s" % ( +        response_data.http_version, +        response_data.status_code, +        response_data.reason, +    ) + + +def _assemble_response_headers(response): +    return bytes(response.headers) diff --git a/netlib/netlib/http/http1/read.py b/netlib/netlib/http/http1/read.py new file mode 100644 index 00000000..6e3a1b93 --- /dev/null +++ b/netlib/netlib/http/http1/read.py @@ -0,0 +1,362 @@ +from __future__ import absolute_import, print_function, division +import time +import sys +import re + +from ... import utils +from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect +from .. import Request, Response, Headers + + +def read_request(rfile, body_size_limit=None): +    request = read_request_head(rfile) +    expected_body_size = expected_http_body_size(request) +    request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) +    request.timestamp_end = time.time() +    return request + + +def read_request_head(rfile): +    """ +    Parse an HTTP request head (request line + headers) from an input stream + +    Args: +        rfile: The input stream + +    Returns: +        The HTTP request object (without body) + +    Raises: +        HttpReadDisconnect: No bytes can be read from rfile. +        HttpSyntaxException: The input is malformed HTTP. +        HttpException: Any other error occured. +    """ +    timestamp_start = time.time() +    if hasattr(rfile, "reset_timestamps"): +        rfile.reset_timestamps() + +    form, method, scheme, host, port, path, http_version = _read_request_line(rfile) +    headers = _read_headers(rfile) + +    if hasattr(rfile, "first_byte_timestamp"): +        # more accurate timestamp_start +        timestamp_start = rfile.first_byte_timestamp + +    return Request( +        form, method, scheme, host, port, path, http_version, headers, None, timestamp_start +    ) + + +def read_response(rfile, request, body_size_limit=None): +    response = read_response_head(rfile) +    expected_body_size = expected_http_body_size(request, response) +    response.data.content = b"".join(read_body(rfile, expected_body_size, body_size_limit)) +    response.timestamp_end = time.time() +    return response + + +def read_response_head(rfile): +    """ +    Parse an HTTP response head (response line + headers) from an input stream + +    Args: +        rfile: The input stream + +    Returns: +        The HTTP request object (without body) + +    Raises: +        HttpReadDisconnect: No bytes can be read from rfile. +        HttpSyntaxException: The input is malformed HTTP. +        HttpException: Any other error occured. +    """ + +    timestamp_start = time.time() +    if hasattr(rfile, "reset_timestamps"): +        rfile.reset_timestamps() + +    http_version, status_code, message = _read_response_line(rfile) +    headers = _read_headers(rfile) + +    if hasattr(rfile, "first_byte_timestamp"): +        # more accurate timestamp_start +        timestamp_start = rfile.first_byte_timestamp + +    return Response(http_version, status_code, message, headers, None, timestamp_start) + + +def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): +    """ +        Read an HTTP message body + +        Args: +            rfile: The input stream +            expected_size: The expected body size (see :py:meth:`expected_body_size`) +            limit: Maximum body size +            max_chunk_size: Maximium chunk size that gets yielded + +        Returns: +            A generator that yields byte chunks of the content. + +        Raises: +            HttpException, if an error occurs + +        Caveats: +            max_chunk_size is not considered if the transfer encoding is chunked. +    """ +    if not limit or limit < 0: +        limit = sys.maxsize +    if not max_chunk_size: +        max_chunk_size = limit + +    if expected_size is None: +        for x in _read_chunked(rfile, limit): +            yield x +    elif expected_size >= 0: +        if limit is not None and expected_size > limit: +            raise HttpException( +                "HTTP Body too large. " +                "Limit is {}, content length was advertised as {}".format(limit, expected_size) +            ) +        bytes_left = expected_size +        while bytes_left: +            chunk_size = min(bytes_left, max_chunk_size) +            content = rfile.read(chunk_size) +            if len(content) < chunk_size: +                raise HttpException("Unexpected EOF") +            yield content +            bytes_left -= chunk_size +    else: +        bytes_left = limit +        while bytes_left: +            chunk_size = min(bytes_left, max_chunk_size) +            content = rfile.read(chunk_size) +            if not content: +                return +            yield content +            bytes_left -= chunk_size +        not_done = rfile.read(1) +        if not_done: +            raise HttpException("HTTP body too large. Limit is {}.".format(limit)) + + +def connection_close(http_version, headers): +    """ +        Checks the message to see if the client connection should be closed +        according to RFC 2616 Section 8.1. +    """ +    # At first, check if we have an explicit Connection header. +    if "connection" in headers: +        tokens = utils.get_header_tokens(headers, "connection") +        if "close" in tokens: +            return True +        elif "keep-alive" in tokens: +            return False + +    # If we don't have a Connection header, HTTP 1.1 connections are assumed to +    # be persistent +    return http_version != "HTTP/1.1" and http_version != b"HTTP/1.1"  # FIXME: Remove one case. + + +def expected_http_body_size(request, response=None): +    """ +        Returns: +            The expected body length: +            - a positive integer, if the size is known in advance +            - None, if the size in unknown in advance (chunked encoding) +            - -1, if all data should be read until end of stream. + +        Raises: +            HttpSyntaxException, if the content length header is invalid +    """ +    # Determine response size according to +    # http://tools.ietf.org/html/rfc7230#section-3.3 +    if not response: +        headers = request.headers +        response_code = None +        is_request = True +    else: +        headers = response.headers +        response_code = response.status_code +        is_request = False + +    if is_request: +        if headers.get("expect", "").lower() == "100-continue": +            return 0 +    else: +        if request.method.upper() == "HEAD": +            return 0 +        if 100 <= response_code <= 199: +            return 0 +        if response_code == 200 and request.method.upper() == "CONNECT": +            return 0 +        if response_code in (204, 304): +            return 0 + +    if "chunked" in headers.get("transfer-encoding", "").lower(): +        return None +    if "content-length" in headers: +        try: +            size = int(headers["content-length"]) +            if size < 0: +                raise ValueError() +            return size +        except ValueError: +            raise HttpSyntaxException("Unparseable Content Length") +    if is_request: +        return 0 +    return -1 + + +def _get_first_line(rfile): +    try: +        line = rfile.readline() +        if line == b"\r\n" or line == b"\n": +            # Possible leftover from previous message +            line = rfile.readline() +    except TcpDisconnect: +        raise HttpReadDisconnect("Remote disconnected") +    if not line: +        raise HttpReadDisconnect("Remote disconnected") +    return line.strip() + + +def _read_request_line(rfile): +    try: +        line = _get_first_line(rfile) +    except HttpReadDisconnect: +        # We want to provide a better error message. +        raise HttpReadDisconnect("Client disconnected") + +    try: +        method, path, http_version = line.split(b" ") + +        if path == b"*" or path.startswith(b"/"): +            form = "relative" +            scheme, host, port = None, None, None +        elif method == b"CONNECT": +            form = "authority" +            host, port = _parse_authority_form(path) +            scheme, path = None, None +        else: +            form = "absolute" +            scheme, host, port, path = utils.parse_url(path) + +        _check_http_version(http_version) +    except ValueError: +        raise HttpSyntaxException("Bad HTTP request line: {}".format(line)) + +    return form, method, scheme, host, port, path, http_version + + +def _parse_authority_form(hostport): +    """ +        Returns (host, port) if hostport is a valid authority-form host specification. +        http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + +        Raises: +            ValueError, if the input is malformed +    """ +    try: +        host, port = hostport.split(b":") +        port = int(port) +        if not utils.is_valid_host(host) or not utils.is_valid_port(port): +            raise ValueError() +    except ValueError: +        raise HttpSyntaxException("Invalid host specification: {}".format(hostport)) + +    return host, port + + +def _read_response_line(rfile): +    try: +        line = _get_first_line(rfile) +    except HttpReadDisconnect: +        # We want to provide a better error message. +        raise HttpReadDisconnect("Server disconnected") + +    try: + +        parts = line.split(b" ", 2) +        if len(parts) == 2:  # handle missing message gracefully +            parts.append(b"") + +        http_version, status_code, message = parts +        status_code = int(status_code) +        _check_http_version(http_version) + +    except ValueError: +        raise HttpSyntaxException("Bad HTTP response line: {}".format(line)) + +    return http_version, status_code, message + + +def _check_http_version(http_version): +    if not re.match(br"^HTTP/\d\.\d$", http_version): +        raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) + + +def _read_headers(rfile): +    """ +        Read a set of headers. +        Stop once a blank line is reached. + +        Returns: +            A headers object + +        Raises: +            HttpSyntaxException +    """ +    ret = [] +    while True: +        line = rfile.readline() +        if not line or line == b"\r\n" or line == b"\n": +            break +        if line[0] in b" \t": +            if not ret: +                raise HttpSyntaxException("Invalid headers") +            # continued header +            ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() +        else: +            try: +                name, value = line.split(b":", 1) +                value = value.strip() +                if not name: +                    raise ValueError() +                ret.append([name, value]) +            except ValueError: +                raise HttpSyntaxException("Invalid headers") +    return Headers(ret) + + +def _read_chunked(rfile, limit=sys.maxsize): +    """ +    Read a HTTP body with chunked transfer encoding. + +    Args: +        rfile: the input file +        limit: A positive integer +    """ +    total = 0 +    while True: +        line = rfile.readline(128) +        if line == b"": +            raise HttpException("Connection closed prematurely") +        if line != b"\r\n" and line != b"\n": +            try: +                length = int(line, 16) +            except ValueError: +                raise HttpSyntaxException("Invalid chunked encoding length: {}".format(line)) +            total += length +            if total > limit: +                raise HttpException( +                    "HTTP Body too large. Limit is {}, " +                    "chunked content longer than {}".format(limit, total) +                ) +            chunk = rfile.read(length) +            suffix = rfile.readline(5) +            if suffix != b"\r\n": +                raise HttpSyntaxException("Malformed chunked body") +            if length == 0: +                return +            yield chunk diff --git a/netlib/netlib/http/http2/__init__.py b/netlib/netlib/http/http2/__init__.py new file mode 100644 index 00000000..7043d36f --- /dev/null +++ b/netlib/netlib/http/http2/__init__.py @@ -0,0 +1,6 @@ +from __future__ import absolute_import, print_function, division +from .connections import HTTP2Protocol + +__all__ = [ +    "HTTP2Protocol" +] diff --git a/netlib/netlib/http/http2/connections.py b/netlib/netlib/http/http2/connections.py new file mode 100644 index 00000000..52fa7193 --- /dev/null +++ b/netlib/netlib/http/http2/connections.py @@ -0,0 +1,426 @@ +from __future__ import (absolute_import, print_function, division) +import itertools +import time + +from hpack.hpack import Encoder, Decoder +from ... import utils +from .. import Headers, Response, Request + +from hyperframe import frame + + +class TCPHandler(object): + +    def __init__(self, rfile, wfile=None): +        self.rfile = rfile +        self.wfile = wfile + + +class HTTP2Protocol(object): + +    ERROR_CODES = utils.BiDi( +        NO_ERROR=0x0, +        PROTOCOL_ERROR=0x1, +        INTERNAL_ERROR=0x2, +        FLOW_CONTROL_ERROR=0x3, +        SETTINGS_TIMEOUT=0x4, +        STREAM_CLOSED=0x5, +        FRAME_SIZE_ERROR=0x6, +        REFUSED_STREAM=0x7, +        CANCEL=0x8, +        COMPRESSION_ERROR=0x9, +        CONNECT_ERROR=0xa, +        ENHANCE_YOUR_CALM=0xb, +        INADEQUATE_SECURITY=0xc, +        HTTP_1_1_REQUIRED=0xd +    ) + +    CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' + +    HTTP2_DEFAULT_SETTINGS = { +        frame.SettingsFrame.HEADER_TABLE_SIZE: 4096, +        frame.SettingsFrame.ENABLE_PUSH: 1, +        frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None, +        frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1, +        frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14, +        frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None, +    } + +    def __init__( +        self, +        tcp_handler=None, +        rfile=None, +        wfile=None, +        is_server=False, +        dump_frames=False, +        encoder=None, +        decoder=None, +        unhandled_frame_cb=None, +    ): +        self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) +        self.is_server = is_server +        self.dump_frames = dump_frames +        self.encoder = encoder or Encoder() +        self.decoder = decoder or Decoder() +        self.unhandled_frame_cb = unhandled_frame_cb + +        self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() +        self.current_stream_id = None +        self.connection_preface_performed = False + +    def read_request( +        self, +        __rfile, +        include_body=True, +        body_size_limit=None, +        allow_empty=False, +    ): +        if body_size_limit is not None: +            raise NotImplementedError() + +        self.perform_connection_preface() + +        timestamp_start = time.time() +        if hasattr(self.tcp_handler.rfile, "reset_timestamps"): +            self.tcp_handler.rfile.reset_timestamps() + +        stream_id, headers, body = self._receive_transmission( +            include_body=include_body, +        ) + +        if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): +            # more accurate timestamp_start +            timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + +        timestamp_end = time.time() + +        authority = headers.get(':authority', b'') +        method = headers.get(':method', 'GET') +        scheme = headers.get(':scheme', 'https') +        path = headers.get(':path', '/') +        host = None +        port = None + +        if path == '*' or path.startswith("/"): +            form_in = "relative" +        elif method == 'CONNECT': +            form_in = "authority" +            if ":" in authority: +                host, port = authority.split(":", 1) +            else: +                host = authority +        else: +            form_in = "absolute" +            # FIXME: verify if path or :host contains what we need +            scheme, host, port, _ = utils.parse_url(path) +            scheme = scheme.decode('ascii') +            host = host.decode('ascii') + +        if host is None: +            host = 'localhost' +        if port is None: +            port = 80 if scheme == 'http' else 443 +        port = int(port) + +        request = Request( +            form_in, +            method.encode('ascii'), +            scheme.encode('ascii'), +            host.encode('ascii'), +            port, +            path.encode('ascii'), +            b"HTTP/2.0", +            headers, +            body, +            timestamp_start, +            timestamp_end, +        ) +        request.stream_id = stream_id + +        return request + +    def read_response( +        self, +        __rfile, +        request_method=b'', +        body_size_limit=None, +        include_body=True, +        stream_id=None, +    ): +        if body_size_limit is not None: +            raise NotImplementedError() + +        self.perform_connection_preface() + +        timestamp_start = time.time() +        if hasattr(self.tcp_handler.rfile, "reset_timestamps"): +            self.tcp_handler.rfile.reset_timestamps() + +        stream_id, headers, body = self._receive_transmission( +            stream_id=stream_id, +            include_body=include_body, +        ) + +        if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): +            # more accurate timestamp_start +            timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + +        if include_body: +            timestamp_end = time.time() +        else: +            timestamp_end = None + +        response = Response( +            b"HTTP/2.0", +            int(headers.get(':status', 502)), +            b'', +            headers, +            body, +            timestamp_start=timestamp_start, +            timestamp_end=timestamp_end, +        ) +        response.stream_id = stream_id + +        return response + +    def assemble(self, message): +        if isinstance(message, Request): +            return self.assemble_request(message) +        elif isinstance(message, Response): +            return self.assemble_response(message) +        else: +            raise ValueError("HTTP message not supported.") + +    def assemble_request(self, request): +        assert isinstance(request, Request) + +        authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host +        if self.tcp_handler.address.port != 443: +            authority += ":%d" % self.tcp_handler.address.port + +        headers = request.headers.copy() + +        if ':authority' not in headers: +            headers.fields.insert(0, (b':authority', authority.encode('ascii'))) +        if ':scheme' not in headers: +            headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) +        if ':path' not in headers: +            headers.fields.insert(0, (b':path', request.path.encode('ascii'))) +        if ':method' not in headers: +            headers.fields.insert(0, (b':method', request.method.encode('ascii'))) + +        if hasattr(request, 'stream_id'): +            stream_id = request.stream_id +        else: +            stream_id = self._next_stream_id() + +        return list(itertools.chain( +            self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), +            self._create_body(request.body, stream_id))) + +    def assemble_response(self, response): +        assert isinstance(response, Response) + +        headers = response.headers.copy() + +        if ':status' not in headers: +            headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) + +        if hasattr(response, 'stream_id'): +            stream_id = response.stream_id +        else: +            stream_id = self._next_stream_id() + +        return list(itertools.chain( +            self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), +            self._create_body(response.body, stream_id), +        )) + +    def perform_connection_preface(self, force=False): +        if force or not self.connection_preface_performed: +            if self.is_server: +                self.perform_server_connection_preface(force) +            else: +                self.perform_client_connection_preface(force) + +    def perform_server_connection_preface(self, force=False): +        if force or not self.connection_preface_performed: +            self.connection_preface_performed = True + +            magic_length = len(self.CLIENT_CONNECTION_PREFACE) +            magic = self.tcp_handler.rfile.safe_read(magic_length) +            assert magic == self.CLIENT_CONNECTION_PREFACE + +            frm = frame.SettingsFrame(settings={ +                frame.SettingsFrame.ENABLE_PUSH: 0, +                frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1, +            }) +            self.send_frame(frm, hide=True) +            self._receive_settings(hide=True) + +    def perform_client_connection_preface(self, force=False): +        if force or not self.connection_preface_performed: +            self.connection_preface_performed = True + +            self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + +            self.send_frame(frame.SettingsFrame(), hide=True) +            self._receive_settings(hide=True)  # server announces own settings +            self._receive_settings(hide=True)  # server acks my settings + +    def send_frame(self, frm, hide=False): +        raw_bytes = frm.serialize() +        self.tcp_handler.wfile.write(raw_bytes) +        self.tcp_handler.wfile.flush() +        if not hide and self.dump_frames:  # pragma no cover +            print(frm.human_readable(">>")) + +    def read_frame(self, hide=False): +        while True: +            frm = utils.http2_read_frame(self.tcp_handler.rfile) +            if not hide and self.dump_frames:  # pragma no cover +                print(frm.human_readable("<<")) + +            if isinstance(frm, frame.PingFrame): +                raw_bytes = frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize() +                self.tcp_handler.wfile.write(raw_bytes) +                self.tcp_handler.wfile.flush() +                continue +            if isinstance(frm, frame.SettingsFrame) and 'ACK' not in frm.flags: +                self._apply_settings(frm.settings, hide) +            if isinstance(frm, frame.DataFrame) and frm.flow_controlled_length > 0: +                self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length) +            return frm + +    def check_alpn(self): +        alp = self.tcp_handler.get_alpn_proto_negotiated() +        if alp != b'h2': +            raise NotImplementedError( +                "HTTP2Protocol can not handle unknown ALP: %s" % alp) +        return True + +    def _handle_unexpected_frame(self, frm): +        if isinstance(frm, frame.SettingsFrame): +            return +        if self.unhandled_frame_cb: +            self.unhandled_frame_cb(frm) + +    def _receive_settings(self, hide=False): +        while True: +            frm = self.read_frame(hide) +            if isinstance(frm, frame.SettingsFrame): +                break +            else: +                self._handle_unexpected_frame(frm) + +    def _next_stream_id(self): +        if self.current_stream_id is None: +            if self.is_server: +                # servers must use even stream ids +                self.current_stream_id = 2 +            else: +                # clients must use odd stream ids +                self.current_stream_id = 1 +        else: +            self.current_stream_id += 2 +        return self.current_stream_id + +    def _apply_settings(self, settings, hide=False): +        for setting, value in settings.items(): +            old_value = self.http2_settings[setting] +            if not old_value: +                old_value = '-' +            self.http2_settings[setting] = value + +        frm = frame.SettingsFrame(flags=['ACK']) +        self.send_frame(frm, hide) + +    def _update_flow_control_window(self, stream_id, increment): +        frm = frame.WindowUpdateFrame(stream_id=0, window_increment=increment) +        self.send_frame(frm) +        frm = frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment) +        self.send_frame(frm) + +    def _create_headers(self, headers, stream_id, end_stream=True): +        def frame_cls(chunks): +            for i in chunks: +                if i == 0: +                    yield frame.HeadersFrame, i +                else: +                    yield frame.ContinuationFrame, i + +        header_block_fragment = self.encoder.encode(headers.fields) + +        chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] +        chunks = range(0, len(header_block_fragment), chunk_size) +        frms = [frm_cls( +            flags=[], +            stream_id=stream_id, +            data=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] + +        frms[-1].flags.add('END_HEADERS') +        if end_stream: +            frms[0].flags.add('END_STREAM') + +        if self.dump_frames:  # pragma no cover +            for frm in frms: +                print(frm.human_readable(">>")) + +        return [frm.serialize() for frm in frms] + +    def _create_body(self, body, stream_id): +        if body is None or len(body) == 0: +            return b'' + +        chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] +        chunks = range(0, len(body), chunk_size) +        frms = [frame.DataFrame( +            flags=[], +            stream_id=stream_id, +            data=body[i:i+chunk_size]) for i in chunks] +        frms[-1].flags.add('END_STREAM') + +        if self.dump_frames:  # pragma no cover +            for frm in frms: +                print(frm.human_readable(">>")) + +        return [frm.serialize() for frm in frms] + +    def _receive_transmission(self, stream_id=None, include_body=True): +        if not include_body: +            raise NotImplementedError() + +        body_expected = True + +        header_blocks = b'' +        body = b'' + +        while True: +            frm = self.read_frame() +            if ( +                (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and +                (stream_id is None or frm.stream_id == stream_id) +            ): +                stream_id = frm.stream_id +                header_blocks += frm.data +                if 'END_STREAM' in frm.flags: +                    body_expected = False +                if 'END_HEADERS' in frm.flags: +                    break +            else: +                self._handle_unexpected_frame(frm) + +        while body_expected: +            frm = self.read_frame() +            if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: +                body += frm.data +                if 'END_STREAM' in frm.flags: +                    break +            else: +                self._handle_unexpected_frame(frm) + +        headers = Headers( +            [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] +        ) + +        return stream_id, headers, body diff --git a/netlib/netlib/http/message.py b/netlib/netlib/http/message.py new file mode 100644 index 00000000..e3d8ce37 --- /dev/null +++ b/netlib/netlib/http/message.py @@ -0,0 +1,222 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +import six + +from .headers import Headers +from .. import encoding, utils + +CONTENT_MISSING = 0 + +if six.PY2:  # pragma: nocover +    _native = lambda x: x +    _always_bytes = lambda x: x +else: +    # While the HTTP head _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. +    _native = lambda x: x.decode("utf-8", "surrogateescape") +    _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") + + +class MessageData(utils.Serializable): +    def __eq__(self, other): +        if isinstance(other, MessageData): +            return self.__dict__ == other.__dict__ +        return False + +    def __ne__(self, other): +        return not self.__eq__(other) + +    def set_state(self, state): +        for k, v in state.items(): +            if k == "headers": +                v = Headers.from_state(v) +            setattr(self, k, v) + +    def get_state(self): +        state = vars(self).copy() +        state["headers"] = state["headers"].get_state() +        return state + +    @classmethod +    def from_state(cls, state): +        state["headers"] = Headers.from_state(state["headers"]) +        return cls(**state) + + +class Message(utils.Serializable): +    def __init__(self, data): +        self.data = data + +    def __eq__(self, other): +        if isinstance(other, Message): +            return self.data == other.data +        return False + +    def __ne__(self, other): +        return not self.__eq__(other) + +    def get_state(self): +        return self.data.get_state() + +    def set_state(self, state): +        self.data.set_state(state) + +    @classmethod +    def from_state(cls, state): +        return cls(**state) + +    @property +    def headers(self): +        """ +        Message headers object + +        Returns: +            netlib.http.Headers +        """ +        return self.data.headers + +    @headers.setter +    def headers(self, h): +        self.data.headers = h + +    @property +    def content(self): +        """ +        The raw (encoded) HTTP message body + +        See also: :py:attr:`text` +        """ +        return self.data.content + +    @content.setter +    def content(self, content): +        self.data.content = content +        if isinstance(content, bytes): +            self.headers["content-length"] = str(len(content)) + +    @property +    def http_version(self): +        """ +        Version string, e.g. "HTTP/1.1" +        """ +        return _native(self.data.http_version) + +    @http_version.setter +    def http_version(self, http_version): +        self.data.http_version = _always_bytes(http_version) + +    @property +    def timestamp_start(self): +        """ +        First byte timestamp +        """ +        return self.data.timestamp_start + +    @timestamp_start.setter +    def timestamp_start(self, timestamp_start): +        self.data.timestamp_start = timestamp_start + +    @property +    def timestamp_end(self): +        """ +        Last byte timestamp +        """ +        return self.data.timestamp_end + +    @timestamp_end.setter +    def timestamp_end(self, timestamp_end): +        self.data.timestamp_end = timestamp_end + +    @property +    def text(self): +        """ +        The decoded HTTP message body. +        Decoded contents are not cached, so accessing this attribute repeatedly is relatively expensive. + +        .. note:: +            This is not implemented yet. + +        See also: :py:attr:`content`, :py:class:`decoded` +        """ +        # This attribute should be called text, because that's what requests does. +        raise NotImplementedError() + +    @text.setter +    def text(self, text): +        raise NotImplementedError() + +    def decode(self): +        """ +            Decodes body based on the current Content-Encoding header, then +            removes the header. If there is no Content-Encoding header, no +            action is taken. + +            Returns: +                True, if decoding succeeded. +                False, otherwise. +        """ +        ce = self.headers.get("content-encoding") +        data = encoding.decode(ce, self.content) +        if data is None: +            return False +        self.content = data +        self.headers.pop("content-encoding", None) +        return True + +    def encode(self, e): +        """ +            Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". + +            Returns: +                True, if decoding succeeded. +                False, otherwise. +        """ +        data = encoding.encode(e, self.content) +        if data is None: +            return False +        self.content = data +        self.headers["content-encoding"] = e +        return True + +    # Legacy + +    @property +    def body(self):  # pragma: nocover +        warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) +        return self.content + +    @body.setter +    def body(self, body):  # pragma: nocover +        warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) +        self.content = body + + +class decoded(object): +    """ +    A context manager that decodes a request or response, and then +    re-encodes it with the same encoding after execution of the block. + +    Example: + +    .. code-block:: python + +        with decoded(request): +            request.content = request.content.replace("foo", "bar") +    """ + +    def __init__(self, message): +        self.message = message +        ce = message.headers.get("content-encoding") +        if ce in encoding.ENCODINGS: +            self.ce = ce +        else: +            self.ce = None + +    def __enter__(self): +        if self.ce: +            self.message.decode() + +    def __exit__(self, type, value, tb): +        if self.ce: +            self.message.encode(self.ce) diff --git a/netlib/netlib/http/request.py b/netlib/netlib/http/request.py new file mode 100644 index 00000000..0e0f88ce --- /dev/null +++ b/netlib/netlib/http/request.py @@ -0,0 +1,353 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +import six +from six.moves import urllib + +from netlib import utils +from netlib.http import cookies +from netlib.odict import ODict +from .. import encoding +from .headers import Headers +from .message import Message, _native, _always_bytes, MessageData + + +class RequestData(MessageData): +    def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, +                 timestamp_start=None, timestamp_end=None): +        if not isinstance(headers, Headers): +            headers = Headers(headers) + +        self.first_line_format = first_line_format +        self.method = method +        self.scheme = scheme +        self.host = host +        self.port = port +        self.path = path +        self.http_version = http_version +        self.headers = headers +        self.content = content +        self.timestamp_start = timestamp_start +        self.timestamp_end = timestamp_end + + +class Request(Message): +    """ +    An HTTP request. +    """ +    def __init__(self, *args, **kwargs): +        data = RequestData(*args, **kwargs) +        super(Request, self).__init__(data) + +    def __repr__(self): +        if self.host and self.port: +            hostport = "{}:{}".format(self.host, self.port) +        else: +            hostport = "" +        path = self.path or "" +        return "Request({} {}{})".format( +            self.method, hostport, path +        ) + +    @property +    def first_line_format(self): +        """ +        HTTP request form as defined in `RFC7230 <https://tools.ietf.org/html/rfc7230#section-5.3>`_. + +        origin-form and asterisk-form are subsumed as "relative". +        """ +        return self.data.first_line_format + +    @first_line_format.setter +    def first_line_format(self, first_line_format): +        self.data.first_line_format = first_line_format + +    @property +    def method(self): +        """ +        HTTP request method, e.g. "GET". +        """ +        return _native(self.data.method).upper() + +    @method.setter +    def method(self, method): +        self.data.method = _always_bytes(method) + +    @property +    def scheme(self): +        """ +        HTTP request scheme, which should be "http" or "https". +        """ +        return _native(self.data.scheme) + +    @scheme.setter +    def scheme(self, scheme): +        self.data.scheme = _always_bytes(scheme) + +    @property +    def host(self): +        """ +        Target host. This may be parsed from the raw request +        (e.g. from a ``GET http://example.com/ HTTP/1.1`` request line) +        or inferred from the proxy mode (e.g. an IP in transparent mode). + +        Setting the host attribute also updates the host header, if present. +        """ + +        if six.PY2:  # pragma: nocover +            return self.data.host + +        if not self.data.host: +            return self.data.host +        try: +            return self.data.host.decode("idna") +        except UnicodeError: +            return self.data.host.decode("utf8", "surrogateescape") + +    @host.setter +    def host(self, host): +        if isinstance(host, six.text_type): +            try: +                # There's no non-strict mode for IDNA encoding. +                # We don't want this operation to fail though, so we try +                # utf8 as a last resort. +                host = host.encode("idna", "strict") +            except UnicodeError: +                host = host.encode("utf8", "surrogateescape") + +        self.data.host = host + +        # Update host header +        if "host" in self.headers: +            if host: +                self.headers["host"] = host +            else: +                self.headers.pop("host") + +    @property +    def port(self): +        """ +        Target port +        """ +        return self.data.port + +    @port.setter +    def port(self, port): +        self.data.port = port + +    @property +    def path(self): +        """ +        HTTP request path, e.g. "/index.html". +        Guaranteed to start with a slash. +        """ +        return _native(self.data.path) + +    @path.setter +    def path(self, path): +        self.data.path = _always_bytes(path) + +    @property +    def url(self): +        """ +        The URL string, constructed from the request's URL components +        """ +        return utils.unparse_url(self.scheme, self.host, self.port, self.path) + +    @url.setter +    def url(self, url): +        self.scheme, self.host, self.port, self.path = utils.parse_url(url) + +    @property +    def pretty_host(self): +        """ +        Similar to :py:attr:`host`, but using the Host headers as an additional preferred data source. +        This is useful in transparent mode where :py:attr:`host` is only an IP address, +        but may not reflect the actual destination as the Host header could be spoofed. +        """ +        return self.headers.get("host", self.host) + +    @property +    def pretty_url(self): +        """ +        Like :py:attr:`url`, but using :py:attr:`pretty_host` instead of :py:attr:`host`. +        """ +        if self.first_line_format == "authority": +            return "%s:%d" % (self.pretty_host, self.port) +        return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) + +    @property +    def query(self): +        """ +        The request query string as an :py:class:`ODict` object. +        None, if there is no query. +        """ +        _, _, _, _, query, _ = urllib.parse.urlparse(self.url) +        if query: +            return ODict(utils.urldecode(query)) +        return None + +    @query.setter +    def query(self, odict): +        query = utils.urlencode(odict.lst) +        scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) +        self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + +    @property +    def cookies(self): +        """ +        The request cookies. +        An empty :py:class:`ODict` object if the cookie monster ate them all. +        """ +        ret = ODict() +        for i in self.headers.get_all("Cookie"): +            ret.extend(cookies.parse_cookie_header(i)) +        return ret + +    @cookies.setter +    def cookies(self, odict): +        self.headers["cookie"] = cookies.format_cookie_header(odict) + +    @property +    def path_components(self): +        """ +        The URL's path components as a list of strings. +        Components are unquoted. +        """ +        _, _, path, _, _, _ = urllib.parse.urlparse(self.url) +        return [urllib.parse.unquote(i) for i in path.split("/") if i] + +    @path_components.setter +    def path_components(self, components): +        components = map(lambda x: urllib.parse.quote(x, safe=""), components) +        path = "/" + "/".join(components) +        scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) +        self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + +    def anticache(self): +        """ +        Modifies this request to remove headers that might produce a cached +        response. That is, we remove ETags and If-Modified-Since headers. +        """ +        delheaders = [ +            "if-modified-since", +            "if-none-match", +        ] +        for i in delheaders: +            self.headers.pop(i, None) + +    def anticomp(self): +        """ +        Modifies this request to remove headers that will compress the +        resource's data. +        """ +        self.headers["accept-encoding"] = "identity" + +    def constrain_encoding(self): +        """ +        Limits the permissible Accept-Encoding values, based on what we can +        decode appropriately. +        """ +        accept_encoding = self.headers.get("accept-encoding") +        if accept_encoding: +            self.headers["accept-encoding"] = ( +                ', '.join( +                    e +                    for e in encoding.ENCODINGS +                    if e in accept_encoding +                ) +            ) + +    @property +    def urlencoded_form(self): +        """ +        The URL-encoded form data as an :py:class:`ODict` object. +        None if there is no data or the content-type indicates non-form data. +        """ +        is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() +        if self.content and is_valid_content_type: +            return ODict(utils.urldecode(self.content)) +        return None + +    @urlencoded_form.setter +    def urlencoded_form(self, odict): +        """ +        Sets the body to the URL-encoded form data, and adds the appropriate content-type header. +        This will overwrite the existing content if there is one. +        """ +        self.headers["content-type"] = "application/x-www-form-urlencoded" +        self.content = utils.urlencode(odict.lst) + +    @property +    def multipart_form(self): +        """ +        The multipart form data as an :py:class:`ODict` object. +        None if there is no data or the content-type indicates non-form data. +        """ +        is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() +        if self.content and is_valid_content_type: +            return ODict(utils.multipartdecode(self.headers,self.content)) +        return None + +    @multipart_form.setter +    def multipart_form(self, value): +        raise NotImplementedError() + +    # Legacy + +    def get_cookies(self):  # pragma: nocover +        warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) +        return self.cookies + +    def set_cookies(self, odict):  # pragma: nocover +        warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) +        self.cookies = odict + +    def get_query(self):  # pragma: nocover +        warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) +        return self.query or ODict([]) + +    def set_query(self, odict):  # pragma: nocover +        warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) +        self.query = odict + +    def get_path_components(self):  # pragma: nocover +        warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) +        return self.path_components + +    def set_path_components(self, lst):  # pragma: nocover +        warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) +        self.path_components = lst + +    def get_form_urlencoded(self):  # pragma: nocover +        warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) +        return self.urlencoded_form or ODict([]) + +    def set_form_urlencoded(self, odict):  # pragma: nocover +        warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) +        self.urlencoded_form = odict + +    def get_form_multipart(self):  # pragma: nocover +        warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) +        return self.multipart_form or ODict([]) + +    @property +    def form_in(self):  # pragma: nocover +        warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) +        return self.first_line_format + +    @form_in.setter +    def form_in(self, form_in):  # pragma: nocover +        warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) +        self.first_line_format = form_in + +    @property +    def form_out(self):  # pragma: nocover +        warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) +        return self.first_line_format + +    @form_out.setter +    def form_out(self, form_out):  # pragma: nocover +        warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) +        self.first_line_format = form_out
\ No newline at end of file diff --git a/netlib/netlib/http/response.py b/netlib/netlib/http/response.py new file mode 100644 index 00000000..8f4d6215 --- /dev/null +++ b/netlib/netlib/http/response.py @@ -0,0 +1,116 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +from . import cookies +from .headers import Headers +from .message import Message, _native, _always_bytes, MessageData +from .. import utils +from ..odict import ODict + + +class ResponseData(MessageData): +    def __init__(self, http_version, status_code, reason=None, headers=None, content=None, +                 timestamp_start=None, timestamp_end=None): +        if not isinstance(headers, Headers): +            headers = Headers(headers) + +        self.http_version = http_version +        self.status_code = status_code +        self.reason = reason +        self.headers = headers +        self.content = content +        self.timestamp_start = timestamp_start +        self.timestamp_end = timestamp_end + + +class Response(Message): +    """ +    An HTTP response. +    """ +    def __init__(self, *args, **kwargs): +        data = ResponseData(*args, **kwargs) +        super(Response, self).__init__(data) + +    def __repr__(self): +        if self.content: +            details = "{}, {}".format( +                self.headers.get("content-type", "unknown content type"), +                utils.pretty_size(len(self.content)) +            ) +        else: +            details = "no content" +        return "Response({status_code} {reason}, {details})".format( +            status_code=self.status_code, +            reason=self.reason, +            details=details +        ) + +    @property +    def status_code(self): +        """ +        HTTP Status Code, e.g. ``200``. +        """ +        return self.data.status_code + +    @status_code.setter +    def status_code(self, status_code): +        self.data.status_code = status_code + +    @property +    def reason(self): +        """ +        HTTP Reason Phrase, e.g. "Not Found". +        This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase. +        """ +        return _native(self.data.reason) + +    @reason.setter +    def reason(self, reason): +        self.data.reason = _always_bytes(reason) + +    @property +    def cookies(self): +        """ +        Get the contents of all Set-Cookie headers. + +        A possibly empty :py:class:`ODict`, where keys are cookie name strings, +        and values are [value, attr] lists. Value is a string, and attr is +        an ODictCaseless containing cookie attributes. Within attrs, unary +        attributes (e.g. HTTPOnly) are indicated by a Null value. +        """ +        ret = [] +        for header in self.headers.get_all("set-cookie"): +            v = cookies.parse_set_cookie_header(header) +            if v: +                name, value, attrs = v +                ret.append([name, [value, attrs]]) +        return ODict(ret) + +    @cookies.setter +    def cookies(self, odict): +        values = [] +        for i in odict.lst: +            header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) +            values.append(header) +        self.headers.set_all("set-cookie", values) + +    # Legacy + +    def get_cookies(self):  # pragma: nocover +        warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) +        return self.cookies + +    def set_cookies(self, odict):  # pragma: nocover +        warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) +        self.cookies = odict + +    @property +    def msg(self):  # pragma: nocover +        warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) +        return self.reason + +    @msg.setter +    def msg(self, reason):  # pragma: nocover +        warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) +        self.reason = reason diff --git a/netlib/netlib/http/status_codes.py b/netlib/netlib/http/status_codes.py new file mode 100644 index 00000000..8a4dc1f5 --- /dev/null +++ b/netlib/netlib/http/status_codes.py @@ -0,0 +1,106 @@ +from __future__ import absolute_import, print_function, division + +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 + +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 + +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 +REQUESTED_RANGE_NOT_SATISFIABLE = 416 +EXPECTATION_FAILED = 417 +IM_A_TEAPOT = 418 + +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510 + +RESPONSES = { +    # 100 +    CONTINUE: "Continue", +    SWITCHING: "Switching Protocols", + +    # 200 +    OK: "OK", +    CREATED: "Created", +    ACCEPTED: "Accepted", +    NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", +    NO_CONTENT: "No Content", +    RESET_CONTENT: "Reset Content.", +    PARTIAL_CONTENT: "Partial Content", +    MULTI_STATUS: "Multi-Status", + +    # 300 +    MULTIPLE_CHOICE: "Multiple Choices", +    MOVED_PERMANENTLY: "Moved Permanently", +    FOUND: "Found", +    SEE_OTHER: "See Other", +    NOT_MODIFIED: "Not Modified", +    USE_PROXY: "Use Proxy", +    # 306 not defined?? +    TEMPORARY_REDIRECT: "Temporary Redirect", + +    # 400 +    BAD_REQUEST: "Bad Request", +    UNAUTHORIZED: "Unauthorized", +    PAYMENT_REQUIRED: "Payment Required", +    FORBIDDEN: "Forbidden", +    NOT_FOUND: "Not Found", +    NOT_ALLOWED: "Method Not Allowed", +    NOT_ACCEPTABLE: "Not Acceptable", +    PROXY_AUTH_REQUIRED: "Proxy Authentication Required", +    REQUEST_TIMEOUT: "Request Time-out", +    CONFLICT: "Conflict", +    GONE: "Gone", +    LENGTH_REQUIRED: "Length Required", +    PRECONDITION_FAILED: "Precondition Failed", +    REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", +    REQUEST_URI_TOO_LONG: "Request-URI Too Long", +    UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", +    REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", +    EXPECTATION_FAILED: "Expectation Failed", +    IM_A_TEAPOT: "I'm a teapot", + +    # 500 +    INTERNAL_SERVER_ERROR: "Internal Server Error", +    NOT_IMPLEMENTED: "Not Implemented", +    BAD_GATEWAY: "Bad Gateway", +    SERVICE_UNAVAILABLE: "Service Unavailable", +    GATEWAY_TIMEOUT: "Gateway Time-out", +    HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", +    INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", +    NOT_EXTENDED: "Not Extended" +} diff --git a/netlib/netlib/http/user_agents.py b/netlib/netlib/http/user_agents.py new file mode 100644 index 00000000..e8681908 --- /dev/null +++ b/netlib/netlib/http/user_agents.py @@ -0,0 +1,52 @@ +from __future__ import (absolute_import, print_function, division) + +""" +    A small collection of useful user-agent header strings. These should be +    kept reasonably current to reflect common usage. +""" + +# pylint: line-too-long + +# A collection of (name, shortcut, string) tuples. + +UASTRINGS = [ +    ("android", +     "a", +     "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"),  # noqa +    ("blackberry", +     "l", +     "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"),  # noqa +    ("bingbot", +     "b", +     "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"),  # noqa +    ("chrome", +     "c", +     "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"),  # noqa +    ("firefox", +     "f", +     "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"),  # noqa +    ("googlebot", +     "g", +     "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"),  # noqa +    ("ie9", +     "i", +     "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"),  # noqa +    ("ipad", +     "p", +     "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"),  # noqa +    ("iphone", +     "h", +     "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"),  # noqa +    ("safari", +     "s", +     "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"),  # noqa +] + + +def get_by_shortcut(s): +    """ +        Retrieve a user agent entry by shortcut. +    """ +    for i in UASTRINGS: +        if s == i[1]: +            return i diff --git a/netlib/netlib/odict.py b/netlib/netlib/odict.py new file mode 100644 index 00000000..1e6e381a --- /dev/null +++ b/netlib/netlib/odict.py @@ -0,0 +1,193 @@ +from __future__ import (absolute_import, print_function, division) +import re +import copy +import six + +from .utils import Serializable + + +def safe_subn(pattern, repl, target, *args, **kwargs): +    """ +        There are Unicode conversion problems with re.subn. We try to smooth +        that over by casting the pattern and replacement to strings. We really +        need a better solution that is aware of the actual content ecoding. +    """ +    return re.subn(str(pattern), str(repl), target, *args, **kwargs) + + +class ODict(Serializable): + +    """ +        A dictionary-like object for managing ordered (key, value) data. Think +        about it as a convenient interface to a list of (key, value) tuples. +    """ + +    def __init__(self, lst=None): +        self.lst = lst or [] + +    def _kconv(self, s): +        return s + +    def __eq__(self, other): +        return self.lst == other.lst + +    def __ne__(self, other): +        return not self.__eq__(other) + +    def __iter__(self): +        return self.lst.__iter__() + +    def __getitem__(self, k): +        """ +            Returns a list of values matching key. +        """ +        ret = [] +        k = self._kconv(k) +        for i in self.lst: +            if self._kconv(i[0]) == k: +                ret.append(i[1]) +        return ret + +    def keys(self): +        return list(set([self._kconv(i[0]) for i in self.lst])) + +    def _filter_lst(self, k, lst): +        k = self._kconv(k) +        new = [] +        for i in lst: +            if self._kconv(i[0]) != k: +                new.append(i) +        return new + +    def __len__(self): +        """ +            Total number of (key, value) pairs. +        """ +        return len(self.lst) + +    def __setitem__(self, k, valuelist): +        """ +            Sets the values for key k. If there are existing values for this +            key, they are cleared. +        """ +        if isinstance(valuelist, six.text_type) or isinstance(valuelist, six.binary_type): +            raise ValueError( +                "Expected list of values instead of string. " +                "Example: odict[b'Host'] = [b'www.example.com']" +            ) +        kc = self._kconv(k) +        new = [] +        for i in self.lst: +            if self._kconv(i[0]) == kc: +                if valuelist: +                    new.append([k, valuelist.pop(0)]) +            else: +                new.append(i) +        while valuelist: +            new.append([k, valuelist.pop(0)]) +        self.lst = new + +    def __delitem__(self, k): +        """ +            Delete all items matching k. +        """ +        self.lst = self._filter_lst(k, self.lst) + +    def __contains__(self, k): +        k = self._kconv(k) +        for i in self.lst: +            if self._kconv(i[0]) == k: +                return True +        return False + +    def add(self, key, value, prepend=False): +        if prepend: +            self.lst.insert(0, [key, value]) +        else: +            self.lst.append([key, value]) + +    def get(self, k, d=None): +        if k in self: +            return self[k] +        else: +            return d + +    def get_first(self, k, d=None): +        if k in self: +            return self[k][0] +        else: +            return d + +    def items(self): +        return self.lst[:] + +    def copy(self): +        """ +            Returns a copy of this object. +        """ +        lst = copy.deepcopy(self.lst) +        return self.__class__(lst) + +    def extend(self, other): +        """ +            Add the contents of other, preserving any duplicates. +        """ +        self.lst.extend(other.lst) + +    def __repr__(self): +        return repr(self.lst) + +    def in_any(self, key, value, caseless=False): +        """ +            Do any of the values matching key contain value? + +            If caseless is true, value comparison is case-insensitive. +        """ +        if caseless: +            value = value.lower() +        for i in self[key]: +            if caseless: +                i = i.lower() +            if value in i: +                return True +        return False + +    def replace(self, pattern, repl, *args, **kwargs): +        """ +            Replaces a regular expression pattern with repl in both keys and +            values. Encoded content will be decoded before replacement, and +            re-encoded afterwards. + +            Returns the number of replacements made. +        """ +        nlst, count = [], 0 +        for i in self.lst: +            k, c = safe_subn(pattern, repl, i[0], *args, **kwargs) +            count += c +            v, c = safe_subn(pattern, repl, i[1], *args, **kwargs) +            count += c +            nlst.append([k, v]) +        self.lst = nlst +        return count + +    # Implement the StateObject protocol from mitmproxy +    def get_state(self): +        return [tuple(i) for i in self.lst] + +    def set_state(self, state): +        self.lst = [list(i) for i in state] + +    @classmethod +    def from_state(cls, state): +        return cls([list(i) for i in state]) + + +class ODictCaseless(ODict): + +    """ +        A variant of ODict with "caseless" keys. This version _preserves_ key +        case, but does not consider case when setting or getting items. +    """ + +    def _kconv(self, s): +        return s.lower() diff --git a/netlib/netlib/socks.py b/netlib/netlib/socks.py new file mode 100644 index 00000000..51ad1c63 --- /dev/null +++ b/netlib/netlib/socks.py @@ -0,0 +1,176 @@ +from __future__ import (absolute_import, print_function, division) +import struct +import array +import ipaddress +from . import tcp, utils + + +class SocksError(Exception): +    def __init__(self, code, message): +        super(SocksError, self).__init__(message) +        self.code = code + + +VERSION = utils.BiDi( +    SOCKS4=0x04, +    SOCKS5=0x05 +) + +CMD = utils.BiDi( +    CONNECT=0x01, +    BIND=0x02, +    UDP_ASSOCIATE=0x03 +) + +ATYP = utils.BiDi( +    IPV4_ADDRESS=0x01, +    DOMAINNAME=0x03, +    IPV6_ADDRESS=0x04 +) + +REP = utils.BiDi( +    SUCCEEDED=0x00, +    GENERAL_SOCKS_SERVER_FAILURE=0x01, +    CONNECTION_NOT_ALLOWED_BY_RULESET=0x02, +    NETWORK_UNREACHABLE=0x03, +    HOST_UNREACHABLE=0x04, +    CONNECTION_REFUSED=0x05, +    TTL_EXPIRED=0x06, +    COMMAND_NOT_SUPPORTED=0x07, +    ADDRESS_TYPE_NOT_SUPPORTED=0x08, +) + +METHOD = utils.BiDi( +    NO_AUTHENTICATION_REQUIRED=0x00, +    GSSAPI=0x01, +    USERNAME_PASSWORD=0x02, +    NO_ACCEPTABLE_METHODS=0xFF +) + + +class ClientGreeting(object): +    __slots__ = ("ver", "methods") + +    def __init__(self, ver, methods): +        self.ver = ver +        self.methods = array.array("B") +        self.methods.extend(methods) + +    def assert_socks5(self): +        if self.ver != VERSION.SOCKS5: +            if self.ver == ord("G") and len(self.methods) == ord("E"): +                guess = "Probably not a SOCKS request but a regular HTTP request. " +            else: +                guess = "" + +            raise SocksError( +                REP.GENERAL_SOCKS_SERVER_FAILURE, +                guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver +            ) + +    @classmethod +    def from_file(cls, f, fail_early=False): +        """ +        :param fail_early: If true, a SocksError will be raised if the first byte does not indicate socks5. +        """ +        ver, nmethods = struct.unpack("!BB", f.safe_read(2)) +        client_greeting = cls(ver, []) +        if fail_early: +            client_greeting.assert_socks5() +        client_greeting.methods.fromstring(f.safe_read(nmethods)) +        return client_greeting + +    def to_file(self, f): +        f.write(struct.pack("!BB", self.ver, len(self.methods))) +        f.write(self.methods.tostring()) + + +class ServerGreeting(object): +    __slots__ = ("ver", "method") + +    def __init__(self, ver, method): +        self.ver = ver +        self.method = method + +    def assert_socks5(self): +        if self.ver != VERSION.SOCKS5: +            if self.ver == ord("H") and self.method == ord("T"): +                guess = "Probably not a SOCKS request but a regular HTTP response. " +            else: +                guess = "" + +            raise SocksError( +                REP.GENERAL_SOCKS_SERVER_FAILURE, +                guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver +            ) + +    @classmethod +    def from_file(cls, f): +        ver, method = struct.unpack("!BB", f.safe_read(2)) +        return cls(ver, method) + +    def to_file(self, f): +        f.write(struct.pack("!BB", self.ver, self.method)) + + +class Message(object): +    __slots__ = ("ver", "msg", "atyp", "addr") + +    def __init__(self, ver, msg, atyp, addr): +        self.ver = ver +        self.msg = msg +        self.atyp = atyp +        self.addr = tcp.Address.wrap(addr) + +    def assert_socks5(self): +        if self.ver != VERSION.SOCKS5: +            raise SocksError( +                REP.GENERAL_SOCKS_SERVER_FAILURE, +                "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver +            ) + +    @classmethod +    def from_file(cls, f): +        ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) +        if rsv != 0x00: +            raise SocksError( +                REP.GENERAL_SOCKS_SERVER_FAILURE, +                "Socks Request: Invalid reserved byte: %s" % rsv +            ) +        if atyp == ATYP.IPV4_ADDRESS: +            # We use tnoa here as ntop is not commonly available on Windows. +            host = ipaddress.IPv4Address(f.safe_read(4)).compressed +            use_ipv6 = False +        elif atyp == ATYP.IPV6_ADDRESS: +            host = ipaddress.IPv6Address(f.safe_read(16)).compressed +            use_ipv6 = True +        elif atyp == ATYP.DOMAINNAME: +            length, = struct.unpack("!B", f.safe_read(1)) +            host = f.safe_read(length) +            if not utils.is_valid_host(host): +                raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Invalid hostname: %s" % host) +            host = host.decode("idna") +            use_ipv6 = False +        else: +            raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, +                             "Socks Request: Unknown ATYP: %s" % atyp) + +        port, = struct.unpack("!H", f.safe_read(2)) +        addr = tcp.Address((host, port), use_ipv6=use_ipv6) +        return cls(ver, msg, atyp, addr) + +    def to_file(self, f): +        f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) +        if self.atyp == ATYP.IPV4_ADDRESS: +            f.write(ipaddress.IPv4Address(self.addr.host).packed) +        elif self.atyp == ATYP.IPV6_ADDRESS: +            f.write(ipaddress.IPv6Address(self.addr.host).packed) +        elif self.atyp == ATYP.DOMAINNAME: +            f.write(struct.pack("!B", len(self.addr.host))) +            f.write(self.addr.host.encode("idna")) +        else: +            raise SocksError( +                REP.ADDRESS_TYPE_NOT_SUPPORTED, +                "Unknown ATYP: %s" % self.atyp +            ) +        f.write(struct.pack("!H", self.addr.port)) diff --git a/netlib/netlib/tcp.py b/netlib/netlib/tcp.py new file mode 100644 index 00000000..c8548aea --- /dev/null +++ b/netlib/netlib/tcp.py @@ -0,0 +1,908 @@ +from __future__ import (absolute_import, print_function, division) +import os +import select +import socket +import sys +import threading +import time +import traceback + +import binascii +from six.moves import range + +import certifi +from backports import ssl_match_hostname +import six +import OpenSSL +from OpenSSL import SSL + +from . import certutils, version_check, utils + +# This is a rather hackish way to make sure that +# the latest version of pyOpenSSL is actually installed. +from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ +    TcpTimeout, TcpDisconnect, TcpException + +version_check.check_pyopenssl_version() + +if six.PY2: +    socket_fileobject = socket._fileobject +else: +    socket_fileobject = socket.SocketIO + +EINTR = 4 +HAS_ALPN = OpenSSL._util.lib.Cryptography_HAS_ALPN + +# To enable all SSL methods use: SSLv23 +# then add options to disable certain methods +# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 +SSL_BASIC_OPTIONS = ( +    SSL.OP_CIPHER_SERVER_PREFERENCE +) +if hasattr(SSL, "OP_NO_COMPRESSION"): +    SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION + +SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD +SSL_DEFAULT_OPTIONS = ( +    SSL.OP_NO_SSLv2 | +    SSL.OP_NO_SSLv3 | +    SSL_BASIC_OPTIONS +) +if hasattr(SSL, "OP_NO_COMPRESSION"): +    SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION + +""" +Map a reasonable SSL version specification into the format OpenSSL expects. +Don't ask... +https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 +""" +sslversion_choices = { +    "all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS), +    # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+ +    # TLSv1_METHOD would be TLS 1.0 only +    "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)), +    "SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS), +    "SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS), +    "TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS), +    "TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS), +    "TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS), +} + +class SSLKeyLogger(object): + +    def __init__(self, filename): +        self.filename = filename +        self.f = None +        self.lock = threading.Lock() + +    # required for functools.wraps, which pyOpenSSL uses. +    __name__ = "SSLKeyLogger" + +    def __call__(self, connection, where, ret): +        if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: +            with self.lock: +                if not self.f: +                    d = os.path.dirname(self.filename) +                    if not os.path.isdir(d): +                        os.makedirs(d) +                    self.f = open(self.filename, "ab") +                    self.f.write(b"\r\n") +                client_random = binascii.hexlify(connection.client_random()) +                masterkey = binascii.hexlify(connection.master_key()) +                self.f.write(b"CLIENT_RANDOM %s %s\r\n" % (client_random, masterkey)) +                self.f.flush() + +    def close(self): +        with self.lock: +            if self.f: +                self.f.close() + +    @staticmethod +    def create_logfun(filename): +        if filename: +            return SSLKeyLogger(filename) +        return False + +log_ssl_key = SSLKeyLogger.create_logfun( +    os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) + + +class _FileLike(object): +    BLOCKSIZE = 1024 * 32 + +    def __init__(self, o): +        self.o = o +        self._log = None +        self.first_byte_timestamp = None + +    def set_descriptor(self, o): +        self.o = o + +    def __getattr__(self, attr): +        return getattr(self.o, attr) + +    def start_log(self): +        """ +            Starts or resets the log. + +            This will store all bytes read or written. +        """ +        self._log = [] + +    def stop_log(self): +        """ +            Stops the log. +        """ +        self._log = None + +    def is_logging(self): +        return self._log is not None + +    def get_log(self): +        """ +            Returns the log as a string. +        """ +        if not self.is_logging(): +            raise ValueError("Not logging!") +        return b"".join(self._log) + +    def add_log(self, v): +        if self.is_logging(): +            self._log.append(v) + +    def reset_timestamps(self): +        self.first_byte_timestamp = None + + +class Writer(_FileLike): + +    def flush(self): +        """ +            May raise TcpDisconnect +        """ +        if hasattr(self.o, "flush"): +            try: +                self.o.flush() +            except (socket.error, IOError) as v: +                raise TcpDisconnect(str(v)) + +    def write(self, v): +        """ +            May raise TcpDisconnect +        """ +        if v: +            self.first_byte_timestamp = self.first_byte_timestamp or time.time() +            try: +                if hasattr(self.o, "sendall"): +                    self.add_log(v) +                    return self.o.sendall(v) +                else: +                    r = self.o.write(v) +                    self.add_log(v[:r]) +                    return r +            except (SSL.Error, socket.error) as e: +                raise TcpDisconnect(str(e)) + + +class Reader(_FileLike): + +    def read(self, length): +        """ +            If length is -1, we read until connection closes. +        """ +        result = b'' +        start = time.time() +        while length == -1 or length > 0: +            if length == -1 or length > self.BLOCKSIZE: +                rlen = self.BLOCKSIZE +            else: +                rlen = length +            try: +                data = self.o.read(rlen) +            except SSL.ZeroReturnError: +                # TLS connection was shut down cleanly +                break +            except (SSL.WantWriteError, SSL.WantReadError): +                # From the OpenSSL docs: +                # If the underlying BIO is non-blocking, SSL_read() will also return when the +                # underlying BIO could not satisfy the needs of SSL_read() to continue the +                # operation. In this case a call to SSL_get_error with the return value of +                # SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE. +                if (time.time() - start) < self.o.gettimeout(): +                    time.sleep(0.1) +                    continue +                else: +                    raise TcpTimeout() +            except socket.timeout: +                raise TcpTimeout() +            except socket.error as e: +                raise TcpDisconnect(str(e)) +            except SSL.SysCallError as e: +                if e.args == (-1, 'Unexpected EOF'): +                    break +                raise TlsException(str(e)) +            except SSL.Error as e: +                raise TlsException(str(e)) +            self.first_byte_timestamp = self.first_byte_timestamp or time.time() +            if not data: +                break +            result += data +            if length != -1: +                length -= len(data) +        self.add_log(result) +        return result + +    def readline(self, size=None): +        result = b'' +        bytes_read = 0 +        while True: +            if size is not None and bytes_read >= size: +                break +            ch = self.read(1) +            bytes_read += 1 +            if not ch: +                break +            else: +                result += ch +                if ch == b'\n': +                    break +        return result + +    def safe_read(self, length): +        """ +            Like .read, but is guaranteed to either return length bytes, or +            raise an exception. +        """ +        result = self.read(length) +        if length != -1 and len(result) != length: +            if not result: +                raise TcpDisconnect() +            else: +                raise TcpReadIncomplete( +                    "Expected %s bytes, got %s" % (length, len(result)) +                ) +        return result + +    def peek(self, length): +        """ +        Tries to peek into the underlying file object. + +        Returns: +            Up to the next N bytes if peeking is successful. + +        Raises: +            TcpException if there was an error with the socket +            TlsException if there was an error with pyOpenSSL. +            NotImplementedError if the underlying file object is not a [pyOpenSSL] socket +        """ +        if isinstance(self.o, socket_fileobject): +            try: +                return self.o._sock.recv(length, socket.MSG_PEEK) +            except socket.error as e: +                raise TcpException(repr(e)) +        elif isinstance(self.o, SSL.Connection): +            try: +                if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): +                    return self.o.recv(length, socket.MSG_PEEK) +                else: +                    # TODO: remove once a new version is released +                    # Polyfill for pyOpenSSL <= 0.15.1 +                    # Taken from https://github.com/pyca/pyopenssl/commit/1d95dea7fea03c7c0df345a5ea30c12d8a0378d2 +                    buf = SSL._ffi.new("char[]", length) +                    result = SSL._lib.SSL_peek(self.o._ssl, buf, length) +                    self.o._raise_ssl_error(self.o._ssl, result) +                    return SSL._ffi.buffer(buf, result)[:] +            except SSL.Error as e: +                six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2]) +        else: +            raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") + + +class Address(utils.Serializable): + +    """ +        This class wraps an IPv4/IPv6 tuple to provide named attributes and +        ipv6 information. +    """ + +    def __init__(self, address, use_ipv6=False): +        self.address = tuple(address) +        self.use_ipv6 = use_ipv6 + +    def get_state(self): +        return { +            "address": self.address, +            "use_ipv6": self.use_ipv6 +        } + +    def set_state(self, state): +        self.address = state["address"] +        self.use_ipv6 = state["use_ipv6"] + +    @classmethod +    def from_state(cls, state): +        return Address(**state) + +    @classmethod +    def wrap(cls, t): +        if isinstance(t, cls): +            return t +        else: +            return cls(t) + +    def __call__(self): +        return self.address + +    @property +    def host(self): +        return self.address[0] + +    @property +    def port(self): +        return self.address[1] + +    @property +    def use_ipv6(self): +        return self.family == socket.AF_INET6 + +    @use_ipv6.setter +    def use_ipv6(self, b): +        self.family = socket.AF_INET6 if b else socket.AF_INET + +    def __repr__(self): +        return "{}:{}".format(self.host, self.port) + +    def __str__(self): +        return str(self.address) + +    def __eq__(self, other): +        if not other: +            return False +        other = Address.wrap(other) +        return (self.address, self.family) == (other.address, other.family) + +    def __ne__(self, other): +        return not self.__eq__(other) + +    def __hash__(self): +        return hash(self.address) ^ 42  # different hash than the tuple alone. + + +def ssl_read_select(rlist, timeout): +    """ +    This is a wrapper around select.select() which also works for SSL.Connections +    by taking ssl_connection.pending() into account. + +    Caveats: +        If .pending() > 0 for any of the connections in rlist, we avoid the select syscall +        and **will not include any other connections which may or may not be ready**. + +    Args: +        rlist: wait until ready for reading + +    Returns: +        subset of rlist which is ready for reading. +    """ +    return [ +        conn for conn in rlist +        if isinstance(conn, SSL.Connection) and conn.pending() > 0 +    ] or select.select(rlist, (), (), timeout)[0] + + +def close_socket(sock): +    """ +    Does a hard close of a socket, without emitting a RST. +    """ +    try: +        # We already indicate that we close our end. +        # may raise "Transport endpoint is not connected" on Linux +        sock.shutdown(socket.SHUT_WR) + +        # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending +        # readable data could lead to an immediate RST being sent (which is the +        # case on Windows). +        # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html +        # +        # This in turn results in the following issue: If we send an error page +        # to the client and then close the socket, the RST may be received by +        # the client before the error page and the users sees a connection +        # error rather than the error page. Thus, we try to empty the read +        # buffer on Windows first. (see +        # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) +        # + +        if os.name == "nt":  # pragma: no cover +            # We cannot rely on the shutdown()-followed-by-read()-eof technique +            # proposed by the page above: Some remote machines just don't send +            # a TCP FIN, which would leave us in the unfortunate situation that +            # recv() would block infinitely. As a workaround, we set a timeout +            # here even if we are in blocking mode. +            sock.settimeout(sock.gettimeout() or 20) + +            # limit at a megabyte so that we don't read infinitely +            for _ in range(1024 ** 3 // 4096): +                # may raise a timeout/disconnect exception. +                if not sock.recv(4096): +                    break + +        # Now we can close the other half as well. +        sock.shutdown(socket.SHUT_RD) + +    except socket.error: +        pass + +    sock.close() + + +class _Connection(object): + +    rbufsize = -1 +    wbufsize = -1 + +    def _makefile(self): +        """ +        Set up .rfile and .wfile attributes from .connection +        """ +        # Ideally, we would use the Buffered IO in Python 3 by default. +        # Unfortunately, the implementation of .peek() is broken for n>1 bytes, +        # as it may just return what's left in the buffer and not all the bytes we want. +        # As a workaround, we just use unbuffered sockets directly. +        # https://mail.python.org/pipermail/python-dev/2009-June/089986.html +        if six.PY2: +            self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) +            self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) +        else: +            self.rfile = Reader(socket.SocketIO(self.connection, "rb")) +            self.wfile = Writer(socket.SocketIO(self.connection, "wb")) + +    def __init__(self, connection): +        if connection: +            self.connection = connection +            self._makefile() +        else: +            self.connection = None +            self.rfile = None +            self.wfile = None + +        self.ssl_established = False +        self.finished = False + +    def get_current_cipher(self): +        if not self.ssl_established: +            return None + +        name = self.connection.get_cipher_name() +        bits = self.connection.get_cipher_bits() +        version = self.connection.get_cipher_version() +        return name, bits, version + +    def finish(self): +        self.finished = True +        # If we have an SSL connection, wfile.close == connection.close +        # (We call _FileLike.set_descriptor(conn)) +        # Closing the socket is not our task, therefore we don't call close +        # then. +        if not isinstance(self.connection, SSL.Connection): +            if not getattr(self.wfile, "closed", False): +                try: +                    self.wfile.flush() +                    self.wfile.close() +                except TcpDisconnect: +                    pass + +            self.rfile.close() +        else: +            try: +                self.connection.shutdown() +            except SSL.Error: +                pass + +    def _create_ssl_context(self, +                            method=SSL_DEFAULT_METHOD, +                            options=SSL_DEFAULT_OPTIONS, +                            verify_options=SSL.VERIFY_NONE, +                            ca_path=None, +                            ca_pemfile=None, +                            cipher_list=None, +                            alpn_protos=None, +                            alpn_select=None, +                            alpn_select_callback=None, +                            ): +        """ +        Creates an SSL Context. + +        :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD +        :param options: A bit field consisting of OpenSSL.SSL.OP_* values +        :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values +        :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool +        :param ca_pemfile: Path to a PEM formatted trusted CA certificate +        :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html +        :rtype : SSL.Context +        """ +        context = SSL.Context(method) +        # Options (NO_SSLv2/3) +        if options is not None: +            context.set_options(options) + +        # Verify Options (NONE/PEER and trusted CAs) +        if verify_options is not None: +            def verify_cert(conn, x509, errno, err_depth, is_cert_verified): +                if not is_cert_verified: +                    self.ssl_verification_error = dict(errno=errno, +                                                       depth=err_depth) +                return is_cert_verified + +            context.set_verify(verify_options, verify_cert) +            if ca_path is None and ca_pemfile is None: +                ca_pemfile = certifi.where() +            context.load_verify_locations(ca_pemfile, ca_path) + +        # Workaround for +        # https://github.com/pyca/pyopenssl/issues/190 +        # https://github.com/mitmproxy/mitmproxy/issues/472 +        # Options already set before are not cleared. +        context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) + +        # Cipher List +        if cipher_list: +            try: +                context.set_cipher_list(cipher_list) + +                # TODO: maybe change this to with newer pyOpenSSL APIs +                context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) +            except SSL.Error as v: +                raise TlsException("SSL cipher specification error: %s" % str(v)) + +        # SSLKEYLOGFILE +        if log_ssl_key: +            context.set_info_callback(log_ssl_key) + +        if HAS_ALPN: +            if alpn_protos is not None: +                # advertise application layer protocols +                context.set_alpn_protos(alpn_protos) +            elif alpn_select is not None and alpn_select_callback is None: +                # select application layer protocol +                def alpn_select_callback(conn_, options): +                    if alpn_select in options: +                        return bytes(alpn_select) +                    else:  # pragma no cover +                        return options[0] +                context.set_alpn_select_callback(alpn_select_callback) +            elif alpn_select_callback is not None and alpn_select is None: +                context.set_alpn_select_callback(alpn_select_callback) +            elif alpn_select_callback is not None and alpn_select is not None: +                raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") + +        return context + + +class TCPClient(_Connection): + +    def __init__(self, address, source_address=None): +        super(TCPClient, self).__init__(None) +        self.address = address +        self.source_address = source_address +        self.cert = None +        self.ssl_verification_error = None +        self.sni = None + +    @property +    def address(self): +        return self.__address + +    @address.setter +    def address(self, address): +        if address: +            self.__address = Address.wrap(address) +        else: +            self.__address = None + +    @property +    def source_address(self): +        return self.__source_address + +    @source_address.setter +    def source_address(self, source_address): +        if source_address: +            self.__source_address = Address.wrap(source_address) +        else: +            self.__source_address = None + +    def close(self): +        # Make sure to close the real socket, not the SSL proxy. +        # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, +        # it tries to renegotiate... +        if isinstance(self.connection, SSL.Connection): +            close_socket(self.connection._socket) +        else: +            close_socket(self.connection) + +    def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): +        context = self._create_ssl_context( +            alpn_protos=alpn_protos, +            **sslctx_kwargs) +        # Client Certs +        if cert: +            try: +                context.use_privatekey_file(cert) +                context.use_certificate_file(cert) +            except SSL.Error as v: +                raise TlsException("SSL client certificate error: %s" % str(v)) +        return context + +    def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): +        """ +            cert: Path to a file containing both client cert and private key. + +            options: A bit field consisting of OpenSSL.SSL.OP_* values +            verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values +            ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool +            ca_pemfile: Path to a PEM formatted trusted CA certificate +        """ +        verification_mode = sslctx_kwargs.get('verify_options', None) +        if verification_mode == SSL.VERIFY_PEER and not sni: +            raise TlsException("Cannot validate certificate hostname without SNI") + +        context = self.create_ssl_context( +            alpn_protos=alpn_protos, +            **sslctx_kwargs +        ) +        self.connection = SSL.Connection(context, self.connection) +        if sni: +            self.sni = sni +            self.connection.set_tlsext_host_name(sni) +        self.connection.set_connect_state() +        try: +            self.connection.do_handshake() +        except SSL.Error as v: +            if self.ssl_verification_error: +                raise InvalidCertificateException("SSL handshake error: %s" % repr(v)) +            else: +                raise TlsException("SSL handshake error: %s" % repr(v)) +        else: +            # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on +            # certificate validation failure +            if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error is not None: +                raise InvalidCertificateException("SSL handshake error: certificate verify failed") + +        self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) + +        # Validate TLS Hostname +        try: +            crt = dict( +                subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in self.cert.altnames] +            ) +            if self.cert.cn: +                crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]] +            if sni: +                hostname = sni.decode("ascii", "strict") +            else: +                hostname = "no-hostname" +            ssl_match_hostname.match_hostname(crt, hostname) +        except (ValueError, ssl_match_hostname.CertificateError) as e: +            self.ssl_verification_error = dict(depth=0, errno="Invalid Hostname") +            if verification_mode == SSL.VERIFY_PEER: +                raise InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e))) + +        self.ssl_established = True +        self.rfile.set_descriptor(self.connection) +        self.wfile.set_descriptor(self.connection) + +    def connect(self): +        try: +            connection = socket.socket(self.address.family, socket.SOCK_STREAM) +            if self.source_address: +                connection.bind(self.source_address()) +            connection.connect(self.address()) +            if not self.source_address: +                self.source_address = Address(connection.getsockname()) +        except (socket.error, IOError) as err: +            raise TcpException( +                'Error connecting to "%s": %s' % +                (self.address.host, err)) +        self.connection = connection +        self._makefile() + +    def settimeout(self, n): +        self.connection.settimeout(n) + +    def gettimeout(self): +        return self.connection.gettimeout() + +    def get_alpn_proto_negotiated(self): +        if HAS_ALPN and self.ssl_established: +            return self.connection.get_alpn_proto_negotiated() +        else: +            return b"" + + +class BaseHandler(_Connection): + +    """ +        The instantiator is expected to call the handle() and finish() methods. +    """ + +    def __init__(self, connection, address, server): +        super(BaseHandler, self).__init__(connection) +        self.address = Address.wrap(address) +        self.server = server +        self.clientcert = None + +    def create_ssl_context(self, +                           cert, key, +                           handle_sni=None, +                           request_client_cert=None, +                           chain_file=None, +                           dhparams=None, +                           **sslctx_kwargs): +        """ +            cert: A certutils.SSLCert object or the path to a certificate +            chain file. + +            handle_sni: SNI handler, should take a connection object. Server +            name can be retrieved like this: + +                    connection.get_servername() + +            And you can specify the connection keys as follows: + +                    new_context = Context(TLSv1_METHOD) +                    new_context.use_privatekey(key) +                    new_context.use_certificate(cert) +                    connection.set_context(new_context) + +            The request_client_cert argument requires some explanation. We're +            supposed to be able to do this with no negative effects - if the +            client has no cert to present, we're notified and proceed as usual. +            Unfortunately, Android seems to have a bug (tested on 4.2.2) - when +            an Android client is asked to present a certificate it does not +            have, it hangs up, which is frankly bogus. Some time down the track +            we may be able to make the proper behaviour the default again, but +            until then we're conservative. +        """ + +        context = self._create_ssl_context(**sslctx_kwargs) + +        context.use_privatekey(key) +        if isinstance(cert, certutils.SSLCert): +            context.use_certificate(cert.x509) +        else: +            context.use_certificate_chain_file(cert) + +        if handle_sni: +            # SNI callback happens during do_handshake() +            context.set_tlsext_servername_callback(handle_sni) + +        if request_client_cert: +            def save_cert(conn_, cert, errno_, depth_, preverify_ok_): +                self.clientcert = certutils.SSLCert(cert) +                # Return true to prevent cert verification error +                return True +            context.set_verify(SSL.VERIFY_PEER, save_cert) + +        # Cert Verify +        if chain_file: +            context.load_verify_locations(chain_file) + +        if dhparams: +            SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) + +        return context + +    def convert_to_ssl(self, cert, key, **sslctx_kwargs): +        """ +        Convert connection to SSL. +        For a list of parameters, see BaseHandler._create_ssl_context(...) +        """ + +        context = self.create_ssl_context( +            cert, +            key, +            **sslctx_kwargs) +        self.connection = SSL.Connection(context, self.connection) +        self.connection.set_accept_state() +        try: +            self.connection.do_handshake() +        except SSL.Error as v: +            raise TlsException("SSL handshake error: %s" % repr(v)) +        self.ssl_established = True +        self.rfile.set_descriptor(self.connection) +        self.wfile.set_descriptor(self.connection) + +    def handle(self):  # pragma: no cover +        raise NotImplementedError + +    def settimeout(self, n): +        self.connection.settimeout(n) + +    def get_alpn_proto_negotiated(self): +        if HAS_ALPN and self.ssl_established: +            return self.connection.get_alpn_proto_negotiated() +        else: +            return b"" + + +class TCPServer(object): +    request_queue_size = 20 + +    def __init__(self, address): +        self.address = Address.wrap(address) +        self.__is_shut_down = threading.Event() +        self.__shutdown_request = False +        self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) +        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +        self.socket.bind(self.address()) +        self.address = Address.wrap(self.socket.getsockname()) +        self.socket.listen(self.request_queue_size) + +    def connection_thread(self, connection, client_address): +        client_address = Address(client_address) +        try: +            self.handle_client_connection(connection, client_address) +        except: +            self.handle_error(connection, client_address) +        finally: +            close_socket(connection) + +    def serve_forever(self, poll_interval=0.1): +        self.__is_shut_down.clear() +        try: +            while not self.__shutdown_request: +                try: +                    r, w_, e_ = select.select( +                        [self.socket], [], [], poll_interval) +                except select.error as ex:  # pragma: no cover +                    if ex[0] == EINTR: +                        continue +                    else: +                        raise +                if self.socket in r: +                    connection, client_address = self.socket.accept() +                    t = threading.Thread( +                        target=self.connection_thread, +                        args=(connection, client_address), +                        name="ConnectionThread (%s:%s -> %s:%s)" % +                             (client_address[0], client_address[1], +                              self.address.host, self.address.port) +                    ) +                    t.setDaemon(1) +                    try: +                        t.start() +                    except threading.ThreadError: +                        self.handle_error(connection, Address(client_address)) +                        connection.close() +        finally: +            self.__shutdown_request = False +            self.__is_shut_down.set() + +    def shutdown(self): +        self.__shutdown_request = True +        self.__is_shut_down.wait() +        self.socket.close() +        self.handle_shutdown() + +    def handle_error(self, connection_, client_address, fp=sys.stderr): +        """ +            Called when handle_client_connection raises an exception. +        """ +        # If a thread has persisted after interpreter exit, the module might be +        # none. +        if traceback: +            exc = six.text_type(traceback.format_exc()) +            print(u'-' * 40, file=fp) +            print( +                u"Error in processing of request from %s" % repr(client_address), file=fp) +            print(exc, file=fp) +            print(u'-' * 40, file=fp) + +    def handle_client_connection(self, conn, client_address):  # pragma: no cover +        """ +            Called after client connection. +        """ +        raise NotImplementedError + +    def handle_shutdown(self): +        """ +            Called after server shutdown. +        """ diff --git a/netlib/netlib/tservers.py b/netlib/netlib/tservers.py new file mode 100644 index 00000000..44ef8063 --- /dev/null +++ b/netlib/netlib/tservers.py @@ -0,0 +1,109 @@ +from __future__ import (absolute_import, print_function, division) + +import threading +from six.moves import queue +from io import StringIO +import OpenSSL + +from netlib import tcp +from netlib import tutils + + +class ServerThread(threading.Thread): + +    def __init__(self, server): +        self.server = server +        threading.Thread.__init__(self) + +    def run(self): +        self.server.serve_forever() + +    def shutdown(self): +        self.server.shutdown() + + +class ServerTestBase(object): +    ssl = None +    handler = None +    addr = ("localhost", 0) + +    @classmethod +    def setup_class(cls): +        cls.q = queue.Queue() +        s = cls.makeserver() +        cls.port = s.address.port +        cls.server = ServerThread(s) +        cls.server.start() + +    @classmethod +    def makeserver(cls): +        return TServer(cls.ssl, cls.q, cls.handler, cls.addr) + +    @classmethod +    def teardown_class(cls): +        cls.server.shutdown() + +    @property +    def last_handler(self): +        return self.server.server.last_handler + + +class TServer(tcp.TCPServer): + +    def __init__(self, ssl, q, handler_klass, addr): +        """ +            ssl: A dictionary of SSL parameters: + +                    cert, key, request_client_cert, cipher_list, +                    dhparams, v3_only +        """ +        tcp.TCPServer.__init__(self, addr) + +        if ssl is True: +            self.ssl = dict() +        elif isinstance(ssl, dict): +            self.ssl = ssl +        else: +            self.ssl = None + +        self.q = q +        self.handler_klass = handler_klass +        self.last_handler = None + +    def handle_client_connection(self, request, client_address): +        h = self.handler_klass(request, client_address, self) +        self.last_handler = h +        if self.ssl is not None: +            cert = self.ssl.get( +                "cert", +                tutils.test_data.path("data/server.crt")) +            raw_key = self.ssl.get( +                "key", +                tutils.test_data.path("data/server.key")) +            key = OpenSSL.crypto.load_privatekey( +                OpenSSL.crypto.FILETYPE_PEM, +                open(raw_key, "rb").read()) +            if self.ssl.get("v3_only", False): +                method = OpenSSL.SSL.SSLv3_METHOD +                options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 +            else: +                method = OpenSSL.SSL.SSLv23_METHOD +                options = None +            h.convert_to_ssl( +                cert, key, +                method=method, +                options=options, +                handle_sni=getattr(h, "handle_sni", None), +                request_client_cert=self.ssl.get("request_client_cert", None), +                cipher_list=self.ssl.get("cipher_list", None), +                dhparams=self.ssl.get("dhparams", None), +                chain_file=self.ssl.get("chain_file", None), +                alpn_select=self.ssl.get("alpn_select", None) +            ) +        h.handle() +        h.finish() + +    def handle_error(self, connection, client_address, fp=None): +        s = StringIO() +        tcp.TCPServer.handle_error(self, connection, client_address, s) +        self.q.put(s.getvalue()) diff --git a/netlib/netlib/tutils.py b/netlib/netlib/tutils.py new file mode 100644 index 00000000..14b4ef06 --- /dev/null +++ b/netlib/netlib/tutils.py @@ -0,0 +1,130 @@ +from io import BytesIO +import tempfile +import os +import time +import shutil +from contextlib import contextmanager +import six +import sys + +from . import utils, tcp +from .http import Request, Response, Headers + + +def treader(bytes): +    """ +        Construct a tcp.Read object from bytes. +    """ +    fp = BytesIO(bytes) +    return tcp.Reader(fp) + + +@contextmanager +def tmpdir(*args, **kwargs): +    orig_workdir = os.getcwd() +    temp_workdir = tempfile.mkdtemp(*args, **kwargs) +    os.chdir(temp_workdir) + +    yield temp_workdir + +    os.chdir(orig_workdir) +    shutil.rmtree(temp_workdir) + + +def _check_exception(expected, actual, exc_tb): +    if isinstance(expected, six.string_types): +        if expected.lower() not in str(actual).lower(): +            six.reraise(AssertionError, AssertionError( +                "Expected %s, but caught %s" % ( +                    repr(expected), repr(actual) +                ) +            ), exc_tb) +    else: +        if not isinstance(actual, expected): +            six.reraise(AssertionError, AssertionError( +                "Expected %s, but caught %s %s" % ( +                    expected.__name__, actual.__class__.__name__, repr(actual) +                ) +            ), exc_tb) + + +def raises(expected_exception, obj=None, *args, **kwargs): +    """ +        Assert that a callable raises a specified exception. + +        :exc An exception class or a string. If a class, assert that an +        exception of this type is raised. If a string, assert that the string +        occurs in the string representation of the exception, based on a +        case-insenstivie match. + +        :obj A callable object. + +        :args Arguments to be passsed to the callable. + +        :kwargs Arguments to be passed to the callable. +    """ +    if obj is None: +        return RaisesContext(expected_exception) +    else: +        try: +            ret = obj(*args, **kwargs) +        except Exception as actual: +            _check_exception(expected_exception, actual, sys.exc_info()[2]) +        else: +            raise AssertionError("No exception raised. Return value: {}".format(ret)) + + +class RaisesContext(object): +    def __init__(self, expected_exception): +        self.expected_exception = expected_exception + +    def __enter__(self): +        return + +    def __exit__(self, exc_type, exc_val, exc_tb): +        if not exc_type: +            raise AssertionError("No exception raised.") +        else: +            _check_exception(self.expected_exception, exc_val, exc_tb) +        return True + + +test_data = utils.Data(__name__) + + +def treq(**kwargs): +    """ +    Returns: +        netlib.http.Request +    """ +    default = dict( +        first_line_format="relative", +        method=b"GET", +        scheme=b"http", +        host=b"address", +        port=22, +        path=b"/path", +        http_version=b"HTTP/1.1", +        headers=Headers(header="qvalue", content_length="7"), +        content=b"content" +    ) +    default.update(kwargs) +    return Request(**default) + + +def tresp(**kwargs): +    """ +    Returns: +        netlib.http.Response +    """ +    default = dict( +        http_version=b"HTTP/1.1", +        status_code=200, +        reason=b"OK", +        headers=Headers(header_response="svalue", content_length="7"), +        content=b"message", +        timestamp_start=time.time(), +        timestamp_end=time.time(), +    ) +    default.update(kwargs) +    return Response(**default) diff --git a/netlib/netlib/utils.py b/netlib/netlib/utils.py new file mode 100644 index 00000000..d2fc7195 --- /dev/null +++ b/netlib/netlib/utils.py @@ -0,0 +1,416 @@ +from __future__ import absolute_import, print_function, division +import os.path +import re +import codecs +import unicodedata +from abc import ABCMeta, abstractmethod + +import six + +from six.moves import urllib +import hyperframe + + +@six.add_metaclass(ABCMeta) +class Serializable(object): +    """ +    Abstract Base Class that defines an API to save an object's state and restore it later on. +    """ + +    @classmethod +    @abstractmethod +    def from_state(cls, state): +        """ +        Create a new object from the given state. +        """ +        raise NotImplementedError() + +    @abstractmethod +    def get_state(self): +        """ +        Retrieve object state. +        """ +        raise NotImplementedError() + +    @abstractmethod +    def set_state(self, state): +        """ +        Set object state to the given state. +        """ +        raise NotImplementedError() + + +def always_bytes(unicode_or_bytes, *encode_args): +    if isinstance(unicode_or_bytes, six.text_type): +        return unicode_or_bytes.encode(*encode_args) +    return unicode_or_bytes + + +def always_byte_args(*encode_args): +    """Decorator that transparently encodes all arguments passed as unicode""" +    def decorator(fun): +        def _fun(*args, **kwargs): +            args = [always_bytes(arg, *encode_args) for arg in args] +            kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} +            return fun(*args, **kwargs) +        return _fun +    return decorator + + +def native(s, *encoding_opts): +    """ +    Convert :py:class:`bytes` or :py:class:`unicode` to the native +    :py:class:`str` type, using latin1 encoding if conversion is necessary. + +    https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types +    """ +    if not isinstance(s, (six.binary_type, six.text_type)): +        raise TypeError("%r is neither bytes nor unicode" % s) +    if six.PY3: +        if isinstance(s, six.binary_type): +            return s.decode(*encoding_opts) +    else: +        if isinstance(s, six.text_type): +            return s.encode(*encoding_opts) +    return s + + +def isascii(bytes): +    try: +        bytes.decode("ascii") +    except ValueError: +        return False +    return True + + +def clean_bin(s, keep_spacing=True): +    """ +        Cleans binary data to make it safe to display. + +        Args: +            keep_spacing: If False, tabs and newlines will also be replaced. +    """ +    if isinstance(s, six.text_type): +        if keep_spacing: +            keep = u" \n\r\t" +        else: +            keep = u" " +        return u"".join( +            ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." +            for ch in s +        ) +    else: +        if keep_spacing: +            keep = (9, 10, 13)  # \t, \n, \r, +        else: +            keep = () +        return b"".join( +            six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." +            for ch in six.iterbytes(s) +        ) + + +def hexdump(s): +    """ +        Returns: +            A generator of (offset, hex, str) tuples +    """ +    for i in range(0, len(s), 16): +        offset = "{:0=10x}".format(i).encode() +        part = s[i:i + 16] +        x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part)) +        x = x.ljust(47)  # 16*2 + 15 +        yield (offset, x, clean_bin(part, False)) + + +def setbit(byte, offset, value): +    """ +        Set a bit in a byte to 1 if value is truthy, 0 if not. +    """ +    if value: +        return byte | (1 << offset) +    else: +        return byte & ~(1 << offset) + + +def getbit(byte, offset): +    mask = 1 << offset +    return bool(byte & mask) + + +class BiDi(object): + +    """ +        A wee utility class for keeping bi-directional mappings, like field +        constants in protocols. Names are attributes on the object, dict-like +        access maps values to names: + +        CONST = BiDi(a=1, b=2) +        assert CONST.a == 1 +        assert CONST.get_name(1) == "a" +    """ + +    def __init__(self, **kwargs): +        self.names = kwargs +        self.values = {} +        for k, v in kwargs.items(): +            self.values[v] = k +        if len(self.names) != len(self.values): +            raise ValueError("Duplicate values not allowed.") + +    def __getattr__(self, k): +        if k in self.names: +            return self.names[k] +        raise AttributeError("No such attribute: %s", k) + +    def get_name(self, n, default=None): +        return self.values.get(n, default) + + +def pretty_size(size): +    suffixes = [ +        ("B", 2 ** 10), +        ("kB", 2 ** 20), +        ("MB", 2 ** 30), +    ] +    for suf, lim in suffixes: +        if size >= lim: +            continue +        else: +            x = round(size / float(lim / 2 ** 10), 2) +            if x == int(x): +                x = int(x) +            return str(x) + suf + + +class Data(object): + +    def __init__(self, name): +        m = __import__(name) +        dirname, _ = os.path.split(m.__file__) +        self.dirname = os.path.abspath(dirname) + +    def path(self, path): +        """ +            Returns a path to the package data housed at 'path' under this +            module.Path can be a path to a file, or to a directory. + +            This function will raise ValueError if the path does not exist. +        """ +        fullpath = os.path.join(self.dirname, '../test/', path) +        if not os.path.exists(fullpath): +            raise ValueError("dataPath: %s does not exist." % fullpath) +        return fullpath + + +_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE) + + +def is_valid_host(host): +    """ +    Checks if a hostname is valid. + +    Args: +      host (bytes): The hostname +    """ +    try: +        host.decode("idna") +    except ValueError: +        return False +    if len(host) > 255: +        return False +    if host[-1] == b".": +        host = host[:-1] +    return all(_label_valid.match(x) for x in host.split(b".")) + + +def is_valid_port(port): +    return 0 <= port <= 65535 + + +# PY2 workaround +def decode_parse_result(result, enc): +    if hasattr(result, "decode"): +        return result.decode(enc) +    else: +        return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) + + +# PY2 workaround +def encode_parse_result(result, enc): +    if hasattr(result, "encode"): +        return result.encode(enc) +    else: +        return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) + + +def parse_url(url): +    """ +        URL-parsing function that checks that +            - port is an integer 0-65535 +            - host is a valid IDNA-encoded hostname with no null-bytes +            - path is valid ASCII + +        Args: +            A URL (as bytes or as unicode) + +        Returns: +            A (scheme, host, port, path) tuple + +        Raises: +            ValueError, if the URL is not properly formatted. +    """ +    parsed = urllib.parse.urlparse(url) + +    if not parsed.hostname: +        raise ValueError("No hostname given") + +    if isinstance(url, six.binary_type): +        host = parsed.hostname + +        # this should not raise a ValueError, +        # but we try to be very forgiving here and accept just everything. +        # decode_parse_result(parsed, "ascii") +    else: +        host = parsed.hostname.encode("idna") +        parsed = encode_parse_result(parsed, "ascii") + +    port = parsed.port +    if not port: +        port = 443 if parsed.scheme == b"https" else 80 + +    full_path = urllib.parse.urlunparse( +        (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) +    ) +    if not full_path.startswith(b"/"): +        full_path = b"/" + full_path + +    if not is_valid_host(host): +        raise ValueError("Invalid Host") +    if not is_valid_port(port): +        raise ValueError("Invalid Port") + +    return parsed.scheme, host, port, full_path + + +def get_header_tokens(headers, key): +    """ +        Retrieve all tokens for a header key. A number of different headers +        follow a pattern where each header line can containe comma-separated +        tokens, and headers can be set multiple times. +    """ +    if key not in headers: +        return [] +    tokens = headers[key].split(",") +    return [token.strip() for token in tokens] + + +def hostport(scheme, host, port): +    """ +        Returns the host component, with a port specifcation if needed. +    """ +    if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: +        return host +    else: +        if isinstance(host, six.binary_type): +            return b"%s:%d" % (host, port) +        else: +            return "%s:%d" % (host, port) + + +def unparse_url(scheme, host, port, path=""): +    """ +    Returns a URL string, constructed from the specified components. + +    Args: +        All args must be str. +    """ +    return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) + + +def urlencode(s): +    """ +        Takes a list of (key, value) tuples and returns a urlencoded string. +    """ +    s = [tuple(i) for i in s] +    return urllib.parse.urlencode(s, False) + + +def urldecode(s): +    """ +        Takes a urlencoded string and returns a list of (key, value) tuples. +    """ +    return urllib.parse.parse_qsl(s, keep_blank_values=True) + + +def parse_content_type(c): +    """ +        A simple parser for content-type values. Returns a (type, subtype, +        parameters) tuple, where type and subtype are strings, and parameters +        is a dict. If the string could not be parsed, return None. + +        E.g. the following string: + +            text/html; charset=UTF-8 + +        Returns: + +            ("text", "html", {"charset": "UTF-8"}) +    """ +    parts = c.split(";", 1) +    ts = parts[0].split("/", 1) +    if len(ts) != 2: +        return None +    d = {} +    if len(parts) == 2: +        for i in parts[1].split(";"): +            clause = i.split("=", 1) +            if len(clause) == 2: +                d[clause[0].strip()] = clause[1].strip() +    return ts[0].lower(), ts[1].lower(), d + + +def multipartdecode(headers, content): +    """ +        Takes a multipart boundary encoded string and returns list of (key, value) tuples. +    """ +    v = headers.get("content-type") +    if v: +        v = parse_content_type(v) +        if not v: +            return [] +        try: +            boundary = v[2]["boundary"].encode("ascii") +        except (KeyError, UnicodeError): +            return [] + +        rx = re.compile(br'\bname="([^"]+)"') +        r = [] + +        for i in content.split(b"--" + boundary): +            parts = i.splitlines() +            if len(parts) > 1 and parts[0][0:2] != b"--": +                match = rx.search(parts[1]) +                if match: +                    key = match.group(1) +                    value = b"".join(parts[3 + parts[2:].index(b""):]) +                    r.append((key, value)) +        return r +    return [] + + +def http2_read_raw_frame(rfile): +    header = rfile.safe_read(9) +    length = int(codecs.encode(header[:3], 'hex_codec'), 16) + +    if length == 4740180: +        raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20)) + +    body = rfile.safe_read(length) +    return [header, body] + +def http2_read_frame(rfile): +    header, body = http2_read_raw_frame(rfile) +    frame, length = hyperframe.frame.Frame.parse_frame_header(header) +    frame.parse_body(memoryview(body)) +    return frame diff --git a/netlib/netlib/version.py b/netlib/netlib/version.py new file mode 100644 index 00000000..bc35c30f --- /dev/null +++ b/netlib/netlib/version.py @@ -0,0 +1,11 @@ +from __future__ import (absolute_import, print_function, division) + +IVERSION = (0, 17) +VERSION = ".".join(str(i) for i in IVERSION) +MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) +NAME = "netlib" +NAMEVERSION = NAME + " " + VERSION + +NEXT_MINORVERSION = list(IVERSION) +NEXT_MINORVERSION[1] += 1 +NEXT_MINORVERSION = ".".join(str(i) for i in NEXT_MINORVERSION[:2]) diff --git a/netlib/netlib/version_check.py b/netlib/netlib/version_check.py new file mode 100644 index 00000000..9cf27eea --- /dev/null +++ b/netlib/netlib/version_check.py @@ -0,0 +1,60 @@ +""" +Having installed a wrong version of pyOpenSSL or netlib is unfortunately a +very common source of error. Check before every start that both versions +are somewhat okay. +""" +from __future__ import division, absolute_import, print_function +import sys +import inspect +import os.path +import six + +import OpenSSL +from . import version + +PYOPENSSL_MIN_VERSION = (0, 15) + + +def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): +    # We don't introduce backward-incompatible changes in patch versions. Only +    # consider major and minor version. +    if version.IVERSION[:2] != mitmproxy_version[:2]: +        print( +            u"You are using mitmproxy %s with netlib %s. " +            u"Most likely, that won't work - please upgrade!" % ( +                mitmproxy_version, version.VERSION +            ), +            file=fp +        ) +        sys.exit(1) + + +def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): +    min_version_str = u".".join(six.text_type(x) for x in min_version) +    try: +        v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) +    except ValueError: +        print( +            u"Cannot parse pyOpenSSL version: {}" +            u"mitmproxy requires pyOpenSSL {} or greater.".format( +                OpenSSL.__version__, min_version_str +            ), +            file=fp +        ) +        return +    if v < min_version: +        print( +            u"You are using an outdated version of pyOpenSSL: " +            u"mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), +            file=fp +        ) +        # Some users apparently have multiple versions of pyOpenSSL installed. +        # Report which one we got. +        pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL)) +        print( +            u"Your pyOpenSSL {} installation is located at {}".format( +                OpenSSL.__version__, pyopenssl_path +            ), +            file=fp +        ) +        sys.exit(1) diff --git a/netlib/netlib/websockets/__init__.py b/netlib/netlib/websockets/__init__.py new file mode 100644 index 00000000..1c143919 --- /dev/null +++ b/netlib/netlib/websockets/__init__.py @@ -0,0 +1,2 @@ +from .frame import * +from .protocol import * diff --git a/netlib/netlib/websockets/frame.py b/netlib/netlib/websockets/frame.py new file mode 100644 index 00000000..fce2c9d3 --- /dev/null +++ b/netlib/netlib/websockets/frame.py @@ -0,0 +1,316 @@ +from __future__ import absolute_import +import os +import struct +import io +import warnings + +import six + +from .protocol import Masker +from netlib import tcp +from netlib import utils + + +MAX_16_BIT_INT = (1 << 16) +MAX_64_BIT_INT = (1 << 64) + +DEFAULT=object() + +OPCODE = utils.BiDi( +    CONTINUE=0x00, +    TEXT=0x01, +    BINARY=0x02, +    CLOSE=0x08, +    PING=0x09, +    PONG=0x0a +) + + +class FrameHeader(object): + +    def __init__( +        self, +        opcode=OPCODE.TEXT, +        payload_length=0, +        fin=False, +        rsv1=False, +        rsv2=False, +        rsv3=False, +        masking_key=DEFAULT, +        mask=DEFAULT, +        length_code=DEFAULT +    ): +        if not 0 <= opcode < 2 ** 4: +            raise ValueError("opcode must be 0-16") +        self.opcode = opcode +        self.payload_length = payload_length +        self.fin = fin +        self.rsv1 = rsv1 +        self.rsv2 = rsv2 +        self.rsv3 = rsv3 + +        if length_code is DEFAULT: +            self.length_code = self._make_length_code(self.payload_length) +        else: +            self.length_code = length_code + +        if mask is DEFAULT and masking_key is DEFAULT: +            self.mask = False +            self.masking_key = b"" +        elif mask is DEFAULT: +            self.mask = 1 +            self.masking_key = masking_key +        elif masking_key is DEFAULT: +            self.mask = mask +            self.masking_key = os.urandom(4) +        else: +            self.mask = mask +            self.masking_key = masking_key + +        if self.masking_key and len(self.masking_key) != 4: +            raise ValueError("Masking key must be 4 bytes.") + +    @classmethod +    def _make_length_code(self, length): +        """ +         A websockets frame contains an initial length_code, and an optional +         extended length code to represent the actual length if length code is +         larger than 125 +        """ +        if length <= 125: +            return length +        elif length >= 126 and length <= 65535: +            return 126 +        else: +            return 127 + +    def __repr__(self): +        vals = [ +            "ws frame:", +            OPCODE.get_name(self.opcode, hex(self.opcode)).lower() +        ] +        flags = [] +        for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: +            if getattr(self, i): +                flags.append(i) +        if flags: +            vals.extend([":", "|".join(flags)]) +        if self.masking_key: +            vals.append(":key=%s" % repr(self.masking_key)) +        if self.payload_length: +            vals.append(" %s" % utils.pretty_size(self.payload_length)) +        return "".join(vals) + +    def human_readable(self): +        warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) +        return repr(self) + +    def __bytes__(self): +        first_byte = utils.setbit(0, 7, self.fin) +        first_byte = utils.setbit(first_byte, 6, self.rsv1) +        first_byte = utils.setbit(first_byte, 5, self.rsv2) +        first_byte = utils.setbit(first_byte, 4, self.rsv3) +        first_byte = first_byte | self.opcode + +        second_byte = utils.setbit(self.length_code, 7, self.mask) + +        b = six.int2byte(first_byte) + six.int2byte(second_byte) + +        if self.payload_length < 126: +            pass +        elif self.payload_length < MAX_16_BIT_INT: +            # '!H' pack as 16 bit unsigned short +            # add 2 byte extended payload length +            b += struct.pack('!H', self.payload_length) +        elif self.payload_length < MAX_64_BIT_INT: +            # '!Q' = pack as 64 bit unsigned long long +            # add 8 bytes extended payload length +            b += struct.pack('!Q', self.payload_length) +        if self.masking_key: +            b += self.masking_key +        return b + +    if six.PY2: +        __str__ = __bytes__ + +    def to_bytes(self): +        warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) +        return bytes(self) + +    @classmethod +    def from_file(cls, fp): +        """ +          read a websockets frame header +        """ +        first_byte = six.byte2int(fp.safe_read(1)) +        second_byte = six.byte2int(fp.safe_read(1)) + +        fin = utils.getbit(first_byte, 7) +        rsv1 = utils.getbit(first_byte, 6) +        rsv2 = utils.getbit(first_byte, 5) +        rsv3 = utils.getbit(first_byte, 4) +        # grab right-most 4 bits +        opcode = first_byte & 15 +        mask_bit = utils.getbit(second_byte, 7) +        # grab the next 7 bits +        length_code = second_byte & 127 + +        # payload_lengthy > 125 indicates you need to read more bytes +        # to get the actual payload length +        if length_code <= 125: +            payload_length = length_code +        elif length_code == 126: +            payload_length, = struct.unpack("!H", fp.safe_read(2)) +        elif length_code == 127: +            payload_length, = struct.unpack("!Q", fp.safe_read(8)) + +        # masking key only present if mask bit set +        if mask_bit == 1: +            masking_key = fp.safe_read(4) +        else: +            masking_key = None + +        return cls( +            fin=fin, +            rsv1=rsv1, +            rsv2=rsv2, +            rsv3=rsv3, +            opcode=opcode, +            mask=mask_bit, +            length_code=length_code, +            payload_length=payload_length, +            masking_key=masking_key, +        ) + +    def __eq__(self, other): +        if isinstance(other, FrameHeader): +            return bytes(self) == bytes(other) +        return False + + +class Frame(object): + +    """ +        Represents one websockets frame. +        Constructor takes human readable forms of the frame components +        from_bytes() is also avaliable. + +        WebSockets Frame as defined in RFC6455 + +          0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +         +-+-+-+-+-------+-+-------------+-------------------------------+ +         |F|R|R|R| opcode|M| Payload len |    Extended payload length    | +         |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           | +         |N|V|V|V|       |S|             |   (if payload len==126/127)   | +         | |1|2|3|       |K|             |                               | +         +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + +         |     Extended payload length continued, if payload len == 127  | +         + - - - - - - - - - - - - - - - +-------------------------------+ +         |                               |Masking-key, if MASK set to 1  | +         +-------------------------------+-------------------------------+ +         | Masking-key (continued)       |          Payload Data         | +         +-------------------------------- - - - - - - - - - - - - - - - + +         :                     Payload Data continued ...                : +         + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +         |                     Payload Data continued ...                | +         +---------------------------------------------------------------+ +    """ + +    def __init__(self, payload=b"", **kwargs): +        self.payload = payload +        kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) +        self.header = FrameHeader(**kwargs) + +    @classmethod +    def default(cls, message, from_client=False): +        """ +          Construct a basic websocket frame from some default values. +          Creates a non-fragmented text frame. +        """ +        if from_client: +            mask_bit = 1 +            masking_key = os.urandom(4) +        else: +            mask_bit = 0 +            masking_key = None + +        return cls( +            message, +            fin=1,  # final frame +            opcode=OPCODE.TEXT,  # text +            mask=mask_bit, +            masking_key=masking_key, +        ) + +    @classmethod +    def from_bytes(cls, bytestring): +        """ +          Construct a websocket frame from an in-memory bytestring +          to construct a frame from a stream of bytes, use from_file() directly +        """ +        return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) + +    def __repr__(self): +        ret = repr(self.header) +        if self.payload: +            ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii") +        return ret + +    def human_readable(self): +        warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning) +        return repr(self) + +    def __bytes__(self): +        """ +            Serialize the frame to wire format. Returns a string. +        """ +        b = bytes(self.header) +        if self.header.masking_key: +            b += Masker(self.header.masking_key)(self.payload) +        else: +            b += self.payload +        return b + +    if six.PY2: +        __str__ = __bytes__ + +    def to_bytes(self): +        warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) +        return bytes(self) + +    def to_file(self, writer): +        warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning) +        writer.write(bytes(self)) +        writer.flush() + +    @classmethod +    def from_file(cls, fp): +        """ +          read a websockets frame sent by a server or client + +          fp is a "file like" object that could be backed by a network +          stream or a disk or an in memory stream reader +        """ +        header = FrameHeader.from_file(fp) +        payload = fp.safe_read(header.payload_length) + +        if header.mask == 1 and header.masking_key: +            payload = Masker(header.masking_key)(payload) + +        return cls( +            payload, +            fin=header.fin, +            opcode=header.opcode, +            mask=header.mask, +            payload_length=header.payload_length, +            masking_key=header.masking_key, +            rsv1=header.rsv1, +            rsv2=header.rsv2, +            rsv3=header.rsv3, +            length_code=header.length_code +        ) + +    def __eq__(self, other): +        if isinstance(other, Frame): +            return bytes(self) == bytes(other) +        return False diff --git a/netlib/netlib/websockets/protocol.py b/netlib/netlib/websockets/protocol.py new file mode 100644 index 00000000..1e95fa1c --- /dev/null +++ b/netlib/netlib/websockets/protocol.py @@ -0,0 +1,115 @@ + + + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness +# +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand +# RFC6455 +from __future__ import absolute_import +import base64 +import hashlib +import os + +import binascii +import six +from ..http import Headers + +websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +VERSION = "13" + + +class Masker(object): + +    """ +        Data sent from the server must be masked to prevent malicious clients +        from sending data over the wire in predictable patterns + +        Servers do not have to mask data they send to the client. +        https://tools.ietf.org/html/rfc6455#section-5.3 +    """ + +    def __init__(self, key): +        self.key = key +        self.offset = 0 + +    def mask(self, offset, data): +        result = bytearray(data) +        if six.PY2: +            for i in range(len(data)): +                result[i] ^= ord(self.key[offset % 4]) +                offset += 1 +            result = str(result) +        else: + +            for i in range(len(data)): +                result[i] ^= self.key[offset % 4] +                offset += 1 +            result = bytes(result) +        return result + +    def __call__(self, data): +        ret = self.mask(self.offset, data) +        self.offset += len(ret) +        return ret + + +class WebsocketsProtocol(object): + +    def __init__(self): +        pass + +    @classmethod +    def client_handshake_headers(self, key=None, version=VERSION): +        """ +            Create the headers for a valid HTTP upgrade request. If Key is not +            specified, it is generated, and can be found in sec-websocket-key in +            the returned header set. + +            Returns an instance of Headers +        """ +        if not key: +            key = base64.b64encode(os.urandom(16)).decode('ascii') +        return Headers( +            sec_websocket_key=key, +            sec_websocket_version=version, +            connection="Upgrade", +            upgrade="websocket", +        ) + +    @classmethod +    def server_handshake_headers(self, key): +        """ +          The server response is a valid HTTP 101 response. +        """ +        return Headers( +            sec_websocket_accept=self.create_server_nonce(key), +            connection="Upgrade", +            upgrade="websocket" +        ) + + +    @classmethod +    def check_client_handshake(self, headers): +        if headers.get("upgrade") != "websocket": +            return +        return headers.get("sec-websocket-key") + + +    @classmethod +    def check_server_handshake(self, headers): +        if headers.get("upgrade") != "websocket": +            return +        return headers.get("sec-websocket-accept") + + +    @classmethod +    def create_server_nonce(self, client_nonce): +        return base64.b64encode(hashlib.sha1(client_nonce + websockets_magic).digest()) diff --git a/netlib/netlib/wsgi.py b/netlib/netlib/wsgi.py new file mode 100644 index 00000000..d6dfae5d --- /dev/null +++ b/netlib/netlib/wsgi.py @@ -0,0 +1,164 @@ +from __future__ import (absolute_import, print_function, division) +from io import BytesIO, StringIO +import urllib +import time +import traceback + +import six +from six.moves import urllib + +from netlib.utils import always_bytes, native +from . import http, tcp + +class ClientConn(object): + +    def __init__(self, address): +        self.address = tcp.Address.wrap(address) + + +class Flow(object): + +    def __init__(self, address, request): +        self.client_conn = ClientConn(address) +        self.request = request + + +class Request(object): + +    def __init__(self, scheme, method, path, http_version, headers, content): +        self.scheme, self.method, self.path = scheme, method, path +        self.headers, self.content = headers, content +        self.http_version = http_version + + +def date_time_string(): +    """Return the current date and time formatted for a message header.""" +    WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] +    MONTHS = [ +        None, +        'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', +        'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' +    ] +    now = time.time() +    year, month, day, hh, mm, ss, wd, y_, z_ = time.gmtime(now) +    s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( +        WEEKS[wd], +        day, MONTHS[month], year, +        hh, mm, ss +    ) +    return s + + +class WSGIAdaptor(object): + +    def __init__(self, app, domain, port, sversion): +        self.app, self.domain, self.port, self.sversion = app, domain, port, sversion + +    def make_environ(self, flow, errsoc, **extra): +        path = native(flow.request.path, "latin-1") +        if '?' in path: +            path_info, query = native(path, "latin-1").split('?', 1) +        else: +            path_info = path +            query = '' +        environ = { +            'wsgi.version': (1, 0), +            'wsgi.url_scheme': native(flow.request.scheme, "latin-1"), +            'wsgi.input': BytesIO(flow.request.content or b""), +            'wsgi.errors': errsoc, +            'wsgi.multithread': True, +            'wsgi.multiprocess': False, +            'wsgi.run_once': False, +            'SERVER_SOFTWARE': self.sversion, +            'REQUEST_METHOD': native(flow.request.method, "latin-1"), +            'SCRIPT_NAME': '', +            'PATH_INFO': urllib.parse.unquote(path_info), +            'QUERY_STRING': query, +            'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', ''), "latin-1"), +            'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', ''), "latin-1"), +            'SERVER_NAME': self.domain, +            'SERVER_PORT': str(self.port), +            'SERVER_PROTOCOL': native(flow.request.http_version, "latin-1"), +        } +        environ.update(extra) +        if flow.client_conn.address: +            environ["REMOTE_ADDR"] = native(flow.client_conn.address.host, "latin-1") +            environ["REMOTE_PORT"] = flow.client_conn.address.port + +        for key, value in flow.request.headers.items(): +            key = 'HTTP_' + native(key, "latin-1").upper().replace('-', '_') +            if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): +                environ[key] = value +        return environ + +    def error_page(self, soc, headers_sent, s): +        """ +            Make a best-effort attempt to write an error page. If headers are +            already sent, we just bung the error into the page. +        """ +        c = """ +            <html> +                <h1>Internal Server Error</h1> +                <pre>{err}"</pre> +            </html> +        """.format(err=s).strip().encode() + +        if not headers_sent: +            soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") +            soc.write(b"Content-Type: text/html\r\n") +            soc.write("Content-Length: {length}\r\n".format(length=len(c)).encode()) +            soc.write(b"\r\n") +        soc.write(c) + +    def serve(self, request, soc, **env): +        state = dict( +            response_started=False, +            headers_sent=False, +            status=None, +            headers=None +        ) + +        def write(data): +            if not state["headers_sent"]: +                soc.write("HTTP/1.1 {status}\r\n".format(status=state["status"]).encode()) +                headers = state["headers"] +                if 'server' not in headers: +                    headers["Server"] = self.sversion +                if 'date' not in headers: +                    headers["Date"] = date_time_string() +                soc.write(bytes(headers)) +                soc.write(b"\r\n") +                state["headers_sent"] = True +            if data: +                soc.write(data) +            soc.flush() + +        def start_response(status, headers, exc_info=None): +            if exc_info: +                if state["headers_sent"]: +                    six.reraise(*exc_info) +            elif state["status"]: +                raise AssertionError('Response already started') +            state["status"] = status +            state["headers"] = http.Headers([[always_bytes(k), always_bytes(v)] for k,v in headers]) +            if exc_info: +                self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2])) +                state["headers_sent"] = True + +        errs = six.BytesIO() +        try: +            dataiter = self.app( +                self.make_environ(request, errs, **env), start_response +            ) +            for i in dataiter: +                write(i) +            if not state["headers_sent"]: +                write(b"") +        except Exception as e: +            try: +                s = traceback.format_exc() +                errs.write(s.encode("utf-8", "replace")) +                self.error_page(soc, state["headers_sent"], s) +            except Exception:    # pragma: no cover +                pass +        return errs.getvalue() diff --git a/netlib/requirements.txt b/netlib/requirements.txt new file mode 100644 index 00000000..aefbcb6d --- /dev/null +++ b/netlib/requirements.txt @@ -0,0 +1 @@ +-e .[dev] diff --git a/netlib/setup.cfg b/netlib/setup.cfg new file mode 100644 index 00000000..3480374b --- /dev/null +++ b/netlib/setup.cfg @@ -0,0 +1,2 @@ +[bdist_wheel] +universal=1
\ No newline at end of file diff --git a/netlib/setup.py b/netlib/setup.py new file mode 100644 index 00000000..bcaecad4 --- /dev/null +++ b/netlib/setup.py @@ -0,0 +1,72 @@ +from setuptools import setup, find_packages +from codecs import open +import os +import sys + +from netlib import version + +# Based on https://github.com/pypa/sampleproject/blob/master/setup.py +# and https://python-packaging-user-guide.readthedocs.org/ +# and https://caremad.io/2014/11/distributing-a-cffi-project/ + +here = os.path.abspath(os.path.dirname(__file__)) + +with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: +    long_description = f.read() + +setup( +    name="netlib", +    version=version.VERSION, +    description="A collection of network utilities used by pathod and mitmproxy.", +    long_description=long_description, +    url="http://github.com/mitmproxy/netlib", +    author="Aldo Cortesi", +    author_email="aldo@corte.si", +    license="MIT", +    classifiers=[ +        "License :: OSI Approved :: MIT License", +        "Development Status :: 3 - Alpha", +        "Operating System :: POSIX", +        "Programming Language :: Python", +        "Programming Language :: Python :: 2", +        "Programming Language :: Python :: 2.7", +        "Programming Language :: Python :: 3", +        "Programming Language :: Python :: 3.5", +        "Programming Language :: Python :: Implementation :: CPython", +        "Programming Language :: Python :: Implementation :: PyPy", +        "Topic :: Internet", +        "Topic :: Internet :: WWW/HTTP", +        "Topic :: Internet :: WWW/HTTP :: HTTP Servers", +        "Topic :: Software Development :: Testing", +        "Topic :: Software Development :: Testing :: Traffic Generation", +    ], +    packages=find_packages(exclude=["test", "test.*"]), +    include_package_data=True, +    zip_safe=False, +    install_requires=[ +        "pyasn1>=0.1.9, <0.2", +        "pyOpenSSL>=0.15.1, <0.16", +        "cryptography>=1.2.2, <1.3", +        "passlib>=1.6.5, <1.7", +        "hpack>=2.1.0, <3.0", +        "hyperframe>=3.2.0, <4.0", +        "six>=1.10.0, <1.11", +        "certifi>=2015.11.20.1",  # no semver here - this should always be on the last release! +        "backports.ssl_match_hostname>=3.5.0.1, <3.6", +    ], +    extras_require={ +        # Do not use a range operator here: https://bitbucket.org/pypa/setuptools/issues/380 +        # Ubuntu Trusty and other still ship with setuptools < 17.1 +        ':python_version == "2.7"': [ +            "ipaddress>=1.0.15, <1.1", +        ], +        'dev': [ +            "mock>=1.3.0, <1.4", +            "pytest>=2.8.7, <2.9", +            "pytest-xdist>=1.14, <1.15", +            "pytest-cov>=2.2.1, <2.3", +            "pytest-timeout>=1.0.0, <1.1", +            "coveralls>=1.1, <1.2" +        ] +    }, +) diff --git a/netlib/test/__init__.py b/netlib/test/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/netlib/test/__init__.py diff --git a/netlib/test/data/clientcert/.gitignore b/netlib/test/data/clientcert/.gitignore new file mode 100644 index 00000000..07bc53d2 --- /dev/null +++ b/netlib/test/data/clientcert/.gitignore @@ -0,0 +1,3 @@ +client.crt +client.key +client.req diff --git a/netlib/test/data/clientcert/client.cnf b/netlib/test/data/clientcert/client.cnf new file mode 100644 index 00000000..5046a944 --- /dev/null +++ b/netlib/test/data/clientcert/client.cnf @@ -0,0 +1,5 @@ +[ ssl_client ] +basicConstraints = CA:FALSE +nsCertType = client +keyUsage = digitalSignature, keyEncipherment +extendedKeyUsage = clientAuth diff --git a/netlib/test/data/clientcert/client.pem b/netlib/test/data/clientcert/client.pem new file mode 100644 index 00000000..4927bca2 --- /dev/null +++ b/netlib/test/data/clientcert/client.pem @@ -0,0 +1,42 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAzCpoRjSTfIN24kkNap/GYmP9zVWj0Gk8R5BB/PvvN0OB1Zk0 +EEYPsWCcuhEdK0ehiDZX030doF0DOncKKa6mop/d0x2o+ts42peDhZM6JNUrm6d+ +ZWQVtio33mpp77UMhR093vaA+ExDnmE26kBTVijJ1+fRAVDXG/cmQINEri91Kk/G +3YJ5e45UrohGI5seBZ4vV0xbHtmczFRhYFlGOvYsoIe4Lvz/eFS2pIrTIpYQ2VM/ +SQQl+JFy+NlQRsWG2NrxtKOzMnnDE7YN4I3z5D5eZFo1EtwZ48LNCeSwrEOdfuzP +G5q5qbs5KpE/x85H9umuRwSCIArbMwBYV8a8JwIDAQABAoIBAFE3FV/IDltbmHEP +iky93hbJm+6QgKepFReKpRVTyqb7LaygUvueQyPWQMIriKTsy675nxo8DQr7tQsO +y3YlSZgra/xNMikIB6e82c7K8DgyrDQw/rCqjZB3Xt4VCqsWJDLXnQMSn98lx0g7 +d7Lbf8soUpKWXqfdVpSDTi4fibSX6kshXyfSTpcz4AdoncEpViUfU1xkEEmZrjT8 +1GcCsDC41xdNmzCpqRuZX7DKSFRoB+0hUzsC1oiqM7FD5kixonRd4F5PbRXImIzt +6YCsT2okxTA04jX7yByis7LlOLTlkmLtKQYuc3erOFvwx89s4vW+AeFei+GGNitn +tHfSwbECgYEA7SzV+nN62hAERHlg8cEQT4TxnsWvbronYWcc/ev44eHSPDWL5tPi +GHfSbW6YAq5Wa0I9jMWfXyhOYEC3MZTC5EEeLOB71qVrTwcy/sY66rOrcgjFI76Q +5JFHQ4wy3SWU50KxE0oWJO9LIowprG+pW1vzqC3VF0T7q0FqESrY4LUCgYEA3F7Z +80ndnCUlooJAb+Hfotv7peFf1o6+m1PTRcz1lLnVt5R5lXj86kn+tXEpYZo1RiGR +2rE2N0seeznWCooakHcsBN7/qmFIhhooJNF7yW+JP2I4P2UV5+tJ+8bcs/voUkQD +1x+rGOuMn8nvHBd2+Vharft8eGL2mgooPVI2XusCgYEAlMZpO3+w8pTVeHaDP2MR +7i/AuQ3cbCLNjSX3Y7jgGCFllWspZRRIYXzYPNkA9b2SbBnTLjjRLgnEkFBIGgvs +7O2EFjaCuDRvydUEQhjq4ErwIsopj7B8h0QyZcbOKTbn3uFQ3n68wVJx2Sv/ADHT +FIHrp/WIE96r19Niy34LKXkCgYB2W59VsuOKnMz01l5DeR5C+0HSWxS9SReIl2IO +yEFSKullWyJeLIgyUaGy0990430feKI8whcrZXYumuah7IDN/KOwzhCk8vEfzWao +N7bzfqtJVrh9HA7C7DVlO+6H4JFrtcoWPZUIomJ549w/yz6EN3ckoMC+a/Ck1TW9 +ka1QFwKBgQCywG6TrZz0UmOjyLQZ+8Q4uvZklSW5NAKBkNnyuQ2kd5rzyYgMPE8C +Er8T88fdVIKvkhDyHhwcI7n58xE5Gr7wkwsrk/Hbd9/ZB2GgAPY3cATskK1v1McU +YeX38CU0fUS4aoy26hWQXkViB47IGQ3jWo3ZCtzIJl8DI9/RsBWTnw== +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICYDCCAckCAQEwDQYJKoZIhvcNAQEFBQAwKDESMBAGA1UEAxMJbWl0bXByb3h5 +MRIwEAYDVQQKEwltaXRtcHJveHkwHhcNMTMwMTIwMDEwODEzWhcNMTUxMDE3MDEw +ODEzWjBFMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UE +ChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEAzCpoRjSTfIN24kkNap/GYmP9zVWj0Gk8R5BB/PvvN0OB1Zk0 +EEYPsWCcuhEdK0ehiDZX030doF0DOncKKa6mop/d0x2o+ts42peDhZM6JNUrm6d+ +ZWQVtio33mpp77UMhR093vaA+ExDnmE26kBTVijJ1+fRAVDXG/cmQINEri91Kk/G +3YJ5e45UrohGI5seBZ4vV0xbHtmczFRhYFlGOvYsoIe4Lvz/eFS2pIrTIpYQ2VM/ +SQQl+JFy+NlQRsWG2NrxtKOzMnnDE7YN4I3z5D5eZFo1EtwZ48LNCeSwrEOdfuzP +G5q5qbs5KpE/x85H9umuRwSCIArbMwBYV8a8JwIDAQABMA0GCSqGSIb3DQEBBQUA +A4GBAFvI+cd47B85PQ970n2dU/PlA2/Hb1ldrrXh2guR4hX6vYx/uuk5yRI/n0Rd +KOXJ3czO0bd2Fpe3ZoNpkW0pOSDej/Q+58ScuJd0gWCT/Sh1eRk6ZdC0kusOuWoY +bPOPMkG45LPgUMFOnZEsfJP6P5mZIxlbCvSMFC25nPHWlct7 +-----END CERTIFICATE----- diff --git a/netlib/test/data/clientcert/make b/netlib/test/data/clientcert/make new file mode 100755 index 00000000..d1caea81 --- /dev/null +++ b/netlib/test/data/clientcert/make @@ -0,0 +1,8 @@ +#!/bin/sh + +openssl genrsa -out client.key 2048 +openssl req -key client.key -new -out client.req +openssl x509 -req -days 365 -in client.req -signkey client.key -out client.crt -extfile client.cnf -extensions ssl_client +openssl x509 -req -days 1000 -in client.req -CA ~/.mitmproxy/mitmproxy-ca.pem -CAkey ~/.mitmproxy/mitmproxy-ca.pem -set_serial 00001 -out client.crt -extensions ssl_client +cat client.key client.crt > client.pem +openssl x509 -text -noout -in client.pem diff --git a/netlib/test/data/dercert b/netlib/test/data/dercert Binary files differnew file mode 100644 index 00000000..370252af --- /dev/null +++ b/netlib/test/data/dercert diff --git a/netlib/test/data/dhparam.pem b/netlib/test/data/dhparam.pem new file mode 100644 index 00000000..afb41672 --- /dev/null +++ b/netlib/test/data/dhparam.pem @@ -0,0 +1,13 @@ +-----BEGIN DH PARAMETERS----- +MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 +O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv +j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ +Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB +chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC +ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq +o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX +IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv +A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8 +6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I +rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= +-----END DH PARAMETERS----- diff --git a/netlib/test/data/htpasswd b/netlib/test/data/htpasswd new file mode 100644 index 00000000..54c95b8c --- /dev/null +++ b/netlib/test/data/htpasswd @@ -0,0 +1 @@ +test:$apr1$/LkYxy3x$WI4.YbiJlu537jLGEW2eu1 diff --git a/netlib/test/data/server.crt b/netlib/test/data/server.crt new file mode 100644 index 00000000..68f61bac --- /dev/null +++ b/netlib/test/data/server.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICOzCCAaQCCQDC7f5GsEpo9jANBgkqhkiG9w0BAQUFADBiMQswCQYDVQQGEwJO +WjEOMAwGA1UECBMFT3RhZ28xEDAOBgNVBAcTB0R1bmVkaW4xDzANBgNVBAoTBm5l +dGxpYjEPMA0GA1UECxMGbmV0bGliMQ8wDQYDVQQDEwZuZXRsaWIwHhcNMTIwNjI0 +MjI0MTU0WhcNMjIwNjIyMjI0MTU0WjBiMQswCQYDVQQGEwJOWjEOMAwGA1UECBMF +T3RhZ28xEDAOBgNVBAcTB0R1bmVkaW4xDzANBgNVBAoTBm5ldGxpYjEPMA0GA1UE +CxMGbmV0bGliMQ8wDQYDVQQDEwZuZXRsaWIwgZ8wDQYJKoZIhvcNAQEBBQADgY0A +MIGJAoGBALJSVEl9y3QUSYuXTH0UjBOPQgS0nHmNWej9hjqnA0KWvEnGY+c6yQeP +/rmwswlKw1iVV5o8kRK9Wej88YWQl/hl/xruyeJgGic0+yqY/FcueZxRudwBcWu2 +7+46aEftwLLRF0GwHZxX/HwWME+TcCXGpXGSG2qs921M4iVeBn5hAgMBAAEwDQYJ +KoZIhvcNAQEFBQADgYEAODZCihEv2yr8zmmQZDrfqg2ChxAoOXWF5+W2F/0LAUBf +2bHP+K4XE6BJWmadX1xKngj7SWrhmmTDp1gBAvXURoDaScOkB1iOCOHoIyalscTR +0FvSHKqFF8fgSlfqS6eYaSbXU3zQolvwP+URzIVnGDqgQCWPtjMqLD3Kd5tuwos= +-----END CERTIFICATE----- diff --git a/netlib/test/data/server.key b/netlib/test/data/server.key new file mode 100644 index 00000000..b1b658ab --- /dev/null +++ b/netlib/test/data/server.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQCyUlRJfct0FEmLl0x9FIwTj0IEtJx5jVno/YY6pwNClrxJxmPn +OskHj/65sLMJSsNYlVeaPJESvVno/PGFkJf4Zf8a7sniYBonNPsqmPxXLnmcUbnc +AXFrtu/uOmhH7cCy0RdBsB2cV/x8FjBPk3AlxqVxkhtqrPdtTOIlXgZ+YQIDAQAB +AoGAQEpGcSiVTYhy64zk2sOprPOdTa0ALSK1I7cjycmk90D5KXAJXLho+f0ETVZT +dioqO6m8J7NmamcyHznyqcDzyNRqD2hEBDGVRJWmpOjIER/JwWLNNbpeVjsMHV8I +40P5rZMOhBPYlwECSC5NtMwaN472fyGNNze8u37IZKiER/ECQQDe1iY5AG3CgkP3 +tEZB3Vtzcn4PoOr3Utyn1YER34lPqAmeAsWUhmAVEfR3N1HDe1VFD9s2BidhBn1a +/Bgqxz4DAkEAzNw0m+uO0WkD7aEYRBW7SbXCX+3xsbVToIWC1jXFG+XDzSWn++c1 +DMXEElzEJxPDA+FzQUvRTml4P92bTAbGywJAS9H7wWtm7Ubbj33UZfbGdhqfz/uF +109naufXedhgZS0c0JnK1oV+Tc0FLEczV9swIUaK5O/lGDtYDcw3AN84NwJBAIw5 +/1jrOOtm8uVp6+5O4dBmthJsEZEPCZtLSG/Qhoe+EvUN3Zq0fL+tb7USAsKs6ERz +wizj9PWzhDhTPMYhrVkCQGIponZHx6VqiFyLgYUH9+gDTjBhYyI+6yMTYzcRweyL +9Suc2NkS3X2Lp+wCjvVZdwGtStp6Vo8z02b3giIsAIY= +-----END RSA PRIVATE KEY----- diff --git a/netlib/test/data/text_cert b/netlib/test/data/text_cert new file mode 100644 index 00000000..36ca33b9 --- /dev/null +++ b/netlib/test/data/text_cert @@ -0,0 +1,145 @@ +-----BEGIN CERTIFICATE----- +MIIadTCCGd6gAwIBAgIGR09PUAFtMA0GCSqGSIb3DQEBBQUAMEYxCzAJBgNVBAYT +AlVTMRMwEQYDVQQKEwpHb29nbGUgSW5jMSIwIAYDVQQDExlHb29nbGUgSW50ZXJu +ZXQgQXV0aG9yaXR5MB4XDTEyMDExNzEyNTUwNFoXDTEzMDExNzEyNTUwNFowTDEL +MAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExEzARBgNVBAoTCkdvb2ds +ZSBJbmMxEzARBgNVBAMTCmdvb2dsZS5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0A +MIGJAoGBALofcxR2fud5cyFIeld9pj2vGB5GH0y9tmAYa5t33xbJguKKX/el3tXA +KMNiT1SZzu8ELJ1Ey0GcBAgHA9jVPQd0LGdbEtNIxjblAsWAD/FZlSt8X87h7C5w +2JSefOani0qgQqU6sTdsaCUGZ+Eu7D0lBfT5/Vnl2vV+zI3YmDlpAgMBAAGjghhm +MIIYYjAdBgNVHQ4EFgQUL3+JeC/oL9jZhTp3F550LautzV8wHwYDVR0jBBgwFoAU +v8Aw6/VDET5nup6R+/xq2uNrEiQwWwYDVR0fBFQwUjBQoE6gTIZKaHR0cDovL3d3 +dy5nc3RhdGljLmNvbS9Hb29nbGVJbnRlcm5ldEF1dGhvcml0eS9Hb29nbGVJbnRl +cm5ldEF1dGhvcml0eS5jcmwwZgYIKwYBBQUHAQEEWjBYMFYGCCsGAQUFBzAChkpo +dHRwOi8vd3d3LmdzdGF0aWMuY29tL0dvb2dsZUludGVybmV0QXV0aG9yaXR5L0dv +b2dsZUludGVybmV0QXV0aG9yaXR5LmNydDCCF1kGA1UdEQSCF1AwghdMggpnb29n +bGUuY29tggwqLmdvb2dsZS5jb22CCyouZ29vZ2xlLmFjggsqLmdvb2dsZS5hZIIL +Ki5nb29nbGUuYWWCCyouZ29vZ2xlLmFmggsqLmdvb2dsZS5hZ4ILKi5nb29nbGUu +YW2CCyouZ29vZ2xlLmFzggsqLmdvb2dsZS5hdIILKi5nb29nbGUuYXqCCyouZ29v +Z2xlLmJhggsqLmdvb2dsZS5iZYILKi5nb29nbGUuYmaCCyouZ29vZ2xlLmJnggsq +Lmdvb2dsZS5iaYILKi5nb29nbGUuYmqCCyouZ29vZ2xlLmJzggsqLmdvb2dsZS5i +eYILKi5nb29nbGUuY2GCDCouZ29vZ2xlLmNhdIILKi5nb29nbGUuY2OCCyouZ29v +Z2xlLmNkggsqLmdvb2dsZS5jZoILKi5nb29nbGUuY2eCCyouZ29vZ2xlLmNoggsq +Lmdvb2dsZS5jaYILKi5nb29nbGUuY2yCCyouZ29vZ2xlLmNtggsqLmdvb2dsZS5j +boIOKi5nb29nbGUuY28uYW+CDiouZ29vZ2xlLmNvLmJ3gg4qLmdvb2dsZS5jby5j +a4IOKi5nb29nbGUuY28uY3KCDiouZ29vZ2xlLmNvLmh1gg4qLmdvb2dsZS5jby5p +ZIIOKi5nb29nbGUuY28uaWyCDiouZ29vZ2xlLmNvLmltgg4qLmdvb2dsZS5jby5p +boIOKi5nb29nbGUuY28uamWCDiouZ29vZ2xlLmNvLmpwgg4qLmdvb2dsZS5jby5r +ZYIOKi5nb29nbGUuY28ua3KCDiouZ29vZ2xlLmNvLmxzgg4qLmdvb2dsZS5jby5t +YYIOKi5nb29nbGUuY28ubXqCDiouZ29vZ2xlLmNvLm56gg4qLmdvb2dsZS5jby50 +aIIOKi5nb29nbGUuY28udHqCDiouZ29vZ2xlLmNvLnVngg4qLmdvb2dsZS5jby51 +a4IOKi5nb29nbGUuY28udXqCDiouZ29vZ2xlLmNvLnZlgg4qLmdvb2dsZS5jby52 +aYIOKi5nb29nbGUuY28uemGCDiouZ29vZ2xlLmNvLnptgg4qLmdvb2dsZS5jby56 +d4IPKi5nb29nbGUuY29tLmFmgg8qLmdvb2dsZS5jb20uYWeCDyouZ29vZ2xlLmNv +bS5haYIPKi5nb29nbGUuY29tLmFygg8qLmdvb2dsZS5jb20uYXWCDyouZ29vZ2xl +LmNvbS5iZIIPKi5nb29nbGUuY29tLmJogg8qLmdvb2dsZS5jb20uYm6CDyouZ29v +Z2xlLmNvbS5ib4IPKi5nb29nbGUuY29tLmJygg8qLmdvb2dsZS5jb20uYnmCDyou +Z29vZ2xlLmNvbS5ieoIPKi5nb29nbGUuY29tLmNugg8qLmdvb2dsZS5jb20uY2+C +DyouZ29vZ2xlLmNvbS5jdYIPKi5nb29nbGUuY29tLmN5gg8qLmdvb2dsZS5jb20u +ZG+CDyouZ29vZ2xlLmNvbS5lY4IPKi5nb29nbGUuY29tLmVngg8qLmdvb2dsZS5j +b20uZXSCDyouZ29vZ2xlLmNvbS5maoIPKi5nb29nbGUuY29tLmdlgg8qLmdvb2ds +ZS5jb20uZ2iCDyouZ29vZ2xlLmNvbS5naYIPKi5nb29nbGUuY29tLmdygg8qLmdv +b2dsZS5jb20uZ3SCDyouZ29vZ2xlLmNvbS5oa4IPKi5nb29nbGUuY29tLmlxgg8q +Lmdvb2dsZS5jb20uam2CDyouZ29vZ2xlLmNvbS5qb4IPKi5nb29nbGUuY29tLmto +gg8qLmdvb2dsZS5jb20ua3eCDyouZ29vZ2xlLmNvbS5sYoIPKi5nb29nbGUuY29t +Lmx5gg8qLmdvb2dsZS5jb20ubXSCDyouZ29vZ2xlLmNvbS5teIIPKi5nb29nbGUu +Y29tLm15gg8qLmdvb2dsZS5jb20ubmGCDyouZ29vZ2xlLmNvbS5uZoIPKi5nb29n +bGUuY29tLm5ngg8qLmdvb2dsZS5jb20ubmmCDyouZ29vZ2xlLmNvbS5ucIIPKi5n +b29nbGUuY29tLm5ygg8qLmdvb2dsZS5jb20ub22CDyouZ29vZ2xlLmNvbS5wYYIP +Ki5nb29nbGUuY29tLnBlgg8qLmdvb2dsZS5jb20ucGiCDyouZ29vZ2xlLmNvbS5w +a4IPKi5nb29nbGUuY29tLnBsgg8qLmdvb2dsZS5jb20ucHKCDyouZ29vZ2xlLmNv +bS5weYIPKi5nb29nbGUuY29tLnFhgg8qLmdvb2dsZS5jb20ucnWCDyouZ29vZ2xl +LmNvbS5zYYIPKi5nb29nbGUuY29tLnNigg8qLmdvb2dsZS5jb20uc2eCDyouZ29v +Z2xlLmNvbS5zbIIPKi5nb29nbGUuY29tLnN2gg8qLmdvb2dsZS5jb20udGqCDyou +Z29vZ2xlLmNvbS50boIPKi5nb29nbGUuY29tLnRygg8qLmdvb2dsZS5jb20udHeC +DyouZ29vZ2xlLmNvbS51YYIPKi5nb29nbGUuY29tLnV5gg8qLmdvb2dsZS5jb20u +dmOCDyouZ29vZ2xlLmNvbS52ZYIPKi5nb29nbGUuY29tLnZuggsqLmdvb2dsZS5j +doILKi5nb29nbGUuY3qCCyouZ29vZ2xlLmRlggsqLmdvb2dsZS5kaoILKi5nb29n +bGUuZGuCCyouZ29vZ2xlLmRtggsqLmdvb2dsZS5keoILKi5nb29nbGUuZWWCCyou +Z29vZ2xlLmVzggsqLmdvb2dsZS5maYILKi5nb29nbGUuZm2CCyouZ29vZ2xlLmZy +ggsqLmdvb2dsZS5nYYILKi5nb29nbGUuZ2WCCyouZ29vZ2xlLmdnggsqLmdvb2ds +ZS5nbIILKi5nb29nbGUuZ22CCyouZ29vZ2xlLmdwggsqLmdvb2dsZS5ncoILKi5n +b29nbGUuZ3mCCyouZ29vZ2xlLmhrggsqLmdvb2dsZS5oboILKi5nb29nbGUuaHKC +CyouZ29vZ2xlLmh0ggsqLmdvb2dsZS5odYILKi5nb29nbGUuaWWCCyouZ29vZ2xl +Lmltgg0qLmdvb2dsZS5pbmZvggsqLmdvb2dsZS5pcYILKi5nb29nbGUuaXOCCyou +Z29vZ2xlLml0gg4qLmdvb2dsZS5pdC5hb4ILKi5nb29nbGUuamWCCyouZ29vZ2xl +Lmpvgg0qLmdvb2dsZS5qb2JzggsqLmdvb2dsZS5qcIILKi5nb29nbGUua2eCCyou +Z29vZ2xlLmtpggsqLmdvb2dsZS5reoILKi5nb29nbGUubGGCCyouZ29vZ2xlLmxp +ggsqLmdvb2dsZS5sa4ILKi5nb29nbGUubHSCCyouZ29vZ2xlLmx1ggsqLmdvb2ds +ZS5sdoILKi5nb29nbGUubWSCCyouZ29vZ2xlLm1lggsqLmdvb2dsZS5tZ4ILKi5n +b29nbGUubWuCCyouZ29vZ2xlLm1sggsqLmdvb2dsZS5tboILKi5nb29nbGUubXOC +CyouZ29vZ2xlLm11ggsqLmdvb2dsZS5tdoILKi5nb29nbGUubXeCCyouZ29vZ2xl +Lm5lgg4qLmdvb2dsZS5uZS5qcIIMKi5nb29nbGUubmV0ggsqLmdvb2dsZS5ubIIL +Ki5nb29nbGUubm+CCyouZ29vZ2xlLm5yggsqLmdvb2dsZS5udYIPKi5nb29nbGUu +b2ZmLmFpggsqLmdvb2dsZS5wa4ILKi5nb29nbGUucGyCCyouZ29vZ2xlLnBuggsq +Lmdvb2dsZS5wc4ILKi5nb29nbGUucHSCCyouZ29vZ2xlLnJvggsqLmdvb2dsZS5y +c4ILKi5nb29nbGUucnWCCyouZ29vZ2xlLnJ3ggsqLmdvb2dsZS5zY4ILKi5nb29n +bGUuc2WCCyouZ29vZ2xlLnNoggsqLmdvb2dsZS5zaYILKi5nb29nbGUuc2uCCyou +Z29vZ2xlLnNtggsqLmdvb2dsZS5zboILKi5nb29nbGUuc2+CCyouZ29vZ2xlLnN0 +ggsqLmdvb2dsZS50ZIILKi5nb29nbGUudGeCCyouZ29vZ2xlLnRrggsqLmdvb2ds +ZS50bIILKi5nb29nbGUudG2CCyouZ29vZ2xlLnRuggsqLmdvb2dsZS50b4ILKi5n +b29nbGUudHCCCyouZ29vZ2xlLnR0ggsqLmdvb2dsZS51c4ILKi5nb29nbGUudXqC +CyouZ29vZ2xlLnZnggsqLmdvb2dsZS52dYILKi5nb29nbGUud3OCCWdvb2dsZS5h +Y4IJZ29vZ2xlLmFkgglnb29nbGUuYWWCCWdvb2dsZS5hZoIJZ29vZ2xlLmFnggln +b29nbGUuYW2CCWdvb2dsZS5hc4IJZ29vZ2xlLmF0gglnb29nbGUuYXqCCWdvb2ds +ZS5iYYIJZ29vZ2xlLmJlgglnb29nbGUuYmaCCWdvb2dsZS5iZ4IJZ29vZ2xlLmJp +gglnb29nbGUuYmqCCWdvb2dsZS5ic4IJZ29vZ2xlLmJ5gglnb29nbGUuY2GCCmdv +b2dsZS5jYXSCCWdvb2dsZS5jY4IJZ29vZ2xlLmNkgglnb29nbGUuY2aCCWdvb2ds +ZS5jZ4IJZ29vZ2xlLmNogglnb29nbGUuY2mCCWdvb2dsZS5jbIIJZ29vZ2xlLmNt +gglnb29nbGUuY26CDGdvb2dsZS5jby5hb4IMZ29vZ2xlLmNvLmJ3ggxnb29nbGUu +Y28uY2uCDGdvb2dsZS5jby5jcoIMZ29vZ2xlLmNvLmh1ggxnb29nbGUuY28uaWSC +DGdvb2dsZS5jby5pbIIMZ29vZ2xlLmNvLmltggxnb29nbGUuY28uaW6CDGdvb2ds +ZS5jby5qZYIMZ29vZ2xlLmNvLmpwggxnb29nbGUuY28ua2WCDGdvb2dsZS5jby5r +coIMZ29vZ2xlLmNvLmxzggxnb29nbGUuY28ubWGCDGdvb2dsZS5jby5teoIMZ29v +Z2xlLmNvLm56ggxnb29nbGUuY28udGiCDGdvb2dsZS5jby50eoIMZ29vZ2xlLmNv +LnVnggxnb29nbGUuY28udWuCDGdvb2dsZS5jby51eoIMZ29vZ2xlLmNvLnZlggxn +b29nbGUuY28udmmCDGdvb2dsZS5jby56YYIMZ29vZ2xlLmNvLnptggxnb29nbGUu +Y28ueneCDWdvb2dsZS5jb20uYWaCDWdvb2dsZS5jb20uYWeCDWdvb2dsZS5jb20u +YWmCDWdvb2dsZS5jb20uYXKCDWdvb2dsZS5jb20uYXWCDWdvb2dsZS5jb20uYmSC +DWdvb2dsZS5jb20uYmiCDWdvb2dsZS5jb20uYm6CDWdvb2dsZS5jb20uYm+CDWdv +b2dsZS5jb20uYnKCDWdvb2dsZS5jb20uYnmCDWdvb2dsZS5jb20uYnqCDWdvb2ds +ZS5jb20uY26CDWdvb2dsZS5jb20uY2+CDWdvb2dsZS5jb20uY3WCDWdvb2dsZS5j +b20uY3mCDWdvb2dsZS5jb20uZG+CDWdvb2dsZS5jb20uZWOCDWdvb2dsZS5jb20u +ZWeCDWdvb2dsZS5jb20uZXSCDWdvb2dsZS5jb20uZmqCDWdvb2dsZS5jb20uZ2WC +DWdvb2dsZS5jb20uZ2iCDWdvb2dsZS5jb20uZ2mCDWdvb2dsZS5jb20uZ3KCDWdv +b2dsZS5jb20uZ3SCDWdvb2dsZS5jb20uaGuCDWdvb2dsZS5jb20uaXGCDWdvb2ds +ZS5jb20uam2CDWdvb2dsZS5jb20uam+CDWdvb2dsZS5jb20ua2iCDWdvb2dsZS5j +b20ua3eCDWdvb2dsZS5jb20ubGKCDWdvb2dsZS5jb20ubHmCDWdvb2dsZS5jb20u +bXSCDWdvb2dsZS5jb20ubXiCDWdvb2dsZS5jb20ubXmCDWdvb2dsZS5jb20ubmGC +DWdvb2dsZS5jb20ubmaCDWdvb2dsZS5jb20ubmeCDWdvb2dsZS5jb20ubmmCDWdv +b2dsZS5jb20ubnCCDWdvb2dsZS5jb20ubnKCDWdvb2dsZS5jb20ub22CDWdvb2ds +ZS5jb20ucGGCDWdvb2dsZS5jb20ucGWCDWdvb2dsZS5jb20ucGiCDWdvb2dsZS5j +b20ucGuCDWdvb2dsZS5jb20ucGyCDWdvb2dsZS5jb20ucHKCDWdvb2dsZS5jb20u +cHmCDWdvb2dsZS5jb20ucWGCDWdvb2dsZS5jb20ucnWCDWdvb2dsZS5jb20uc2GC +DWdvb2dsZS5jb20uc2KCDWdvb2dsZS5jb20uc2eCDWdvb2dsZS5jb20uc2yCDWdv +b2dsZS5jb20uc3aCDWdvb2dsZS5jb20udGqCDWdvb2dsZS5jb20udG6CDWdvb2ds +ZS5jb20udHKCDWdvb2dsZS5jb20udHeCDWdvb2dsZS5jb20udWGCDWdvb2dsZS5j +b20udXmCDWdvb2dsZS5jb20udmOCDWdvb2dsZS5jb20udmWCDWdvb2dsZS5jb20u +dm6CCWdvb2dsZS5jdoIJZ29vZ2xlLmN6gglnb29nbGUuZGWCCWdvb2dsZS5kaoIJ +Z29vZ2xlLmRrgglnb29nbGUuZG2CCWdvb2dsZS5keoIJZ29vZ2xlLmVlgglnb29n +bGUuZXOCCWdvb2dsZS5maYIJZ29vZ2xlLmZtgglnb29nbGUuZnKCCWdvb2dsZS5n +YYIJZ29vZ2xlLmdlgglnb29nbGUuZ2eCCWdvb2dsZS5nbIIJZ29vZ2xlLmdtggln +b29nbGUuZ3CCCWdvb2dsZS5ncoIJZ29vZ2xlLmd5gglnb29nbGUuaGuCCWdvb2ds +ZS5oboIJZ29vZ2xlLmhygglnb29nbGUuaHSCCWdvb2dsZS5odYIJZ29vZ2xlLmll +gglnb29nbGUuaW2CC2dvb2dsZS5pbmZvgglnb29nbGUuaXGCCWdvb2dsZS5pc4IJ +Z29vZ2xlLml0ggxnb29nbGUuaXQuYW+CCWdvb2dsZS5qZYIJZ29vZ2xlLmpvggtn +b29nbGUuam9ic4IJZ29vZ2xlLmpwgglnb29nbGUua2eCCWdvb2dsZS5raYIJZ29v +Z2xlLmt6gglnb29nbGUubGGCCWdvb2dsZS5saYIJZ29vZ2xlLmxrgglnb29nbGUu +bHSCCWdvb2dsZS5sdYIJZ29vZ2xlLmx2gglnb29nbGUubWSCCWdvb2dsZS5tZYIJ +Z29vZ2xlLm1ngglnb29nbGUubWuCCWdvb2dsZS5tbIIJZ29vZ2xlLm1ugglnb29n +bGUubXOCCWdvb2dsZS5tdYIJZ29vZ2xlLm12gglnb29nbGUubXeCCWdvb2dsZS5u +ZYIMZ29vZ2xlLm5lLmpwggpnb29nbGUubmV0gglnb29nbGUubmyCCWdvb2dsZS5u +b4IJZ29vZ2xlLm5ygglnb29nbGUubnWCDWdvb2dsZS5vZmYuYWmCCWdvb2dsZS5w +a4IJZ29vZ2xlLnBsgglnb29nbGUucG6CCWdvb2dsZS5wc4IJZ29vZ2xlLnB0ggln +b29nbGUucm+CCWdvb2dsZS5yc4IJZ29vZ2xlLnJ1gglnb29nbGUucneCCWdvb2ds +ZS5zY4IJZ29vZ2xlLnNlgglnb29nbGUuc2iCCWdvb2dsZS5zaYIJZ29vZ2xlLnNr +gglnb29nbGUuc22CCWdvb2dsZS5zboIJZ29vZ2xlLnNvgglnb29nbGUuc3SCCWdv +b2dsZS50ZIIJZ29vZ2xlLnRngglnb29nbGUudGuCCWdvb2dsZS50bIIJZ29vZ2xl +LnRtgglnb29nbGUudG6CCWdvb2dsZS50b4IJZ29vZ2xlLnRwgglnb29nbGUudHSC +CWdvb2dsZS51c4IJZ29vZ2xlLnV6gglnb29nbGUudmeCCWdvb2dsZS52dYIJZ29v +Z2xlLndzMA0GCSqGSIb3DQEBBQUAA4GBAJmZ9RyqpUzrP0UcJnHXoLu/AjIEsIvZ +Y9hq/9bLry8InfmvERYHr4hNetkOYlW0FeDZtCpWxdPUgJjmWgKAK6j0goOFavTV +GptkL8gha4p1QUsdLkd36/cvBXeBYSle787veo46N1k4V6Uv2gaDVkre786CNsHv +Q6MYZ5ClQ+kS +-----END CERTIFICATE----- + diff --git a/netlib/test/data/text_cert_2 b/netlib/test/data/text_cert_2 new file mode 100644 index 00000000..ffe8faae --- /dev/null +++ b/netlib/test/data/text_cert_2 @@ -0,0 +1,39 @@ +-----BEGIN CERTIFICATE----- +MIIGujCCBaKgAwIBAgIDAQlEMA0GCSqGSIb3DQEBBQUAMIGMMQswCQYDVQQGEwJJ +TDEWMBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0 +YWwgQ2VydGlmaWNhdGUgU2lnbmluZzE4MDYGA1UEAxMvU3RhcnRDb20gQ2xhc3Mg +MSBQcmltYXJ5IEludGVybWVkaWF0ZSBTZXJ2ZXIgQ0EwHhcNMTAwMTExMTkyNzM2 +WhcNMTEwMTEyMDkxNDU1WjCBtDEgMB4GA1UEDRMXMTI2ODMyLU1DeExzWTZUbjFn +bTdvOTAxCzAJBgNVBAYTAk5aMR4wHAYDVQQKExVQZXJzb25hIE5vdCBWYWxpZGF0 +ZWQxKTAnBgNVBAsTIFN0YXJ0Q29tIEZyZWUgQ2VydGlmaWNhdGUgTWVtYmVyMRgw +FgYDVQQDEw93d3cuaW5vZGUuY28ubnoxHjAcBgkqhkiG9w0BCQEWD2ppbUBpbm9k +ZS5jby5uejCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAL6ghWlGhqg+ +V0P58R3SvLRiO9OrdekDxzmQbKwQcc05frnF5Z9vT6ga7YOuXVeXxhYCAo0nr6KI ++y/Lx+QHvP5W0nKbs+svzUQErq2ZZFwhh1e1LbVccrNwkHUzKOq0TTaVdU4k8kDQ +zzYF9tTZb+G5Hv1BJjpwYwe8P4cAiPJPrFFOKTySzHqiYsXlx+vR1l1e3zKavhd+ +LVSoLWWXb13yKODq6vnuiHjUJXl8CfVlBhoGotXU4JR5cbuGoW/8+rkwEdX+YoCv +VCqgdx9IkRFB6uWfN6ocUiFvhA0eknO+ewuVfRLiIaSDB8pNyUWVqu4ngFWtWO1O +YZg0I/32BkcCAwEAAaOCAvkwggL1MAkGA1UdEwQCMAAwCwYDVR0PBAQDAgOoMBMG +A1UdJQQMMAoGCCsGAQUFBwMBMB0GA1UdDgQWBBQfaL2Rj6r8iRlBTgppgE7ZZ5WT +UzAfBgNVHSMEGDAWgBTrQjTQmLCrn/Qbawj3zGQu7w4sRTAnBgNVHREEIDAegg93 +d3cuaW5vZGUuY28ubnqCC2lub2RlLmNvLm56MIIBQgYDVR0gBIIBOTCCATUwggEx +BgsrBgEEAYG1NwECATCCASAwLgYIKwYBBQUHAgEWImh0dHA6Ly93d3cuc3RhcnRz +c2wuY29tL3BvbGljeS5wZGYwNAYIKwYBBQUHAgEWKGh0dHA6Ly93d3cuc3RhcnRz +c2wuY29tL2ludGVybWVkaWF0ZS5wZGYwgbcGCCsGAQUFBwICMIGqMBQWDVN0YXJ0 +Q29tIEx0ZC4wAwIBARqBkUxpbWl0ZWQgTGlhYmlsaXR5LCBzZWUgc2VjdGlvbiAq +TGVnYWwgTGltaXRhdGlvbnMqIG9mIHRoZSBTdGFydENvbSBDZXJ0aWZpY2F0aW9u +IEF1dGhvcml0eSBQb2xpY3kgYXZhaWxhYmxlIGF0IGh0dHA6Ly93d3cuc3RhcnRz +c2wuY29tL3BvbGljeS5wZGYwYQYDVR0fBFowWDAqoCigJoYkaHR0cDovL3d3dy5z +dGFydHNzbC5jb20vY3J0MS1jcmwuY3JsMCqgKKAmhiRodHRwOi8vY3JsLnN0YXJ0 +c3NsLmNvbS9jcnQxLWNybC5jcmwwgY4GCCsGAQUFBwEBBIGBMH8wOQYIKwYBBQUH +MAGGLWh0dHA6Ly9vY3NwLnN0YXJ0c3NsLmNvbS9zdWIvY2xhc3MxL3NlcnZlci9j +YTBCBggrBgEFBQcwAoY2aHR0cDovL3d3dy5zdGFydHNzbC5jb20vY2VydHMvc3Vi +LmNsYXNzMS5zZXJ2ZXIuY2EuY3J0MCMGA1UdEgQcMBqGGGh0dHA6Ly93d3cuc3Rh +cnRzc2wuY29tLzANBgkqhkiG9w0BAQUFAAOCAQEAivWID0KT8q1EzWzy+BecsFry +hQhuLFfAsPkHqpNd9OfkRStGBuJlLX+9DQ9TzjqutdY2buNBuDn71buZK+Y5fmjr +28rAT6+WMd+KnCl5WLT5IOS6Z9s3cec5TFQbmOGlepSS9Q6Ts9KsXOHHQvDkQeDq +OV2UqdgXIAyFm5efSL9JXPXntRausNu2s8F2B2rRJe4jPfnUy2LvY8OW1YvjUA++ +vpdWRdfUbJQp55mRfaYMPRnyUm30lAI27QaxgQPFOqDeZUm5llb5eFG/B3f87uhg ++Y1oEykbEvZrIFN4hithioQ0tb+57FKkkG2sW3uemNiQw2qrEo/GAMb1cI50Rg== +-----END CERTIFICATE----- + diff --git a/netlib/test/data/text_cert_weird1 b/netlib/test/data/text_cert_weird1 new file mode 100644 index 00000000..72b09dcb --- /dev/null +++ b/netlib/test/data/text_cert_weird1 @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFNDCCBBygAwIBAgIEDFJFNzANBgkqhkiG9w0BAQUFADCBjDELMAkGA1UEBhMC +REUxHjAcBgNVBAoTFVVuaXZlcnNpdGFldCBNdWVuc3RlcjE6MDgGA1UEAxMxWmVy +dGlmaXppZXJ1bmdzc3RlbGxlIFVuaXZlcnNpdGFldCBNdWVuc3RlciAtIEcwMjEh +MB8GCSqGSIb3DQEJARYSY2FAdW5pLW11ZW5zdGVyLmRlMB4XDTA4MDUyMDEyNDQy +NFoXDTEzMDUxOTEyNDQyNFowezELMAkGA1UEBhMCREUxHjAcBgNVBAoTFVVuaXZl +cnNpdGFldCBNdWVuc3RlcjEuMCwGA1UECxMlWmVudHJ1bSBmdWVyIEluZm9ybWF0 +aW9uc3ZlcmFyYmVpdHVuZzEcMBoGA1UEAxMTd3d3LnVuaS1tdWVuc3Rlci5kZTCC +ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMM0WlCj0ew+tyZ1GurBOqFn +AlChKk4S1F9oDzvp3FwOON4H8YFET7p9ZnoWtkfXSlGNMjekqy67dFlLt1sLusSo +tjNdaOrDLYmnGEgnYAT0RFBvErzIybJoD/Vu3NXyhes+L94R9mEMCwYXmSvG51H9 +c5CvguXBofMchDLCM/U6AYpwu3sST5orV3S1Rsa9sndj8sKJAcw195PYwl6EiEBb +M36ltDBlTYEUAg3Z+VSzB09J3U4vSvguVkDCz+szZh5RG3xlN9mlNfzhf4lHrNgV +0BRbKypa5Uuf81wbMcMMqTxKq+A9ysObpn9J3pNUym+Tn2oqHzGgvwZYB4tzXqUC +AwEAAaOCAawwggGoMAkGA1UdEwQCMAAwCwYDVR0PBAQDAgTwMBMGA1UdJQQMMAoG +CCsGAQUFBwMBMB0GA1UdDgQWBBQ3RFo8awewUTq5TpOFf3jOCEKihzAfBgNVHSME +GDAWgBS+nlGiyZJ8u2CL5rBoZHdaUhmhADAjBgNVHREEHDAagRh3d3dhZG1pbkB1 +bmktbXVlbnN0ZXIuZGUwewYDVR0fBHQwcjA3oDWgM4YxaHR0cDovL2NkcDEucGNh +LmRmbi5kZS93d3UtY2EvcHViL2NybC9nX2NhY3JsLmNybDA3oDWgM4YxaHR0cDov +L2NkcDIucGNhLmRmbi5kZS93d3UtY2EvcHViL2NybC9nX2NhY3JsLmNybDCBlgYI +KwYBBQUHAQEEgYkwgYYwQQYIKwYBBQUHMAKGNWh0dHA6Ly9jZHAxLnBjYS5kZm4u +ZGUvd3d1LWNhL3B1Yi9jYWNlcnQvZ19jYWNlcnQuY3J0MEEGCCsGAQUFBzAChjVo +dHRwOi8vY2RwMi5wY2EuZGZuLmRlL3d3dS1jYS9wdWIvY2FjZXJ0L2dfY2FjZXJ0 +LmNydDANBgkqhkiG9w0BAQUFAAOCAQEAFfNpagtcKUSDKss7TcqjYn99FQ4FtWjE +pGmzYL2zX2wsdCGoVQlGkieL9slbQVEUAnBuqM1LPzUNNe9kZpOPV3Rdhq4y8vyS +xkx3G1v5aGxfPUe8KM8yKIOHRqYefNronHJM0fw7KyjQ73xgbIEgkW+kNXaMLcrb +EPC36O2Zna8GP9FQxJRLgcfQCcYdRKGVn0EtRSkz2ym5Rbh/hrmJBbbC2yJGGMI0 +Vu5A9piK0EZPekZIUmhMQynD9QcMfWhTEFr7YZfx9ktxKDW4spnu7YrgICfZNcCm +tfxmnEAFt6a47u9P0w9lpY8+Sx9MNFfTePym+HP4TYha9bIBes+XnA== +-----END CERTIFICATE----- + diff --git a/netlib/test/data/verificationcerts/9da13359.0 b/netlib/test/data/verificationcerts/9da13359.0 new file mode 100644 index 00000000..b22e4d20 --- /dev/null +++ b/netlib/test/data/verificationcerts/9da13359.0 @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDXTCCAkWgAwIBAgIJAPAfPQGCV/Z4MA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTUxMTAxMTY0ODAxWhcNMTgwODIxMTY0ODAxWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEArp8LD34JhKCwcQbwIYQMg4+eCgLVN8fwB7+/qOfJbArPs0djFBN+F7c6 +HGvMr24BKUk5u8pn4dPtNurm/vPC8ovNGmcXz62BQJpcMX2veVdRsF7yNwhNacNJ +Arq+70zNMwYBznx0XUxMF6j6nVFf3AW6SU04ylT4Mp3SY/BUUDAdfl1eRo0mPLNS +8rpsN+8YBw1Q7SCuBRVqpOgVIsL88svgQUSOlzvMZPBpG/cmB3BNKNrltwb5iFEI +1jAV7uSj5IcIuNO/246kfsDVPTFMJIzav/CUoidd5UNw+SoFDlzh8sA7L1Bm7D1/ +3KHYSKswGsSR3kynAl10w/SJKDtn8wIDAQABo1AwTjAdBgNVHQ4EFgQUgOcrtxBX +LxbpnOT65d+vpfyWUkgwHwYDVR0jBBgwFoAUgOcrtxBXLxbpnOT65d+vpfyWUkgw +DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAEE9bFmUCA+6cvESKPoi2 +TGSpV652d0xd2U66LpEXeiWRJFLz8YGgoJCx3QFGBscJDXxrLxrBBBV/tCpEqypo +pYIqsawH7M66jpOr83Us3M8JC2eFBZJocMpXxdytWqHik5VKZNx6VQFT8bS7+yVC +VoUKePhlgcg+pmo41qjqieBNKRMh/1tXS77DI1lgO5wZLVrLXcdqWuDpmaQOKJeq +G/nxytCW/YJA7bFn/8Gjy8DYypJSeeaKu7o3P3+ONJHdIMHb+MdcheDBS9AOFSeo +xI0D5EbO9F873O77l7nbD7B0X34HFN0nGczC4poexIpbDFG3hAPekwZ5KC6VwJLc +1Q== +-----END CERTIFICATE----- diff --git a/netlib/test/data/verificationcerts/generate.py b/netlib/test/data/verificationcerts/generate.py new file mode 100644 index 00000000..9203abbb --- /dev/null +++ b/netlib/test/data/verificationcerts/generate.py @@ -0,0 +1,68 @@ +""" +Generate SSL test certificates. +""" +import subprocess +import shlex +import os +import shutil + + +ROOT_CA = "trusted-root" +SUBJECT = "/CN=example.mitmproxy.org/" + + +def do(args): +    print("> %s" % args) +    args = shlex.split(args) +    output = subprocess.check_output(args) +    return output + + +def genrsa(cert): +    do("openssl genrsa -out {cert}.key 2048".format(cert=cert)) + + +def sign(cert): +    do("openssl x509 -req -in {cert}.csr " +       "-CA {root_ca}.crt " +       "-CAkey {root_ca}.key " +       "-CAcreateserial " +       "-days 1024 " +       "-out {cert}.crt".format(root_ca=ROOT_CA, cert=cert) +       ) + + +def mkcert(cert, args): +    genrsa(cert) +    do("openssl req -new -nodes -batch " +       "-key {cert}.key " +       "{args} " +       "-out {cert}.csr".format(cert=cert, args=args) +       ) +    sign(cert) +    os.remove("{cert}.csr".format(cert=cert)) + + +# create trusted root CA +genrsa("trusted-root") +do("openssl req -x509 -new -nodes -batch " +   "-key trusted-root.key " +   "-days 1024 " +   "-out trusted-root.crt" +   ) +h = do("openssl x509 -hash -noout -in trusted-root.crt").decode("ascii").strip() +shutil.copyfile("trusted-root.crt", "{}.0".format(h)) + +# create trusted leaf cert. +mkcert("trusted-leaf", "-subj {}".format(SUBJECT)) + +# create self-signed cert +genrsa("self-signed") +do("openssl req -x509 -new -nodes -batch " +   "-key self-signed.key " +   "-subj {} " +   "-days 1024 " +   "-out self-signed.crt".format(SUBJECT) +   ) + + diff --git a/netlib/test/data/verificationcerts/self-signed.crt b/netlib/test/data/verificationcerts/self-signed.crt new file mode 100644 index 00000000..dce2a7e0 --- /dev/null +++ b/netlib/test/data/verificationcerts/self-signed.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDEzCCAfugAwIBAgIJAJ945xt1FRsfMA0GCSqGSIb3DQEBCwUAMCAxHjAcBgNV +BAMMFWV4YW1wbGUubWl0bXByb3h5Lm9yZzAeFw0xNTExMDExNjQ4MDJaFw0xODA4 +MjExNjQ4MDJaMCAxHjAcBgNVBAMMFWV4YW1wbGUubWl0bXByb3h5Lm9yZzCCASIw +DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALFxyzPfjgIghOMMnJlW80yB84xC +nJtko3tuyOdozgTCyha2W+NdIKPNZJtWrzN4P0B5PlozCDwfcSYffLs0WZs8LRWv +BfZX8+oX+14qQjKFsiqgO65cTLP3qlPySYPJQQ37vOP1Y5Yf8nQq2mwQdC18hLtT +QOANG6OFoSplpBLsYF+QeoMgqCTa6hrl/5GLmQoDRTjXkv3Sj379AUDMybuBqccm +q5EIqCrE4+xJ8JywJclAVn2YP14baiFrrYCsYYg4sS1Od6xFj+xtpLe7My3AYjB9 +/aeHd8vDiob0cqOW1TFwhqgJKuErfFyg8lZ2hJmStJKyfofWuY/gl/vnvX0CAwEA +AaNQME4wHQYDVR0OBBYEFB8d32zK8eqZIoKw4jXzYzhw4amPMB8GA1UdIwQYMBaA +FB8d32zK8eqZIoKw4jXzYzhw4amPMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEL +BQADggEBAJmo2oKv1OEjZ0Q4yELO6BAnHAkmBKpW+zmLyQa8idxtLVkI9uXk3iqY +GWugkmcUZCTVFRWv/QXQQSex+00IY3x2rdHbtuZwcyKiz2u8WEmfW1rOIwBaFJ1i +v7+SA2aZs6vepN2sE56X54c/YbwQooaKZtOb+djWXYMJrc/Ezj0J7oQIJTptYV8v +/3216yCHRp/KCL7yTLtiw25xKuXNu/gkcd8wZOY9rS2qMUD897MJF0MvgJoauRBd +d4XEYCNKkrIRmfqrkiRQfAZpvpoutH6NCk7KuQYcI0BlOHlsnHHcs/w72EEqHwFq +x6476tW/t8GJDZVD74+pNBcLifXxArE= +-----END CERTIFICATE----- diff --git a/netlib/test/data/verificationcerts/self-signed.key b/netlib/test/data/verificationcerts/self-signed.key new file mode 100644 index 00000000..71a6ad6a --- /dev/null +++ b/netlib/test/data/verificationcerts/self-signed.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAsXHLM9+OAiCE4wycmVbzTIHzjEKcm2Sje27I52jOBMLKFrZb +410go81km1avM3g/QHk+WjMIPB9xJh98uzRZmzwtFa8F9lfz6hf7XipCMoWyKqA7 +rlxMs/eqU/JJg8lBDfu84/Vjlh/ydCrabBB0LXyEu1NA4A0bo4WhKmWkEuxgX5B6 +gyCoJNrqGuX/kYuZCgNFONeS/dKPfv0BQMzJu4GpxyarkQioKsTj7EnwnLAlyUBW +fZg/XhtqIWutgKxhiDixLU53rEWP7G2kt7szLcBiMH39p4d3y8OKhvRyo5bVMXCG +qAkq4St8XKDyVnaEmZK0krJ+h9a5j+CX++e9fQIDAQABAoIBAQCT+FvGbych2PJX +0D2KlXqgE0IAdc/YuYymstSwPLKIP9N8KyfnKtK8Jdw+uYOyfRTp8/EuEJ5OXL3j +V6CRD++lRwIlseVb7y5EySjh9oVrUhgn+aSrGucPsHkGNeZeEmbAfWugARLBrvRl +MRMhyHrJL6wT9jIEZInmy9mA3G99IuFW3rS8UR1Yu7zyvhtjvop1xg/wfEUu24Ty +PvMfnwaDcZHCz2tmu2KJvaxSBAG3FKmAqeMvk1Gt5m2keKgw03M+EX0LrM8ybWqn +VwB8tnSyMBLVFLIXMpIiSfpji10+p9fdKFMRF++D6qVwyoxPiIq+yEJapxXiqLea +mkhtJW91AoGBAOvIb7bZvH4wYvi6txs2pygF3ZMjqg/fycnplrmYMrjeeDeeN4v1 +h/5tkN9TeTkHRaN3L7v49NEUDhDyuopLTNfWpYdv63U/BVzvgMm/guacTYkx9whB +OvQ2YekR/WKg7kuyrTZidTDz+mjU+1b8JaWGjiDc6vFwxZA7uWicaGGHAoGBAMCo +y/2AwFGwCR+5bET1nTTyxok6iKo4k6R/7DJe4Bq8VLifoyX3zDlGG/33KN3xVqBU +xnT9gkii1lfX2U+4iM+GOSPl0nG0hOEqEH+vFHszpHybDeNez3FEyIbgOzg6u7sV +NOy+P94L5EMQVEmWp5g6Vm3k9kr92Bd9UacKQPnbAoGAMN8KyMu41i8RVJze9zUM +0K7mjmkGBuRL3x4br7xsRwVVxbF1sfzig0oSjTewGLH5LTi3HC8uD2gowjqNj7yr +4NEM3lXEaDj305uRBkA70bD0IUvJ+FwM7DGZecXQz3Cr8+TFIlCmGc94R+Jddlot +M3IAY69mw0SsroiylYxV1mECgYAcSGtx8rXJCDO+sYTgdsI2ZLGasbogax/ZlWIC +XwU9R4qUc/MKft8/RTiUxvT76BMUhH2B7Tl0GlunF6vyVR/Yf1biGzoSsTKUr40u +gXBbSdCK7mRSjbecZEGf80keTxkCNPHJE4DiwxImej41c2V1JpNLnMI/bhaMFDyp +bgrt4wKBgHFzZgAgM1v07F038tAkIBGrYLukY1ZFBaZoGZ9xHfy/EmLJM3HCHLO5 +8wszMGhMTe2+39EeChwgj0kFaq1YnDiucU74BC57KR1tD59y7l6UnsQXTm4/32j8 +Or6i8GekBibCb97DzzOU0ZK//fNhHTXpDDXsYt5lJUWSmgW+S9Qp +-----END RSA PRIVATE KEY----- diff --git a/netlib/test/data/verificationcerts/trusted-leaf.crt b/netlib/test/data/verificationcerts/trusted-leaf.crt new file mode 100644 index 00000000..6a92de92 --- /dev/null +++ b/netlib/test/data/verificationcerts/trusted-leaf.crt @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC4TCCAckCCQCj6D9oVylb8jANBgkqhkiG9w0BAQsFADBFMQswCQYDVQQGEwJB +VTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMB4XDTE1MTEwMTE2NDgwMloXDTE4MDgyMTE2NDgwMlowIDEeMBwG +A1UEAwwVZXhhbXBsZS5taXRtcHJveHkub3JnMIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEAy/L5JYHS7QFhSIsjmd6bJTgs2rdqEn6tsmPBVZKZ7SqCAVjW +hPpEu7Q23akmU6Zm9Fp/vENc3jzxQLlEKhrv7eWmFYSOrCYtbJOz3RQorlwjjfdY +LlNQh1wYUXQX3PN3r3dyYtt5vTtXKc8+aP4M4vX7qlbW+4j4LrQfmPjS0XOdYpu3 +wh+i1ZMIhZye3hpCjwnpjTf7/ff45ZFxtkoi1uzEC/+swr1RSvamY8Foe12Re17Z +5ij8ZB0NIdoSk1tDkY3sJ8iNi35+qartl0UYeG9IUXRwDRrPsEKpF4RxY1+X2bdZ +r6PKb/E4CA5JlMvS5SVmrvxjCVqTQBmTjXfxqwIDAQABMA0GCSqGSIb3DQEBCwUA +A4IBAQBmpSZJrTDvzSlo6P7P7x1LoETzHyVjwgPeqGYw6ndGXeJMN9rhhsFvRsiB +I/aHh58MIlSjti7paikDAoFHB3dBvFHR+JUa/ailWEbcZReWRSE3lV6wFiN3G3lU +OyofR7MKnPW7bv8hSqOLqP1mbupXuQFB5M6vPLRwg5VgiCHI/XBiTvzMamzvNAR3 +UHHZtsJkRqzogYm6K9YJaga7jteSx2nNo+ujLwrxeXsLChTyFMJGnVkp5IyKeNfc +qwlzNncb3y+4KnUdNkPEtuydgAxAfuyXufiFBYRcUWbQ5/9ycgF7131ySaj9f/Y2 +kMsv2jg+soKvwwVYCABsk1KSHtfz +-----END CERTIFICATE----- diff --git a/netlib/test/data/verificationcerts/trusted-leaf.key b/netlib/test/data/verificationcerts/trusted-leaf.key new file mode 100644 index 00000000..783ebf1c --- /dev/null +++ b/netlib/test/data/verificationcerts/trusted-leaf.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAy/L5JYHS7QFhSIsjmd6bJTgs2rdqEn6tsmPBVZKZ7SqCAVjW +hPpEu7Q23akmU6Zm9Fp/vENc3jzxQLlEKhrv7eWmFYSOrCYtbJOz3RQorlwjjfdY +LlNQh1wYUXQX3PN3r3dyYtt5vTtXKc8+aP4M4vX7qlbW+4j4LrQfmPjS0XOdYpu3 +wh+i1ZMIhZye3hpCjwnpjTf7/ff45ZFxtkoi1uzEC/+swr1RSvamY8Foe12Re17Z +5ij8ZB0NIdoSk1tDkY3sJ8iNi35+qartl0UYeG9IUXRwDRrPsEKpF4RxY1+X2bdZ +r6PKb/E4CA5JlMvS5SVmrvxjCVqTQBmTjXfxqwIDAQABAoIBAQC956DWq+wbhA1x +3x1nSUBth8E8Z0z9q7dRRFHhvIBXth0X5ADcEa2umj/8ZmSpv2heX2ZRhugSh+yc +t+YgzrRacFwV7ThsU6A4WdBBK2Q19tWke4xAlpOFdtut/Mu7kXkAidiY9ISHD5o5 +9B/I48ZcD3AnTHUiAogV9OL3LbogDD4HasLt4mWkbq8U2thdjxMIvxdg36olJEuo +iAZrAUCPZEXuU89BtvPLUYioe9n90nzkyneGNS0SHxotlEc9ZYK9VTsivtXJb4wB +ptDMCp+TH3tjo8BTGnbnoZEybgyyOEd0UTzxK4DlxnvRVWexFY6NXwPFhIxKlB0Y +Bg8NkAkBAoGBAOiRnmbC5QkqrKrTkLx3fghIHPqgEXPPYgHLSuY3UjTlMb3APXpq +vzQnlCn3QuSse/1fWnQj+9vLVbx1XNgKjzk7dQhn5IUY+mGN4lLmoSnTebxvSQ43 +VAgTYjST9JFmJ3wK4KkWDsEsVao8LAx0h5JEQXUTT5xZpFA2MLztYbgfAoGBAOB/ +MvhLMAwlx8+m/zXMEPLk/KOd2dVZ4q5se8bAT/GiGsi8JUcPnCk140ZZabJqryAp +JFzUHIjfVsS9ejAfocDk1JeIm7Uus4um6fQEKIPMBxI/M/UAwYCXAG9ULXqilbO3 +pTdeeuraVKrTu1Z4ea6x4du1JWKcyDfYfsHepcT1AoGBAM2fskV5G7e3G2MOG3IG +1E/OMpEE5WlXenfLnjVdxDkwS4JRbgnGR7d9JurTyzkTp6ylmfwFtLDoXq15ttTs +wSUBBMCh2tIy+201XV2eu++XIpMQca84C/v352RFTH8hqtdpZqkY74KsCDGzcd6x +SQxxfM5efIzoVPb2crEX0MZRAoGAQ2EqFSfL9flo7UQ8GRN0itJ7mUgJV2WxCZT5 +2X9i/y0eSN1feuKOhjfsTPMNLEWk5kwy48GuBs6xpj8Qa10zGUgVHp4bzdeEgAfK +9DhDSLt1694YZBKkAUpRERj8xXAC6nvWFLZAwjhhbRw7gAqMywgMt/q4i85usYRD +F0ESE/kCgYBbc083PcLmlHbkn/d1i4IcLI6wFk+tZYIEVYDid7xDOgZOBcOTTyYB +BrDzNqbKNexKRt7QHVlwR+VOGMdN5P0hf7oH3SMW23OxBKoQe8pUSGF9a4DjCS1v +vCXMekifb9kIhhUWaG71L8+MaOzNBVAmk1+3NzPZgV/YxHjAWWhGHQ== +-----END RSA PRIVATE KEY----- diff --git a/netlib/test/data/verificationcerts/trusted-root.crt b/netlib/test/data/verificationcerts/trusted-root.crt new file mode 100644 index 00000000..b22e4d20 --- /dev/null +++ b/netlib/test/data/verificationcerts/trusted-root.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDXTCCAkWgAwIBAgIJAPAfPQGCV/Z4MA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTUxMTAxMTY0ODAxWhcNMTgwODIxMTY0ODAxWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEArp8LD34JhKCwcQbwIYQMg4+eCgLVN8fwB7+/qOfJbArPs0djFBN+F7c6 +HGvMr24BKUk5u8pn4dPtNurm/vPC8ovNGmcXz62BQJpcMX2veVdRsF7yNwhNacNJ +Arq+70zNMwYBznx0XUxMF6j6nVFf3AW6SU04ylT4Mp3SY/BUUDAdfl1eRo0mPLNS +8rpsN+8YBw1Q7SCuBRVqpOgVIsL88svgQUSOlzvMZPBpG/cmB3BNKNrltwb5iFEI +1jAV7uSj5IcIuNO/246kfsDVPTFMJIzav/CUoidd5UNw+SoFDlzh8sA7L1Bm7D1/ +3KHYSKswGsSR3kynAl10w/SJKDtn8wIDAQABo1AwTjAdBgNVHQ4EFgQUgOcrtxBX +LxbpnOT65d+vpfyWUkgwHwYDVR0jBBgwFoAUgOcrtxBXLxbpnOT65d+vpfyWUkgw +DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAEE9bFmUCA+6cvESKPoi2 +TGSpV652d0xd2U66LpEXeiWRJFLz8YGgoJCx3QFGBscJDXxrLxrBBBV/tCpEqypo +pYIqsawH7M66jpOr83Us3M8JC2eFBZJocMpXxdytWqHik5VKZNx6VQFT8bS7+yVC +VoUKePhlgcg+pmo41qjqieBNKRMh/1tXS77DI1lgO5wZLVrLXcdqWuDpmaQOKJeq +G/nxytCW/YJA7bFn/8Gjy8DYypJSeeaKu7o3P3+ONJHdIMHb+MdcheDBS9AOFSeo +xI0D5EbO9F873O77l7nbD7B0X34HFN0nGczC4poexIpbDFG3hAPekwZ5KC6VwJLc +1Q== +-----END CERTIFICATE----- diff --git a/netlib/test/data/verificationcerts/trusted-root.key b/netlib/test/data/verificationcerts/trusted-root.key new file mode 100644 index 00000000..05483f77 --- /dev/null +++ b/netlib/test/data/verificationcerts/trusted-root.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEArp8LD34JhKCwcQbwIYQMg4+eCgLVN8fwB7+/qOfJbArPs0dj +FBN+F7c6HGvMr24BKUk5u8pn4dPtNurm/vPC8ovNGmcXz62BQJpcMX2veVdRsF7y +NwhNacNJArq+70zNMwYBznx0XUxMF6j6nVFf3AW6SU04ylT4Mp3SY/BUUDAdfl1e +Ro0mPLNS8rpsN+8YBw1Q7SCuBRVqpOgVIsL88svgQUSOlzvMZPBpG/cmB3BNKNrl +twb5iFEI1jAV7uSj5IcIuNO/246kfsDVPTFMJIzav/CUoidd5UNw+SoFDlzh8sA7 +L1Bm7D1/3KHYSKswGsSR3kynAl10w/SJKDtn8wIDAQABAoIBAFgMzjDzpqz/sbhs +fS0JPp4gDtqRbx3/bSMbJvNuXPxjvzNxLZ5z7cLbmyu1l7Jlz6QXzkrI1vTiPdzR +OcUY+RYANF252iHYJTKEIzS5YX/X7dL3LT9eqlpIJEqCC8Dygw3VW5fY3Xwl+sB7 +blNhMuro4HQRwi8UBUrQlcPa7Ui5BBi323Q6en+VjYctkqpJHzNKPSqPTbsdLaK+ +B0XuXxFatM09rmeRKZCL71Lk1T8N/l0hqEzej7zxgVD7vG/x1kMFN4T3yCmXCbPa +izGHYr1EBHglm4qMNWveXCZiVJ+wmwCjdjqvggyHiZFXE2N0OCrWPhxQPdqFf5y7 +bUO9U2ECgYEA6GM1UzRnbVpjb20ezFy7dU7rlWM0nHBfG27M3bcXh4HnPpnvKp0/ +8a1WFi4kkRywrNXx8hFEd43vTbdObLpVXScXRKiY3MHmFk4k4hbWuTpmumCubQZO +AWlX6TE0HRKn1wQahgpQcxcWaDN2xJJmRQ1zVmlnNkT48/4kFgRxyykCgYEAwF08 +ngrF35oYoU/x+KKq2NXGeNUzoZMj568dE1oWW0ZFpqCi+DGT+hAbG3yUOBSaPqy9 +zn1obGo0YRlrayvtebz118kG7a/rzY02VcAPlT/GpEhvkZlXTwEK17zRJc1nJrfP +39QAZWZsaOru9NRIg/8HcdG3JPR2MhRD/De9GbsCgYAaiZnBUq6s8jGAu/lUZRKT +JtwIRzfu1XZG77Q9bXcmZlM99t41A5gVxTGbftF2MMyMMDJc7lPfQzocqd4u1GiD +Jr+le4tZSls4GNxlZS5IIL8ycW/5y0qFJr5/RrsoxsSb7UAKJothWTWZ2Karc/xx +zkNpjsfWjrHPSypbyU4lYQKBgFh1R5/BgnatjO/5LGNSok/uFkOQfxqo6BTtYOh6 +P9efO/5A1lBdtBeE+oIsSphzWO7DTtE6uB9Kw2V3Y/83hw+5RjABoG8Cu+OdMURD +eqb+WeFH8g45Pn31E8Bbcq34g5u5YR0jhz8Z13ZzuojZabNRPmIntxmGVSf4S78a +/plrAoGBANMHNng2lyr03nqnHrOM6NXD+60af0YR/YJ+2d/H40RnXxGJ4DXn7F00 +a4vJFPa97uq+xpd0HE+TE+NIrOdVDXPePD2qzBzMTsctGtj30vLzojMOT+Yf/nvO +WxTL5Q8GruJz2Dn0awSZO2z/3A8S1rmpuVZ/jT5NtRrvOSY6hmxF +-----END RSA PRIVATE KEY----- diff --git a/netlib/test/data/verificationcerts/trusted-root.srl b/netlib/test/data/verificationcerts/trusted-root.srl new file mode 100644 index 00000000..4ad962ba --- /dev/null +++ b/netlib/test/data/verificationcerts/trusted-root.srl @@ -0,0 +1 @@ +A3E83F6857295BF2 diff --git a/netlib/test/http/__init__.py b/netlib/test/http/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/netlib/test/http/__init__.py diff --git a/netlib/test/http/http1/__init__.py b/netlib/test/http/http1/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/netlib/test/http/http1/__init__.py diff --git a/netlib/test/http/http1/test_assemble.py b/netlib/test/http/http1/test_assemble.py new file mode 100644 index 00000000..31a62438 --- /dev/null +++ b/netlib/test/http/http1/test_assemble.py @@ -0,0 +1,102 @@ +from __future__ import absolute_import, print_function, division +from netlib.exceptions import HttpException +from netlib.http import CONTENT_MISSING, Headers +from netlib.http.http1.assemble import ( +    assemble_request, assemble_request_head, assemble_response, +    assemble_response_head, _assemble_request_line, _assemble_request_headers, +    _assemble_response_headers, +    assemble_body) +from netlib.tutils import treq, raises, tresp + + +def test_assemble_request(): +    c = assemble_request(treq()) == ( +        b"GET /path HTTP/1.1\r\n" +        b"header: qvalue\r\n" +        b"Host: address:22\r\n" +        b"Content-Length: 7\r\n" +        b"\r\n" +        b"content" +    ) + +    with raises(HttpException): +        assemble_request(treq(content=CONTENT_MISSING)) + + +def test_assemble_request_head(): +    c = assemble_request_head(treq(content="foo")) +    assert b"GET" in c +    assert b"qvalue" in c +    assert b"content-length" in c +    assert b"foo" not in c + + +def test_assemble_response(): +    c = assemble_response(tresp()) == ( +        b"HTTP/1.1 200 OK\r\n" +        b"header-response: svalue\r\n" +        b"Content-Length: 7\r\n" +        b"\r\n" +        b"message" +    ) + +    with raises(HttpException): +        assemble_response(tresp(content=CONTENT_MISSING)) + + +def test_assemble_response_head(): +    c = assemble_response_head(tresp()) +    assert b"200" in c +    assert b"svalue" in c +    assert b"message" not in c + + +def test_assemble_body(): +    c = list(assemble_body(Headers(), [b"body"])) +    assert c == [b"body"] + +    c = list(assemble_body(Headers(transfer_encoding="chunked"), [b"123456789a", b""])) +    assert c == [b"a\r\n123456789a\r\n", b"0\r\n\r\n"] + +    c = list(assemble_body(Headers(transfer_encoding="chunked"), [b"123456789a"])) +    assert c == [b"a\r\n123456789a\r\n", b"0\r\n\r\n"] + + +def test_assemble_request_line(): +    assert _assemble_request_line(treq().data) == b"GET /path HTTP/1.1" + +    authority_request = treq(method=b"CONNECT", first_line_format="authority").data +    assert _assemble_request_line(authority_request) == b"CONNECT address:22 HTTP/1.1" + +    absolute_request = treq(first_line_format="absolute").data +    assert _assemble_request_line(absolute_request) == b"GET http://address:22/path HTTP/1.1" + +    with raises(RuntimeError): +        _assemble_request_line(treq(first_line_format="invalid_form").data) + + +def test_assemble_request_headers(): +    # https://github.com/mitmproxy/mitmproxy/issues/186 +    r = treq(content=b"") +    r.headers["Transfer-Encoding"] = "chunked" +    c = _assemble_request_headers(r.data) +    assert b"Transfer-Encoding" in c + + +def test_assemble_request_headers_host_header(): +    r = treq() +    r.headers = Headers() +    c = _assemble_request_headers(r.data) +    assert b"host" in c + +    r.host = None +    c = _assemble_request_headers(r.data) +    assert b"host" not in c + + +def test_assemble_response_headers(): +    # https://github.com/mitmproxy/mitmproxy/issues/186 +    r = tresp(content=b"") +    r.headers["Transfer-Encoding"] = "chunked" +    c = _assemble_response_headers(r) +    assert b"Transfer-Encoding" in c diff --git a/netlib/test/http/http1/test_read.py b/netlib/test/http/http1/test_read.py new file mode 100644 index 00000000..90234070 --- /dev/null +++ b/netlib/test/http/http1/test_read.py @@ -0,0 +1,333 @@ +from __future__ import absolute_import, print_function, division +from io import BytesIO +import textwrap +from mock import Mock +from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect, TcpDisconnect +from netlib.http import Headers +from netlib.http.http1.read import ( +    read_request, read_response, read_request_head, +    read_response_head, read_body, connection_close, expected_http_body_size, _get_first_line, +    _read_request_line, _parse_authority_form, _read_response_line, _check_http_version, +    _read_headers, _read_chunked +) +from netlib.tutils import treq, tresp, raises + + +def test_read_request(): +    rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip") +    r = read_request(rfile) +    assert r.method == "GET" +    assert r.content == b"" +    assert r.timestamp_end +    assert rfile.read() == b"skip" + + +def test_read_request_head(): +    rfile = BytesIO( +        b"GET / HTTP/1.1\r\n" +        b"Content-Length: 4\r\n" +        b"\r\n" +        b"skip" +    ) +    rfile.reset_timestamps = Mock() +    rfile.first_byte_timestamp = 42 +    r = read_request_head(rfile) +    assert r.method == "GET" +    assert r.headers["Content-Length"] == "4" +    assert r.content is None +    assert rfile.reset_timestamps.called +    assert r.timestamp_start == 42 +    assert rfile.read() == b"skip" + + +def test_read_response(): +    req = treq() +    rfile = BytesIO(b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody") +    r = read_response(rfile, req) +    assert r.status_code == 418 +    assert r.content == b"body" +    assert r.timestamp_end + + +def test_read_response_head(): +    rfile = BytesIO( +        b"HTTP/1.1 418 I'm a teapot\r\n" +        b"Content-Length: 4\r\n" +        b"\r\n" +        b"skip" +    ) +    rfile.reset_timestamps = Mock() +    rfile.first_byte_timestamp = 42 +    r = read_response_head(rfile) +    assert r.status_code == 418 +    assert r.headers["Content-Length"] == "4" +    assert r.content is None +    assert rfile.reset_timestamps.called +    assert r.timestamp_start == 42 +    assert rfile.read() == b"skip" + + +class TestReadBody(object): +    def test_chunked(self): +        rfile = BytesIO(b"3\r\nfoo\r\n0\r\n\r\nbar") +        body = b"".join(read_body(rfile, None)) +        assert body == b"foo" +        assert rfile.read() == b"bar" + +    def test_known_size(self): +        rfile = BytesIO(b"foobar") +        body = b"".join(read_body(rfile, 3)) +        assert body == b"foo" +        assert rfile.read() == b"bar" + +    def test_known_size_limit(self): +        rfile = BytesIO(b"foobar") +        with raises(HttpException): +            b"".join(read_body(rfile, 3, 2)) + +    def test_known_size_too_short(self): +        rfile = BytesIO(b"foo") +        with raises(HttpException): +            b"".join(read_body(rfile, 6)) + +    def test_unknown_size(self): +        rfile = BytesIO(b"foobar") +        body = b"".join(read_body(rfile, -1)) +        assert body == b"foobar" + +    def test_unknown_size_limit(self): +        rfile = BytesIO(b"foobar") +        with raises(HttpException): +            b"".join(read_body(rfile, -1, 3)) + +    def test_max_chunk_size(self): +        rfile = BytesIO(b"123456") +        assert list(read_body(rfile, -1, max_chunk_size=None)) == [b"123456"] +        rfile = BytesIO(b"123456") +        assert list(read_body(rfile, -1, max_chunk_size=1)) == [b"1", b"2", b"3", b"4", b"5", b"6"] + +def test_connection_close(): +    headers = Headers() +    assert connection_close(b"HTTP/1.0", headers) +    assert not connection_close(b"HTTP/1.1", headers) + +    headers["connection"] = "keep-alive" +    assert not connection_close(b"HTTP/1.1", headers) + +    headers["connection"] = "close" +    assert connection_close(b"HTTP/1.1", headers) + +    headers["connection"] = "foobar" +    assert connection_close(b"HTTP/1.0", headers) +    assert not connection_close(b"HTTP/1.1", headers) + +def test_expected_http_body_size(): +    # Expect: 100-continue +    assert expected_http_body_size( +        treq(headers=Headers(expect="100-continue", content_length="42")) +    ) == 0 + +    # http://tools.ietf.org/html/rfc7230#section-3.3 +    assert expected_http_body_size( +        treq(method=b"HEAD"), +        tresp(headers=Headers(content_length="42")) +    ) == 0 +    assert expected_http_body_size( +        treq(method=b"CONNECT"), +        tresp() +    ) == 0 +    for code in (100, 204, 304): +        assert expected_http_body_size( +            treq(), +            tresp(status_code=code) +        ) == 0 + +    # chunked +    assert expected_http_body_size( +        treq(headers=Headers(transfer_encoding="chunked")), +    ) is None + +    # explicit length +    for val in (b"foo", b"-7"): +        with raises(HttpSyntaxException): +            expected_http_body_size( +                treq(headers=Headers(content_length=val)) +            ) +    assert expected_http_body_size( +        treq(headers=Headers(content_length="42")) +    ) == 42 + +    # no length +    assert expected_http_body_size( +        treq(headers=Headers()) +    ) == 0 +    assert expected_http_body_size( +        treq(headers=Headers()), tresp(headers=Headers()) +    ) == -1 + + +def test_get_first_line(): +    rfile = BytesIO(b"foo\r\nbar") +    assert _get_first_line(rfile) == b"foo" + +    rfile = BytesIO(b"\r\nfoo\r\nbar") +    assert _get_first_line(rfile) == b"foo" + +    with raises(HttpReadDisconnect): +        rfile = BytesIO(b"") +        _get_first_line(rfile) + +    with raises(HttpReadDisconnect): +        rfile = Mock() +        rfile.readline.side_effect = TcpDisconnect +        _get_first_line(rfile) + + +def test_read_request_line(): +    def t(b): +        return _read_request_line(BytesIO(b)) + +    assert (t(b"GET / HTTP/1.1") == +            ("relative", b"GET", None, None, None, b"/", b"HTTP/1.1")) +    assert (t(b"OPTIONS * HTTP/1.1") == +            ("relative", b"OPTIONS", None, None, None, b"*", b"HTTP/1.1")) +    assert (t(b"CONNECT foo:42 HTTP/1.1") == +            ("authority", b"CONNECT", None, b"foo", 42, None, b"HTTP/1.1")) +    assert (t(b"GET http://foo:42/bar HTTP/1.1") == +            ("absolute", b"GET", b"http", b"foo", 42, b"/bar", b"HTTP/1.1")) + +    with raises(HttpSyntaxException): +        t(b"GET / WTF/1.1") +    with raises(HttpSyntaxException): +        t(b"this is not http") +    with raises(HttpReadDisconnect): +        t(b"") + +def test_parse_authority_form(): +    assert _parse_authority_form(b"foo:42") == (b"foo", 42) +    with raises(HttpSyntaxException): +        _parse_authority_form(b"foo") +    with raises(HttpSyntaxException): +        _parse_authority_form(b"foo:bar") +    with raises(HttpSyntaxException): +        _parse_authority_form(b"foo:99999999") +    with raises(HttpSyntaxException): +        _parse_authority_form(b"f\x00oo:80") + + +def test_read_response_line(): +    def t(b): +        return _read_response_line(BytesIO(b)) + +    assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK") +    assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"") + +    # https://github.com/mitmproxy/mitmproxy/issues/784 +    assert t(b"HTTP/1.1 200 Non-Autoris\xc3\xa9") == (b"HTTP/1.1", 200, b"Non-Autoris\xc3\xa9") + +    with raises(HttpSyntaxException): +        assert t(b"HTTP/1.1") + +    with raises(HttpSyntaxException): +        t(b"HTTP/1.1 OK OK") +    with raises(HttpSyntaxException): +        t(b"WTF/1.1 200 OK") +    with raises(HttpReadDisconnect): +        t(b"") + + +def test_check_http_version(): +    _check_http_version(b"HTTP/0.9") +    _check_http_version(b"HTTP/1.0") +    _check_http_version(b"HTTP/1.1") +    _check_http_version(b"HTTP/2.0") +    with raises(HttpSyntaxException): +        _check_http_version(b"WTF/1.0") +    with raises(HttpSyntaxException): +        _check_http_version(b"HTTP/1.10") +    with raises(HttpSyntaxException): +        _check_http_version(b"HTTP/1.b") + + +class TestReadHeaders(object): +    @staticmethod +    def _read(data): +        return _read_headers(BytesIO(data)) + +    def test_read_simple(self): +        data = ( +            b"Header: one\r\n" +            b"Header2: two\r\n" +            b"\r\n" +        ) +        headers = self._read(data) +        assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]] + +    def test_read_multi(self): +        data = ( +            b"Header: one\r\n" +            b"Header: two\r\n" +            b"\r\n" +        ) +        headers = self._read(data) +        assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]] + +    def test_read_continued(self): +        data = ( +            b"Header: one\r\n" +            b"\ttwo\r\n" +            b"Header2: three\r\n" +            b"\r\n" +        ) +        headers = self._read(data) +        assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]] + +    def test_read_continued_err(self): +        data = b"\tfoo: bar\r\n" +        with raises(HttpSyntaxException): +            self._read(data) + +    def test_read_err(self): +        data = b"foo" +        with raises(HttpSyntaxException): +            self._read(data) + +    def test_read_empty_name(self): +        data = b":foo" +        with raises(HttpSyntaxException): +            self._read(data) + +    def test_read_empty_value(self): +        data = b"bar:" +        headers = self._read(data) +        assert headers.fields == [[b"bar", b""]] + +def test_read_chunked(): +    req = treq(content=None) +    req.headers["Transfer-Encoding"] = "chunked" + +    data = b"1\r\na\r\n0\r\n" +    with raises(HttpSyntaxException): +        b"".join(_read_chunked(BytesIO(data))) + +    data = b"1\r\na\r\n0\r\n\r\n" +    assert b"".join(_read_chunked(BytesIO(data))) == b"a" + +    data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n" +    assert b"".join(_read_chunked(BytesIO(data))) == b"ab" + +    data = b"\r\n" +    with raises("closed prematurely"): +        b"".join(_read_chunked(BytesIO(data))) + +    data = b"1\r\nfoo" +    with raises("malformed chunked body"): +        b"".join(_read_chunked(BytesIO(data))) + +    data = b"foo\r\nfoo" +    with raises(HttpSyntaxException): +        b"".join(_read_chunked(BytesIO(data))) + +    data = b"5\r\naaaaa\r\n0\r\n\r\n" +    with raises("too large"): +        b"".join(_read_chunked(BytesIO(data), limit=2)) diff --git a/netlib/test/http/http2/__init__.py b/netlib/test/http/http2/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/netlib/test/http/http2/__init__.py diff --git a/netlib/test/http/http2/test_connections.py b/netlib/test/http/http2/test_connections.py new file mode 100644 index 00000000..a115fc7c --- /dev/null +++ b/netlib/test/http/http2/test_connections.py @@ -0,0 +1,540 @@ +import OpenSSL +import mock +import codecs + +from hyperframe.frame import * + +from netlib import tcp, http, utils, tservers +from netlib.tutils import raises +from netlib.exceptions import TcpDisconnect +from netlib.http.http2.connections import HTTP2Protocol, TCPHandler + + +class TestTCPHandlerWrapper: +    def test_wrapped(self): +        h = TCPHandler(rfile='foo', wfile='bar') +        p = HTTP2Protocol(h) +        assert p.tcp_handler.rfile == 'foo' +        assert p.tcp_handler.wfile == 'bar' + +    def test_direct(self): +        p = HTTP2Protocol(rfile='foo', wfile='bar') +        assert isinstance(p.tcp_handler, TCPHandler) +        assert p.tcp_handler.rfile == 'foo' +        assert p.tcp_handler.wfile == 'bar' + + +class EchoHandler(tcp.BaseHandler): +    sni = None + +    def handle(self): +        while True: +            v = self.rfile.safe_read(1) +            self.wfile.write(v) +            self.wfile.flush() + + +class TestProtocol: +    @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") +    @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") +    def test_perform_connection_preface(self, mock_client_method, mock_server_method): +        protocol = HTTP2Protocol(is_server=False) +        protocol.connection_preface_performed = True + +        protocol.perform_connection_preface() +        assert not mock_client_method.called +        assert not mock_server_method.called + +        protocol.perform_connection_preface(force=True) +        assert mock_client_method.called +        assert not mock_server_method.called + +    @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") +    @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") +    def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): +        protocol = HTTP2Protocol(is_server=True) +        protocol.connection_preface_performed = True + +        protocol.perform_connection_preface() +        assert not mock_client_method.called +        assert not mock_server_method.called + +        protocol.perform_connection_preface(force=True) +        assert not mock_client_method.called +        assert mock_server_method.called + + +class TestCheckALPNMatch(tservers.ServerTestBase): +    handler = EchoHandler +    ssl = dict( +        alpn_select=b'h2', +    ) + +    if OpenSSL._util.lib.Cryptography_HAS_ALPN: + +        def test_check_alpn(self): +            c = tcp.TCPClient(("127.0.0.1", self.port)) +            c.connect() +            c.convert_to_ssl(alpn_protos=[b'h2']) +            protocol = HTTP2Protocol(c) +            assert protocol.check_alpn() + + +class TestCheckALPNMismatch(tservers.ServerTestBase): +    handler = EchoHandler +    ssl = dict( +        alpn_select=None, +    ) + +    if OpenSSL._util.lib.Cryptography_HAS_ALPN: + +        def test_check_alpn(self): +            c = tcp.TCPClient(("127.0.0.1", self.port)) +            c.connect() +            c.convert_to_ssl(alpn_protos=[b'h2']) +            protocol = HTTP2Protocol(c) +            with raises(NotImplementedError): +                protocol.check_alpn() + + +class TestPerformServerConnectionPreface(tservers.ServerTestBase): +    class handler(tcp.BaseHandler): + +        def handle(self): +            # send magic +            self.wfile.write(codecs.decode('505249202a20485454502f322e300d0a0d0a534d0d0a0d0a', 'hex_codec')) +            self.wfile.flush() + +            # send empty settings frame +            self.wfile.write(codecs.decode('000000040000000000', 'hex_codec')) +            self.wfile.flush() + +            # check empty settings frame +            raw = utils.http2_read_raw_frame(self.rfile) +            assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec') + +            # check settings acknowledgement +            raw = utils.http2_read_raw_frame(self.rfile) +            assert raw == codecs.decode('000000040100000000', 'hex_codec') + +            # send settings acknowledgement +            self.wfile.write(codecs.decode('000000040100000000', 'hex_codec')) +            self.wfile.flush() + +    def test_perform_server_connection_preface(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        protocol = HTTP2Protocol(c) + +        assert not protocol.connection_preface_performed +        protocol.perform_server_connection_preface() +        assert protocol.connection_preface_performed + +        with raises(TcpDisconnect): +            protocol.perform_server_connection_preface(force=True) + + +class TestPerformClientConnectionPreface(tservers.ServerTestBase): +    class handler(tcp.BaseHandler): + +        def handle(self): +            # check magic +            assert self.rfile.read(24) == HTTP2Protocol.CLIENT_CONNECTION_PREFACE + +            # check empty settings frame +            assert self.rfile.read(9) ==\ +                codecs.decode('000000040000000000', 'hex_codec') + +            # send empty settings frame +            self.wfile.write(codecs.decode('000000040000000000', 'hex_codec')) +            self.wfile.flush() + +            # check settings acknowledgement +            assert self.rfile.read(9) == \ +                codecs.decode('000000040100000000', 'hex_codec') + +            # send settings acknowledgement +            self.wfile.write(codecs.decode('000000040100000000', 'hex_codec')) +            self.wfile.flush() + +    def test_perform_client_connection_preface(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        protocol = HTTP2Protocol(c) + +        assert not protocol.connection_preface_performed +        protocol.perform_client_connection_preface() +        assert protocol.connection_preface_performed + + +class TestClientStreamIds(object): +    c = tcp.TCPClient(("127.0.0.1", 0)) +    protocol = HTTP2Protocol(c) + +    def test_client_stream_ids(self): +        assert self.protocol.current_stream_id is None +        assert self.protocol._next_stream_id() == 1 +        assert self.protocol.current_stream_id == 1 +        assert self.protocol._next_stream_id() == 3 +        assert self.protocol.current_stream_id == 3 +        assert self.protocol._next_stream_id() == 5 +        assert self.protocol.current_stream_id == 5 + + +class TestServerStreamIds(object): +    c = tcp.TCPClient(("127.0.0.1", 0)) +    protocol = HTTP2Protocol(c, is_server=True) + +    def test_server_stream_ids(self): +        assert self.protocol.current_stream_id is None +        assert self.protocol._next_stream_id() == 2 +        assert self.protocol.current_stream_id == 2 +        assert self.protocol._next_stream_id() == 4 +        assert self.protocol.current_stream_id == 4 +        assert self.protocol._next_stream_id() == 6 +        assert self.protocol.current_stream_id == 6 + + +class TestApplySettings(tservers.ServerTestBase): +    class handler(tcp.BaseHandler): +        def handle(self): +            # check settings acknowledgement +            assert self.rfile.read(9) == codecs.decode('000000040100000000', 'hex_codec') +            self.wfile.write("OK") +            self.wfile.flush() +            self.rfile.safe_read(9)  # just to keep the connection alive a bit longer + +    ssl = True + +    def test_apply_settings(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        protocol = HTTP2Protocol(c) + +        protocol._apply_settings({ +            SettingsFrame.ENABLE_PUSH: 'foo', +            SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar', +            SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef', +        }) + +        assert c.rfile.safe_read(2) == b"OK" + +        assert protocol.http2_settings[ +            SettingsFrame.ENABLE_PUSH] == 'foo' +        assert protocol.http2_settings[ +            SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar' +        assert protocol.http2_settings[ +            SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef' + + +class TestCreateHeaders(object): +    c = tcp.TCPClient(("127.0.0.1", 0)) + +    def test_create_headers(self): +        headers = http.Headers([ +            (b':method', b'GET'), +            (b':path', b'index.html'), +            (b':scheme', b'https'), +            (b'foo', b'bar')]) + +        bytes = HTTP2Protocol(self.c)._create_headers( +            headers, 1, end_stream=True) +        assert b''.join(bytes) ==\ +            codecs.decode('000014010500000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec') + +        bytes = HTTP2Protocol(self.c)._create_headers( +            headers, 1, end_stream=False) +        assert b''.join(bytes) ==\ +            codecs.decode('000014010400000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec') + +    def test_create_headers_multiple_frames(self): +        headers = http.Headers([ +            (b':method', b'GET'), +            (b':path', b'/'), +            (b':scheme', b'https'), +            (b'foo', b'bar'), +            (b'server', b'version')]) + +        protocol = HTTP2Protocol(self.c) +        protocol.http2_settings[SettingsFrame.MAX_FRAME_SIZE] = 8 +        bytes = protocol._create_headers(headers, 1, end_stream=True) +        assert len(bytes) == 3 +        assert bytes[0] == codecs.decode('000008010100000001828487408294e783', 'hex_codec') +        assert bytes[1] == codecs.decode('0000080900000000018c767f7685ee5b10', 'hex_codec') +        assert bytes[2] == codecs.decode('00000209040000000163d5', 'hex_codec') + + +class TestCreateBody(object): +    c = tcp.TCPClient(("127.0.0.1", 0)) + +    def test_create_body_empty(self): +        protocol = HTTP2Protocol(self.c) +        bytes = protocol._create_body(b'', 1) +        assert b''.join(bytes) == b'' + +    def test_create_body_single_frame(self): +        protocol = HTTP2Protocol(self.c) +        bytes = protocol._create_body(b'foobar', 1) +        assert b''.join(bytes) == codecs.decode('000006000100000001666f6f626172', 'hex_codec') + +    def test_create_body_multiple_frames(self): +        protocol = HTTP2Protocol(self.c) +        protocol.http2_settings[SettingsFrame.MAX_FRAME_SIZE] = 5 +        bytes = protocol._create_body(b'foobarmehm42', 1) +        assert len(bytes) == 3 +        assert bytes[0] == codecs.decode('000005000000000001666f6f6261', 'hex_codec') +        assert bytes[1] == codecs.decode('000005000000000001726d65686d', 'hex_codec') +        assert bytes[2] == codecs.decode('0000020001000000013432', 'hex_codec') + + +class TestReadRequest(tservers.ServerTestBase): +    class handler(tcp.BaseHandler): + +        def handle(self): +            self.wfile.write( +                codecs.decode('000003010400000001828487', 'hex_codec')) +            self.wfile.write( +                codecs.decode('000006000100000001666f6f626172', 'hex_codec')) +            self.wfile.flush() +            self.rfile.safe_read(9)  # just to keep the connection alive a bit longer + +    ssl = True + +    def test_read_request(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        protocol = HTTP2Protocol(c, is_server=True) +        protocol.connection_preface_performed = True + +        req = protocol.read_request(NotImplemented) + +        assert req.stream_id +        assert req.headers.fields == [[b':method', b'GET'], [b':path', b'/'], [b':scheme', b'https']] +        assert req.content == b'foobar' + + +class TestReadRequestRelative(tservers.ServerTestBase): +    class handler(tcp.BaseHandler): +        def handle(self): +            self.wfile.write( +                codecs.decode('00000c0105000000014287d5af7e4d5a777f4481f9', 'hex_codec')) +            self.wfile.flush() + +    ssl = True + +    def test_asterisk_form_in(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        protocol = HTTP2Protocol(c, is_server=True) +        protocol.connection_preface_performed = True + +        req = protocol.read_request(NotImplemented) + +        assert req.form_in == "relative" +        assert req.method == "OPTIONS" +        assert req.path == "*" + + +class TestReadRequestAbsolute(tservers.ServerTestBase): +    class handler(tcp.BaseHandler): +        def handle(self): +            self.wfile.write( +                codecs.decode('00001901050000000182448d9d29aee30c0e492c2a1170426366871c92585422e085', 'hex_codec')) +            self.wfile.flush() + +    ssl = True + +    def test_absolute_form_in(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        protocol = HTTP2Protocol(c, is_server=True) +        protocol.connection_preface_performed = True + +        req = protocol.read_request(NotImplemented) + +        assert req.form_in == "absolute" +        assert req.scheme == "http" +        assert req.host == "address" +        assert req.port == 22 + + +class TestReadRequestConnect(tservers.ServerTestBase): +    class handler(tcp.BaseHandler): +        def handle(self): +            self.wfile.write( +                codecs.decode('00001b0105000000014287bdab4e9c17b7ff44871c92585422e08541871c92585422e085', 'hex_codec')) +            self.wfile.write( +                codecs.decode('00001d0105000000014287bdab4e9c17b7ff44882f91d35d055c87a741882f91d35d055c87a7', 'hex_codec')) +            self.wfile.flush() + +    ssl = True + +    def test_connect(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        protocol = HTTP2Protocol(c, is_server=True) +        protocol.connection_preface_performed = True + +        req = protocol.read_request(NotImplemented) +        assert req.form_in == "authority" +        assert req.method == "CONNECT" +        assert req.host == "address" +        assert req.port == 22 + +        req = protocol.read_request(NotImplemented) +        assert req.form_in == "authority" +        assert req.method == "CONNECT" +        assert req.host == "example.com" +        assert req.port == 443 + + +class TestReadResponse(tservers.ServerTestBase): +    class handler(tcp.BaseHandler): +        def handle(self): +            self.wfile.write( +                codecs.decode('00000801040000002a88628594e78c767f', 'hex_codec')) +            self.wfile.write( +                codecs.decode('00000600010000002a666f6f626172', 'hex_codec')) +            self.wfile.flush() +            self.rfile.safe_read(9)  # just to keep the connection alive a bit longer + +    ssl = True + +    def test_read_response(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        protocol = HTTP2Protocol(c) +        protocol.connection_preface_performed = True + +        resp = protocol.read_response(NotImplemented, stream_id=42) + +        assert resp.http_version == "HTTP/2.0" +        assert resp.status_code == 200 +        assert resp.msg == '' +        assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] +        assert resp.content == b'foobar' +        assert resp.timestamp_end + + +class TestReadEmptyResponse(tservers.ServerTestBase): +    class handler(tcp.BaseHandler): +        def handle(self): +            self.wfile.write( +                codecs.decode('00000801050000002a88628594e78c767f', 'hex_codec')) +            self.wfile.flush() + +    ssl = True + +    def test_read_empty_response(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        protocol = HTTP2Protocol(c) +        protocol.connection_preface_performed = True + +        resp = protocol.read_response(NotImplemented, stream_id=42) + +        assert resp.stream_id == 42 +        assert resp.http_version == "HTTP/2.0" +        assert resp.status_code == 200 +        assert resp.msg == '' +        assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] +        assert resp.content == b'' + + +class TestAssembleRequest(object): +    c = tcp.TCPClient(("127.0.0.1", 0)) + +    def test_request_simple(self): +        bytes = HTTP2Protocol(self.c).assemble_request(http.Request( +            b'', +            b'GET', +            b'https', +            b'', +            b'', +            b'/', +            b"HTTP/2.0", +            None, +            None, +        )) +        assert len(bytes) == 1 +        assert bytes[0] == codecs.decode('00000d0105000000018284874188089d5c0b8170dc07', 'hex_codec') + +    def test_request_with_stream_id(self): +        req = http.Request( +            b'', +            b'GET', +            b'https', +            b'', +            b'', +            b'/', +            b"HTTP/2.0", +            None, +            None, +        ) +        req.stream_id = 0x42 +        bytes = HTTP2Protocol(self.c).assemble_request(req) +        assert len(bytes) == 1 +        assert bytes[0] == codecs.decode('00000d0105000000428284874188089d5c0b8170dc07', 'hex_codec') + +    def test_request_with_body(self): +        bytes = HTTP2Protocol(self.c).assemble_request(http.Request( +            b'', +            b'GET', +            b'https', +            b'', +            b'', +            b'/', +            b"HTTP/2.0", +            http.Headers([(b'foo', b'bar')]), +            b'foobar', +        )) +        assert len(bytes) == 2 +        assert bytes[0] ==\ +            codecs.decode('0000150104000000018284874188089d5c0b8170dc07408294e7838c767f', 'hex_codec') +        assert bytes[1] ==\ +            codecs.decode('000006000100000001666f6f626172', 'hex_codec') + + +class TestAssembleResponse(object): +    c = tcp.TCPClient(("127.0.0.1", 0)) + +    def test_simple(self): +        bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( +            b"HTTP/2.0", +            200, +        )) +        assert len(bytes) == 1 +        assert bytes[0] ==\ +            codecs.decode('00000101050000000288', 'hex_codec') + +    def test_with_stream_id(self): +        resp = http.Response( +            b"HTTP/2.0", +            200, +        ) +        resp.stream_id = 0x42 +        bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp) +        assert len(bytes) == 1 +        assert bytes[0] ==\ +            codecs.decode('00000101050000004288', 'hex_codec') + +    def test_with_body(self): +        bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( +            b"HTTP/2.0", +            200, +            b'', +            http.Headers(foo=b"bar"), +            b'foobar' +        )) +        assert len(bytes) == 2 +        assert bytes[0] ==\ +            codecs.decode('00000901040000000288408294e7838c767f', 'hex_codec') +        assert bytes[1] ==\ +            codecs.decode('000006000100000002666f6f626172', 'hex_codec') diff --git a/netlib/test/http/test_authentication.py b/netlib/test/http/test_authentication.py new file mode 100644 index 00000000..1df7cd9c --- /dev/null +++ b/netlib/test/http/test_authentication.py @@ -0,0 +1,122 @@ +import binascii + +from netlib import tutils +from netlib.http import authentication, Headers + + +def test_parse_http_basic_auth(): +    vals = ("basic", "foo", "bar") +    assert authentication.parse_http_basic_auth( +        authentication.assemble_http_basic_auth(*vals) +    ) == vals +    assert not authentication.parse_http_basic_auth("") +    assert not authentication.parse_http_basic_auth("foo bar") +    v = "basic " + binascii.b2a_base64(b"foo").decode("ascii") +    assert not authentication.parse_http_basic_auth(v) + + +class TestPassManNonAnon: + +    def test_simple(self): +        p = authentication.PassManNonAnon() +        assert not p.test("", "") +        assert p.test("user", "") + + +class TestPassManHtpasswd: + +    def test_file_errors(self): +        tutils.raises( +            "malformed htpasswd file", +            authentication.PassManHtpasswd, +            tutils.test_data.path("data/server.crt")) + +    def test_simple(self): +        pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) + +        vals = ("basic", "test", "test") +        authentication.assemble_http_basic_auth(*vals) +        assert pm.test("test", "test") +        assert not pm.test("test", "foo") +        assert not pm.test("foo", "test") +        assert not pm.test("test", "") +        assert not pm.test("", "") + + +class TestPassManSingleUser: + +    def test_simple(self): +        pm = authentication.PassManSingleUser("test", "test") +        assert pm.test("test", "test") +        assert not pm.test("test", "foo") +        assert not pm.test("foo", "test") + + +class TestNullProxyAuth: + +    def test_simple(self): +        na = authentication.NullProxyAuth(authentication.PassManNonAnon()) +        assert not na.auth_challenge_headers() +        assert na.authenticate("foo") +        na.clean({}) + + +class TestBasicProxyAuth: + +    def test_simple(self): +        ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") +        headers = Headers() +        assert ba.auth_challenge_headers() +        assert not ba.authenticate(headers) + +    def test_authenticate_clean(self): +        ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") + +        headers = Headers() +        vals = ("basic", "foo", "bar") +        headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) +        assert ba.authenticate(headers) + +        ba.clean(headers) +        assert not ba.AUTH_HEADER in headers + +        headers[ba.AUTH_HEADER] = "" +        assert not ba.authenticate(headers) + +        headers[ba.AUTH_HEADER] = "foo" +        assert not ba.authenticate(headers) + +        vals = ("foo", "foo", "bar") +        headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) +        assert not ba.authenticate(headers) + +        ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") +        vals = ("basic", "foo", "bar") +        headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) +        assert not ba.authenticate(headers) + + +class Bunch: +    pass + + +class TestAuthAction: + +    def test_nonanonymous(self): +        m = Bunch() +        aa = authentication.NonanonymousAuthAction(None, "authenticator") +        aa(None, m, None, None) +        assert m.authenticator + +    def test_singleuser(self): +        m = Bunch() +        aa = authentication.SingleuserAuthAction(None, "authenticator") +        aa(None, m, "foo:bar", None) +        assert m.authenticator +        tutils.raises("invalid", aa, None, m, "foo", None) + +    def test_httppasswd(self): +        m = Bunch() +        aa = authentication.HtpasswdAuthAction(None, "authenticator") +        aa(None, m, tutils.test_data.path("data/htpasswd"), None) +        assert m.authenticator diff --git a/netlib/test/http/test_cookies.py b/netlib/test/http/test_cookies.py new file mode 100644 index 00000000..34bb64f2 --- /dev/null +++ b/netlib/test/http/test_cookies.py @@ -0,0 +1,218 @@ +from netlib.http import cookies + + +def test_read_token(): +    tokens = [ +        [("foo", 0), ("foo", 3)], +        [("foo", 1), ("oo", 3)], +        [(" foo", 1), ("foo", 4)], +        [(" foo;", 1), ("foo", 4)], +        [(" foo=", 1), ("foo", 4)], +        [(" foo=bar", 1), ("foo", 4)], +    ] +    for q, a in tokens: +        assert cookies._read_token(*q) == a + + +def test_read_quoted_string(): +    tokens = [ +        [('"foo" x', 0), ("foo", 5)], +        [('"f\oo" x', 0), ("foo", 6)], +        [(r'"f\\o" x', 0), (r"f\o", 6)], +        [(r'"f\\" x', 0), (r"f" + '\\', 5)], +        [('"fo\\\"" x', 0), ("fo\"", 6)], +        [('"foo" x', 7), ("", 8)], +    ] +    for q, a in tokens: +        assert cookies._read_quoted_string(*q) == a + + +def test_read_pairs(): +    vals = [ +        [ +            "one", +            [["one", None]] +        ], +        [ +            "one=two", +            [["one", "two"]] +        ], +        [ +            "one=", +            [["one", ""]] +        ], +        [ +            'one="two"', +            [["one", "two"]] +        ], +        [ +            'one="two"; three=four', +            [["one", "two"], ["three", "four"]] +        ], +        [ +            'one="two"; three=four; five', +            [["one", "two"], ["three", "four"], ["five", None]] +        ], +        [ +            'one="\\"two"; three=four', +            [["one", '"two'], ["three", "four"]] +        ], +    ] +    for s, lst in vals: +        ret, off = cookies._read_pairs(s) +        assert ret == lst + + +def test_pairs_roundtrips(): +    pairs = [ +        [ +            "", +            [] +        ], +        [ +            "one=uno", +            [["one", "uno"]] +        ], +        [ +            "one", +            [["one", None]] +        ], +        [ +            "one=uno; two=due", +            [["one", "uno"], ["two", "due"]] +        ], +        [ +            'one="uno"; two="\due"', +            [["one", "uno"], ["two", "due"]] +        ], +        [ +            'one="un\\"o"', +            [["one", 'un"o']] +        ], +        [ +            'one="uno,due"', +            [["one", 'uno,due']] +        ], +        [ +            "one=uno; two; three=tre", +            [["one", "uno"], ["two", None], ["three", "tre"]] +        ], +        [ +            "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " +            "_rcc2=53VdltWl+Ov6ordflA==;", +            [ +                ["_lvs2", "zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g="], +                ["_rcc2", "53VdltWl+Ov6ordflA=="] +            ] +        ] +    ] +    for s, lst in pairs: +        ret, off = cookies._read_pairs(s) +        assert ret == lst +        s2 = cookies._format_pairs(lst) +        ret, off = cookies._read_pairs(s2) +        assert ret == lst + + +def test_cookie_roundtrips(): +    pairs = [ +        [ +            "one=uno", +            [["one", "uno"]] +        ], +        [ +            "one=uno; two=due", +            [["one", "uno"], ["two", "due"]] +        ], +    ] +    for s, lst in pairs: +        ret = cookies.parse_cookie_header(s) +        assert ret.lst == lst +        s2 = cookies.format_cookie_header(ret) +        ret = cookies.parse_cookie_header(s2) +        assert ret.lst == lst + + +def test_parse_set_cookie_pairs(): +    pairs = [ +        [ +            "one=uno", +            [ +                ["one", "uno"] +            ] +        ], +        [ +            "one=un\x20", +            [ +                ["one", "un\x20"] +            ] +        ], +        [ +            "one=uno; foo", +            [ +                ["one", "uno"], +                ["foo", None] +            ] +        ], +        [ +            "mun=1.390.f60; " +            "expires=sun, 11-oct-2015 12:38:31 gmt; path=/; " +            "domain=b.aol.com", +            [ +                ["mun", "1.390.f60"], +                ["expires", "sun, 11-oct-2015 12:38:31 gmt"], +                ["path", "/"], +                ["domain", "b.aol.com"] +            ] +        ], +        [ +            r'rpb=190%3d1%2616726%3d1%2634832%3d1%2634874%3d1; ' +            'domain=.rubiconproject.com; ' +            'expires=mon, 11-may-2015 21:54:57 gmt; ' +            'path=/', +            [ +                ['rpb', r'190%3d1%2616726%3d1%2634832%3d1%2634874%3d1'], +                ['domain', '.rubiconproject.com'], +                ['expires', 'mon, 11-may-2015 21:54:57 gmt'], +                ['path', '/'] +            ] +        ], +    ] +    for s, lst in pairs: +        ret = cookies._parse_set_cookie_pairs(s) +        assert ret == lst +        s2 = cookies._format_set_cookie_pairs(ret) +        ret2 = cookies._parse_set_cookie_pairs(s2) +        assert  ret2 == lst + + +def test_parse_set_cookie_header(): +    vals = [ +        [ +            "", None +        ], +        [ +            ";", None +        ], +        [ +            "one=uno", +            ("one", "uno", []) +        ], +        [ +            "one=uno; foo=bar", +            ("one", "uno", [["foo", "bar"]]) +        ] +    ] +    for s, expected in vals: +        ret = cookies.parse_set_cookie_header(s) +        if expected: +            assert ret[0] == expected[0] +            assert ret[1] == expected[1] +            assert ret[2].lst == expected[2] +            s2 = cookies.format_set_cookie_header(*ret) +            ret2 = cookies.parse_set_cookie_header(s2) +            assert ret2[0] == expected[0] +            assert ret2[1] == expected[1] +            assert ret2[2].lst == expected[2] +        else: +            assert ret is None diff --git a/netlib/test/http/test_headers.py b/netlib/test/http/test_headers.py new file mode 100644 index 00000000..d50fee3e --- /dev/null +++ b/netlib/test/http/test_headers.py @@ -0,0 +1,152 @@ +from netlib.http import Headers +from netlib.tutils import raises + + +class TestHeaders(object): +    def _2host(self): +        return Headers( +            [ +                [b"Host", b"example.com"], +                [b"host", b"example.org"] +            ] +        ) + +    def test_init(self): +        headers = Headers() +        assert len(headers) == 0 + +        headers = Headers([[b"Host", b"example.com"]]) +        assert len(headers) == 1 +        assert headers["Host"] == "example.com" + +        headers = Headers(Host="example.com") +        assert len(headers) == 1 +        assert headers["Host"] == "example.com" + +        headers = Headers( +            [[b"Host", b"invalid"]], +            Host="example.com" +        ) +        assert len(headers) == 1 +        assert headers["Host"] == "example.com" + +        headers = Headers( +            [[b"Host", b"invalid"], [b"Accept", b"text/plain"]], +            Host="example.com" +        ) +        assert len(headers) == 2 +        assert headers["Host"] == "example.com" +        assert headers["Accept"] == "text/plain" + +        with raises(ValueError): +            Headers([[b"Host", u"not-bytes"]]) + +    def test_getitem(self): +        headers = Headers(Host="example.com") +        assert headers["Host"] == "example.com" +        assert headers["host"] == "example.com" +        with raises(KeyError): +            _ = headers["Accept"] + +        headers = self._2host() +        assert headers["Host"] == "example.com, example.org" + +    def test_str(self): +        headers = Headers(Host="example.com") +        assert bytes(headers) == b"Host: example.com\r\n" + +        headers = Headers([ +            [b"Host", b"example.com"], +            [b"Accept", b"text/plain"] +        ]) +        assert bytes(headers) == b"Host: example.com\r\nAccept: text/plain\r\n" + +        headers = Headers() +        assert bytes(headers) == b"" + +    def test_setitem(self): +        headers = Headers() +        headers["Host"] = "example.com" +        assert "Host" in headers +        assert "host" in headers +        assert headers["Host"] == "example.com" + +        headers["host"] = "example.org" +        assert "Host" in headers +        assert "host" in headers +        assert headers["Host"] == "example.org" + +        headers["accept"] = "text/plain" +        assert len(headers) == 2 +        assert "Accept" in headers +        assert "Host" in headers + +        headers = self._2host() +        assert len(headers.fields) == 2 +        headers["Host"] = "example.com" +        assert len(headers.fields) == 1 +        assert "Host" in headers + +    def test_delitem(self): +        headers = Headers(Host="example.com") +        assert len(headers) == 1 +        del headers["host"] +        assert len(headers) == 0 +        try: +            del headers["host"] +        except KeyError: +            assert True +        else: +            assert False + +        headers = self._2host() +        del headers["Host"] +        assert len(headers) == 0 + +    def test_keys(self): +        headers = Headers(Host="example.com") +        assert list(headers.keys()) == ["Host"] + +        headers = self._2host() +        assert list(headers.keys()) == ["Host"] + +    def test_eq_ne(self): +        headers1 = Headers(Host="example.com") +        headers2 = Headers(host="example.com") +        assert not (headers1 == headers2) +        assert headers1 != headers2 + +        headers1 = Headers(Host="example.com") +        headers2 = Headers(Host="example.com") +        assert headers1 == headers2 +        assert not (headers1 != headers2) + +        assert headers1 != 42 + +    def test_get_all(self): +        headers = self._2host() +        assert headers.get_all("host") == ["example.com", "example.org"] +        assert headers.get_all("accept") == [] + +    def test_set_all(self): +        headers = Headers(Host="example.com") +        headers.set_all("Accept", ["text/plain"]) +        assert len(headers) == 2 +        assert "accept" in headers + +        headers = self._2host() +        headers.set_all("Host", ["example.org"]) +        assert headers["host"] == "example.org" + +        headers.set_all("Host", ["example.org", "example.net"]) +        assert headers["host"] == "example.org, example.net" + +    def test_state(self): +        headers = self._2host() +        assert len(headers.get_state()) == 2 +        assert headers == Headers.from_state(headers.get_state()) + +        headers2 = Headers() +        assert headers != headers2 +        headers2.set_state(headers.get_state()) +        assert headers == headers2 diff --git a/netlib/test/http/test_message.py b/netlib/test/http/test_message.py new file mode 100644 index 00000000..4b1f4630 --- /dev/null +++ b/netlib/test/http/test_message.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, division + +from netlib.http import decoded, Headers +from netlib.tutils import tresp, raises + + +def _test_passthrough_attr(message, attr): +    assert getattr(message, attr) == getattr(message.data, attr) +    setattr(message, attr, "foo") +    assert getattr(message.data, attr) == "foo" + + +def _test_decoded_attr(message, attr): +    assert getattr(message, attr) == getattr(message.data, attr).decode("utf8") +    # Set str, get raw bytes +    setattr(message, attr, "foo") +    assert getattr(message.data, attr) == b"foo" +    # Set raw bytes, get decoded +    setattr(message.data, attr, b"BAR")  # use uppercase so that we can also cover request.method +    assert getattr(message, attr) == "BAR" +    # Set bytes, get raw bytes +    setattr(message, attr, b"baz") +    assert getattr(message.data, attr) == b"baz" + +    # Set UTF8 +    setattr(message, attr, "Non-Autorisé") +    assert getattr(message.data, attr) == b"Non-Autoris\xc3\xa9" +    # Don't fail on garbage +    setattr(message.data, attr, b"FOO\xFF\x00BAR") +    assert getattr(message, attr).startswith("FOO") +    assert getattr(message, attr).endswith("BAR") +    # foo.bar = foo.bar should not cause any side effects. +    d = getattr(message, attr) +    setattr(message, attr, d) +    assert getattr(message.data, attr) == b"FOO\xFF\x00BAR" + + +class TestMessageData(object): +    def test_eq_ne(self): +        data = tresp(timestamp_start=42, timestamp_end=42).data +        same = tresp(timestamp_start=42, timestamp_end=42).data +        assert data == same +        assert not data != same + +        other = tresp(content=b"foo").data +        assert not data == other +        assert data != other + +        assert data != 0 + + +class TestMessage(object): + +    def test_init(self): +        resp = tresp() +        assert resp.data + +    def test_eq_ne(self): +        resp = tresp(timestamp_start=42, timestamp_end=42) +        same = tresp(timestamp_start=42, timestamp_end=42) +        assert resp == same +        assert not resp != same + +        other = tresp(timestamp_start=0, timestamp_end=0) +        assert not resp == other +        assert resp != other + +        assert resp != 0 + +    def test_content_length_update(self): +        resp = tresp() +        resp.content = b"foo" +        assert resp.data.content == b"foo" +        assert resp.headers["content-length"] == "3" +        resp.content = b"" +        assert resp.data.content == b"" +        assert resp.headers["content-length"] == "0" + +    def test_content_basic(self): +        _test_passthrough_attr(tresp(), "content") + +    def test_headers(self): +        _test_passthrough_attr(tresp(), "headers") + +    def test_timestamp_start(self): +        _test_passthrough_attr(tresp(), "timestamp_start") + +    def test_timestamp_end(self): +        _test_passthrough_attr(tresp(), "timestamp_end") + +    def teste_http_version(self): +        _test_decoded_attr(tresp(), "http_version") + + +class TestDecodedDecorator(object): + +    def test_simple(self): +        r = tresp() +        assert r.content == b"message" +        assert "content-encoding" not in r.headers +        assert r.encode("gzip") + +        assert r.headers["content-encoding"] +        assert r.content != b"message" +        with decoded(r): +            assert "content-encoding" not in r.headers +            assert r.content == b"message" +        assert r.headers["content-encoding"] +        assert r.content != b"message" + +    def test_modify(self): +        r = tresp() +        assert "content-encoding" not in r.headers +        assert r.encode("gzip") + +        with decoded(r): +            r.content = b"foo" + +        assert r.content != b"foo" +        r.decode() +        assert r.content == b"foo" + +    def test_unknown_ce(self): +        r = tresp() +        r.headers["content-encoding"] = "zopfli" +        r.content = b"foo" +        with decoded(r): +            assert r.headers["content-encoding"] +            assert r.content == b"foo" +        assert r.headers["content-encoding"] +        assert r.content == b"foo" + +    def test_cannot_decode(self): +        r = tresp() +        assert r.encode("gzip") +        r.content = b"foo" +        with decoded(r): +            assert r.headers["content-encoding"] +            assert r.content == b"foo" +        assert r.headers["content-encoding"] +        assert r.content != b"foo" +        r.decode() +        assert r.content == b"foo" + +    def test_cannot_encode(self): +        r = tresp() +        assert r.encode("gzip") +        with decoded(r): +            r.content = None + +        assert "content-encoding" not in r.headers +        assert r.content is None diff --git a/netlib/test/http/test_request.py b/netlib/test/http/test_request.py new file mode 100644 index 00000000..900b2cd1 --- /dev/null +++ b/netlib/test/http/test_request.py @@ -0,0 +1,238 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, division + +import six + +from netlib import utils +from netlib.http import Headers +from netlib.odict import ODict +from netlib.tutils import treq, raises +from .test_message import _test_decoded_attr, _test_passthrough_attr + + +class TestRequestData(object): +    def test_init(self): +        with raises(ValueError if six.PY2 else TypeError): +            treq(headers="foobar") + +        assert isinstance(treq(headers=None).headers, Headers) + + +class TestRequestCore(object): +    """ +    Tests for builtins and the attributes that are directly proxied from the data structure +    """ +    def test_repr(self): +        request = treq() +        assert repr(request) == "Request(GET address:22/path)" +        request.host = None +        assert repr(request) == "Request(GET /path)" + +    def test_first_line_format(self): +        _test_passthrough_attr(treq(), "first_line_format") + +    def test_method(self): +        _test_decoded_attr(treq(), "method") + +    def test_scheme(self): +        _test_decoded_attr(treq(), "scheme") + +    def test_port(self): +        _test_passthrough_attr(treq(), "port") + +    def test_path(self): +        _test_decoded_attr(treq(), "path") + +    def test_host(self): +        if six.PY2: +            from unittest import SkipTest +            raise SkipTest() + +        request = treq() +        assert request.host == request.data.host.decode("idna") + +        # Test IDNA encoding +        # Set str, get raw bytes +        request.host = "Ãdna.example" +        assert request.data.host == b"xn--dna-qma.example" +        # Set raw bytes, get decoded +        request.data.host = b"xn--idn-gla.example" +        assert request.host == "idná.example" +        # Set bytes, get raw bytes +        request.host = b"xn--dn-qia9b.example" +        assert request.data.host == b"xn--dn-qia9b.example" +        # IDNA encoding is not bijective +        request.host = "fußball" +        assert request.host == "fussball" + +        # Don't fail on garbage +        request.data.host = b"foo\xFF\x00bar" +        assert request.host.startswith("foo") +        assert request.host.endswith("bar") +        # foo.bar = foo.bar should not cause any side effects. +        d = request.host +        request.host = d +        assert request.data.host == b"foo\xFF\x00bar" + +    def test_host_header_update(self): +        request = treq() +        assert "host" not in request.headers +        request.host = "example.com" +        assert "host" not in request.headers + +        request.headers["Host"] = "foo" +        request.host = "example.org" +        assert request.headers["Host"] == "example.org" + + +class TestRequestUtils(object): +    """ +    Tests for additional convenience methods. +    """ +    def test_url(self): +        request = treq() +        assert request.url == "http://address:22/path" + +        request.url = "https://otheraddress:42/foo" +        assert request.scheme == "https" +        assert request.host == "otheraddress" +        assert request.port == 42 +        assert request.path == "/foo" + +        with raises(ValueError): +            request.url = "not-a-url" + +    def test_pretty_host(self): +        request = treq() +        assert request.pretty_host == "address" +        assert request.host == "address" +        request.headers["host"] = "other" +        assert request.pretty_host == "other" +        assert request.host == "address" +        request.host = None +        assert request.pretty_host is None +        assert request.host is None + +        # Invalid IDNA +        request.headers["host"] = ".disqus.com" +        assert request.pretty_host == ".disqus.com" + +    def test_pretty_url(self): +        request = treq() +        assert request.url == "http://address:22/path" +        assert request.pretty_url == "http://address:22/path" +        request.headers["host"] = "other" +        assert request.pretty_url == "http://other:22/path" + +    def test_pretty_url_authority(self): +        request = treq(first_line_format="authority") +        assert request.pretty_url == "address:22" + +    def test_get_query(self): +        request = treq() +        assert request.query is None + +        request.url = "http://localhost:80/foo?bar=42" +        assert request.query.lst == [("bar", "42")] + +    def test_set_query(self): +        request = treq() +        request.query = ODict([]) + +    def test_get_cookies_none(self): +        request = treq() +        request.headers = Headers() +        assert len(request.cookies) == 0 + +    def test_get_cookies_single(self): +        request = treq() +        request.headers = Headers(cookie="cookiename=cookievalue") +        result = request.cookies +        assert len(result) == 1 +        assert result['cookiename'] == ['cookievalue'] + +    def test_get_cookies_double(self): +        request = treq() +        request.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") +        result = request.cookies +        assert len(result) == 2 +        assert result['cookiename'] == ['cookievalue'] +        assert result['othercookiename'] == ['othercookievalue'] + +    def test_get_cookies_withequalsign(self): +        request = treq() +        request.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") +        result = request.cookies +        assert len(result) == 2 +        assert result['cookiename'] == ['coo=kievalue'] +        assert result['othercookiename'] == ['othercookievalue'] + +    def test_set_cookies(self): +        request = treq() +        request.headers = Headers(cookie="cookiename=cookievalue") +        result = request.cookies +        result["cookiename"] = ["foo"] +        request.cookies = result +        assert request.cookies["cookiename"] == ["foo"] + +    def test_get_path_components(self): +        request = treq(path=b"/foo/bar") +        assert request.path_components == ["foo", "bar"] + +    def test_set_path_components(self): +        request = treq() +        request.path_components = ["foo", "baz"] +        assert request.path == "/foo/baz" +        request.path_components = [] +        assert request.path == "/" + +    def test_anticache(self): +        request = treq() +        request.headers["If-Modified-Since"] = "foo" +        request.headers["If-None-Match"] = "bar" +        request.anticache() +        assert "If-Modified-Since" not in request.headers +        assert "If-None-Match" not in request.headers + +    def test_anticomp(self): +        request = treq() +        request.headers["Accept-Encoding"] = "foobar" +        request.anticomp() +        assert request.headers["Accept-Encoding"] == "identity" + +    def test_constrain_encoding(self): +        request = treq() + +        h = request.headers.copy() +        request.constrain_encoding()  # no-op if there is no accept_encoding header. +        assert request.headers == h + +        request.headers["Accept-Encoding"] = "identity, gzip, foo" +        request.constrain_encoding() +        assert "foo" not in request.headers["Accept-Encoding"] +        assert "gzip" in request.headers["Accept-Encoding"] + +    def test_get_urlencoded_form(self): +        request = treq(content="foobar") +        assert request.urlencoded_form is None + +        request.headers["Content-Type"] = "application/x-www-form-urlencoded" +        assert request.urlencoded_form == ODict(utils.urldecode(request.content)) + +    def test_set_urlencoded_form(self): +        request = treq() +        request.urlencoded_form = ODict([('foo', 'bar'), ('rab', 'oof')]) +        assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" +        assert request.content + +    def test_get_multipart_form(self): +        request = treq(content="foobar") +        assert request.multipart_form is None + +        request.headers["Content-Type"] = "multipart/form-data" +        assert request.multipart_form == ODict( +            utils.multipartdecode( +                request.headers, +                request.content +            ) +        ) diff --git a/netlib/test/http/test_response.py b/netlib/test/http/test_response.py new file mode 100644 index 00000000..14588000 --- /dev/null +++ b/netlib/test/http/test_response.py @@ -0,0 +1,102 @@ +from __future__ import absolute_import, print_function, division + +import six + +from netlib.http import Headers +from netlib.odict import ODict, ODictCaseless +from netlib.tutils import raises, tresp +from .test_message import _test_passthrough_attr, _test_decoded_attr + + +class TestResponseData(object): +    def test_init(self): +        with raises(ValueError if six.PY2 else TypeError): +            tresp(headers="foobar") + +        assert isinstance(tresp(headers=None).headers, Headers) + + +class TestResponseCore(object): +    """ +    Tests for builtins and the attributes that are directly proxied from the data structure +    """ +    def test_repr(self): +        response = tresp() +        assert repr(response) == "Response(200 OK, unknown content type, 7B)" +        response.content = None +        assert repr(response) == "Response(200 OK, no content)" + +    def test_status_code(self): +        _test_passthrough_attr(tresp(), "status_code") + +    def test_reason(self): +        _test_decoded_attr(tresp(), "reason") + + +class TestResponseUtils(object): +    """ +    Tests for additional convenience methods. +    """ +    def test_get_cookies_none(self): +        resp = tresp() +        resp.headers = Headers() +        assert not resp.cookies + +    def test_get_cookies_empty(self): +        resp = tresp() +        resp.headers = Headers(set_cookie="") +        assert not resp.cookies + +    def test_get_cookies_simple(self): +        resp = tresp() +        resp.headers = Headers(set_cookie="cookiename=cookievalue") +        result = resp.cookies +        assert len(result) == 1 +        assert "cookiename" in result +        assert result["cookiename"][0] == ["cookievalue", ODict()] + +    def test_get_cookies_with_parameters(self): +        resp = tresp() +        resp.headers = Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct  21 16:29:41 2015;path=/; HttpOnly") +        result = resp.cookies +        assert len(result) == 1 +        assert "cookiename" in result +        assert result["cookiename"][0][0] == "cookievalue" +        attrs = result["cookiename"][0][1] +        assert len(attrs) == 4 +        assert attrs["domain"] == ["example.com"] +        assert attrs["expires"] == ["Wed Oct  21 16:29:41 2015"] +        assert attrs["path"] == ["/"] +        assert attrs["httponly"] == [None] + +    def test_get_cookies_no_value(self): +        resp = tresp() +        resp.headers = Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") +        result = resp.cookies +        assert len(result) == 1 +        assert "cookiename" in result +        assert result["cookiename"][0][0] == "" +        assert len(result["cookiename"][0][1]) == 2 + +    def test_get_cookies_twocookies(self): +        resp = tresp() +        resp.headers = Headers([ +            [b"Set-Cookie", b"cookiename=cookievalue"], +            [b"Set-Cookie", b"othercookie=othervalue"] +        ]) +        result = resp.cookies +        assert len(result) == 2 +        assert "cookiename" in result +        assert result["cookiename"][0] == ["cookievalue", ODict()] +        assert "othercookie" in result +        assert result["othercookie"][0] == ["othervalue", ODict()] + +    def test_set_cookies(self): +        resp = tresp() +        v = resp.cookies +        v.add("foo", ["bar", ODictCaseless()]) +        resp.set_cookies(v) + +        v = resp.cookies +        assert len(v) == 1 +        assert v["foo"] == [["bar", ODictCaseless()]] diff --git a/netlib/test/http/test_status_codes.py b/netlib/test/http/test_status_codes.py new file mode 100644 index 00000000..9fea6b70 --- /dev/null +++ b/netlib/test/http/test_status_codes.py @@ -0,0 +1,6 @@ +from netlib.http import status_codes + + +def test_simple(): +    assert status_codes.IM_A_TEAPOT == 418 +    assert status_codes.RESPONSES[418] == "I'm a teapot" diff --git a/netlib/test/http/test_user_agents.py b/netlib/test/http/test_user_agents.py new file mode 100644 index 00000000..0bf1bba7 --- /dev/null +++ b/netlib/test/http/test_user_agents.py @@ -0,0 +1,6 @@ +from netlib.http import user_agents + + +def test_get_shortcut(): +    assert user_agents.get_by_shortcut("c")[0] == "chrome" +    assert not user_agents.get_by_shortcut("_") diff --git a/netlib/test/test_certutils.py b/netlib/test/test_certutils.py new file mode 100644 index 00000000..027dcc93 --- /dev/null +++ b/netlib/test/test_certutils.py @@ -0,0 +1,155 @@ +import os +from netlib import certutils, tutils + +# class TestDNTree: +#     def test_simple(self): +#         d = certutils.DNTree() +#         d.add("foo.com", "foo") +#         d.add("bar.com", "bar") +#         assert d.get("foo.com") == "foo" +#         assert d.get("bar.com") == "bar" +#         assert not d.get("oink.com") +#         assert not d.get("oink") +#         assert not d.get("") +#         assert not d.get("oink.oink") +# +#         d.add("*.match.org", "match") +#         assert not d.get("match.org") +#         assert d.get("foo.match.org") == "match" +#         assert d.get("foo.foo.match.org") == "match" +# +#     def test_wildcard(self): +#         d = certutils.DNTree() +#         d.add("foo.com", "foo") +#         assert not d.get("*.foo.com") +#         d.add("*.foo.com", "wild") +# +#         d = certutils.DNTree() +#         d.add("*", "foo") +#         assert d.get("foo.com") == "foo" +#         assert d.get("*.foo.com") == "foo" +#         assert d.get("com") == "foo" + + +class TestCertStore: + +    def test_create_explicit(self): +        with tutils.tmpdir() as d: +            ca = certutils.CertStore.from_store(d, "test") +            assert ca.get_cert(b"foo", []) + +            ca2 = certutils.CertStore.from_store(d, "test") +            assert ca2.get_cert(b"foo", []) + +            assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() + +    def test_create_no_common_name(self): +        with tutils.tmpdir() as d: +            ca = certutils.CertStore.from_store(d, "test") +            assert ca.get_cert(None, [])[0].cn is None + +    def test_create_tmp(self): +        with tutils.tmpdir() as d: +            ca = certutils.CertStore.from_store(d, "test") +            assert ca.get_cert(b"foo.com", []) +            assert ca.get_cert(b"foo.com", []) +            assert ca.get_cert(b"*.foo.com", []) + +            r = ca.get_cert(b"*.foo.com", []) +            assert r[1] == ca.default_privatekey + +    def test_sans(self): +        with tutils.tmpdir() as d: +            ca = certutils.CertStore.from_store(d, "test") +            c1 = ca.get_cert(b"foo.com", [b"*.bar.com"]) +            ca.get_cert(b"foo.bar.com", []) +            # assert c1 == c2 +            c3 = ca.get_cert(b"bar.com", []) +            assert not c1 == c3 + +    def test_sans_change(self): +        with tutils.tmpdir() as d: +            ca = certutils.CertStore.from_store(d, "test") +            ca.get_cert(b"foo.com", [b"*.bar.com"]) +            cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"]) +            assert b"*.baz.com" in cert.altnames + +    def test_overrides(self): +        with tutils.tmpdir() as d: +            ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") +            ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test") +            assert not ca1.default_ca.get_serial_number( +            ) == ca2.default_ca.get_serial_number() + +            dc = ca2.get_cert(b"foo.com", [b"sans.example.com"]) +            dcp = os.path.join(d, "dc") +            f = open(dcp, "wb") +            f.write(dc[0].to_pem()) +            f.close() +            ca1.add_cert_file(b"foo.com", dcp) + +            ret = ca1.get_cert(b"foo.com", []) +            assert ret[0].serial == dc[0].serial + + +class TestDummyCert: + +    def test_with_ca(self): +        with tutils.tmpdir() as d: +            ca = certutils.CertStore.from_store(d, "test") +            r = certutils.dummy_cert( +                ca.default_privatekey, +                ca.default_ca, +                b"foo.com", +                [b"one.com", b"two.com", b"*.three.com"] +            ) +            assert r.cn == b"foo.com" + +            r = certutils.dummy_cert( +                ca.default_privatekey, +                ca.default_ca, +                None, +                [] +            ) +            assert r.cn is None + + +class TestSSLCert: + +    def test_simple(self): +        with open(tutils.test_data.path("data/text_cert"), "rb") as f: +            d = f.read() +        c1 = certutils.SSLCert.from_pem(d) +        assert c1.cn == b"google.com" +        assert len(c1.altnames) == 436 + +        with open(tutils.test_data.path("data/text_cert_2"), "rb") as f: +            d = f.read() +        c2 = certutils.SSLCert.from_pem(d) +        assert c2.cn == b"www.inode.co.nz" +        assert len(c2.altnames) == 2 +        assert c2.digest("sha1") +        assert c2.notbefore +        assert c2.notafter +        assert c2.subject +        assert c2.keyinfo == ("RSA", 2048) +        assert c2.serial +        assert c2.issuer +        assert c2.to_pem() +        assert c2.has_expired is not None + +        assert not c1 == c2 +        assert c1 != c2 + +    def test_err_broken_sans(self): +        with open(tutils.test_data.path("data/text_cert_weird1"), "rb") as f: +            d = f.read() +        c = certutils.SSLCert.from_pem(d) +        # This breaks unless we ignore a decoding error. +        assert c.altnames is not None + +    def test_der(self): +        with open(tutils.test_data.path("data/dercert"), "rb") as f: +            d = f.read() +        s = certutils.SSLCert.from_der(d) +        assert s.cn diff --git a/netlib/test/test_encoding.py b/netlib/test/test_encoding.py new file mode 100644 index 00000000..0ff1aad1 --- /dev/null +++ b/netlib/test/test_encoding.py @@ -0,0 +1,37 @@ +from netlib import encoding + + +def test_identity(): +    assert b"string" == encoding.decode("identity", b"string") +    assert b"string" == encoding.encode("identity", b"string") +    assert not encoding.encode("nonexistent", b"string") +    assert not encoding.decode("nonexistent encoding", b"string") + + +def test_gzip(): +    assert b"string" == encoding.decode( +        "gzip", +        encoding.encode( +            "gzip", +            b"string" +        ) +    ) +    assert encoding.decode("gzip", b"bogus") is None + + +def test_deflate(): +    assert b"string" == encoding.decode( +        "deflate", +        encoding.encode( +            "deflate", +            b"string" +        ) +    ) +    assert b"string" == encoding.decode( +        "deflate", +        encoding.encode( +            "deflate", +            b"string" +        )[2:-4] +    ) +    assert encoding.decode("deflate", b"bogus") is None diff --git a/netlib/test/test_imports.py b/netlib/test/test_imports.py new file mode 100644 index 00000000..b88ef26d --- /dev/null +++ b/netlib/test/test_imports.py @@ -0,0 +1 @@ +# These are actually tests! diff --git a/netlib/test/test_odict.py b/netlib/test/test_odict.py new file mode 100644 index 00000000..f0985ef6 --- /dev/null +++ b/netlib/test/test_odict.py @@ -0,0 +1,153 @@ +from netlib import odict, tutils + + +class TestODict(object): + +    def test_repr(self): +        h = odict.ODict() +        h["one"] = ["two"] +        assert repr(h) + +    def test_str_err(self): +        h = odict.ODict() +        with tutils.raises(ValueError): +            h["key"] = u"foo" +        with tutils.raises(ValueError): +            h["key"] = b"foo" + +    def test_getset_state(self): +        od = odict.ODict() +        od.add("foo", 1) +        od.add("foo", 2) +        od.add("bar", 3) +        state = od.get_state() +        nd = odict.ODict.from_state(state) +        assert nd == od +        b = odict.ODict() +        b.set_state(state) +        assert b == od + +    def test_in_any(self): +        od = odict.ODict() +        od["one"] = ["atwoa", "athreea"] +        assert od.in_any("one", "two") +        assert od.in_any("one", "three") +        assert not od.in_any("one", "four") +        assert not od.in_any("nonexistent", "foo") +        assert not od.in_any("one", "TWO") +        assert od.in_any("one", "TWO", True) + +    def test_iter(self): +        od = odict.ODict() +        assert not [i for i in od] +        od.add("foo", 1) +        assert [i for i in od] + +    def test_keys(self): +        od = odict.ODict() +        assert not od.keys() +        od.add("foo", 1) +        assert od.keys() == ["foo"] +        od.add("foo", 2) +        assert od.keys() == ["foo"] +        od.add("bar", 2) +        assert len(od.keys()) == 2 + +    def test_copy(self): +        od = odict.ODict() +        od.add("foo", 1) +        od.add("foo", 2) +        od.add("bar", 3) +        assert od == od.copy() +        assert not od != od.copy() + +    def test_del(self): +        od = odict.ODict() +        od.add("foo", 1) +        od.add("Foo", 2) +        od.add("bar", 3) +        del od["foo"] +        assert len(od.lst) == 2 + +    def test_replace(self): +        od = odict.ODict() +        od.add("one", "two") +        od.add("two", "one") +        assert od.replace("one", "vun") == 2 +        assert od.lst == [ +            ["vun", "two"], +            ["two", "vun"], +        ] + +    def test_get(self): +        od = odict.ODict() +        od.add("one", "two") +        assert od.get("one") == ["two"] +        assert od.get("two") is None + +    def test_get_first(self): +        od = odict.ODict() +        od.add("one", "two") +        od.add("one", "three") +        assert od.get_first("one") == "two" +        assert od.get_first("two") is None + +    def test_extend(self): +        a = odict.ODict([["a", "b"], ["c", "d"]]) +        b = odict.ODict([["a", "b"], ["e", "f"]]) +        a.extend(b) +        assert len(a) == 4 +        assert a["a"] == ["b", "b"] + + +class TestODictCaseless(object): + +    def test_override(self): +        o = odict.ODictCaseless() +        o.add('T', 'application/x-www-form-urlencoded; charset=UTF-8') +        o["T"] = ["foo"] +        assert o["T"] == ["foo"] + +    def test_case_preservation(self): +        od = odict.ODictCaseless() +        od["Foo"] = ["1"] +        assert "foo" in od +        assert od.items()[0][0] == "Foo" +        assert od.get("foo") == ["1"] +        assert od.get("foo", [""]) == ["1"] +        assert od.get("Foo", [""]) == ["1"] +        assert od.get("xx", "yy") == "yy" + +    def test_del(self): +        od = odict.ODictCaseless() +        od.add("foo", 1) +        od.add("Foo", 2) +        od.add("bar", 3) +        del od["foo"] +        assert len(od) == 1 + +    def test_keys(self): +        od = odict.ODictCaseless() +        assert not od.keys() +        od.add("foo", 1) +        assert od.keys() == ["foo"] +        od.add("Foo", 2) +        assert od.keys() == ["foo"] +        od.add("bar", 2) +        assert len(od.keys()) == 2 + +    def test_add_order(self): +        od = odict.ODict( +            [ +                ["one", "uno"], +                ["two", "due"], +                ["three", "tre"], +            ] +        ) +        od["two"] = ["foo", "bar"] +        assert od.lst == [ +            ["one", "uno"], +            ["two", "foo"], +            ["three", "tre"], +            ["two", "bar"], +        ] diff --git a/netlib/test/test_socks.py b/netlib/test/test_socks.py new file mode 100644 index 00000000..d95dee41 --- /dev/null +++ b/netlib/test/test_socks.py @@ -0,0 +1,149 @@ +import ipaddress +from io import BytesIO +import socket +from netlib import socks, tcp, tutils + + +def test_client_greeting(): +    raw = tutils.treader(b"\x05\x02\x00\xBE\xEF") +    out = BytesIO() +    msg = socks.ClientGreeting.from_file(raw) +    msg.assert_socks5() +    msg.to_file(out) + +    assert out.getvalue() == raw.getvalue()[:-1] +    assert msg.ver == 5 +    assert len(msg.methods) == 2 +    assert 0xBE in msg.methods +    assert 0xEF not in msg.methods + + +def test_client_greeting_assert_socks5(): +    raw = tutils.treader(b"\x00\x00") +    msg = socks.ClientGreeting.from_file(raw) +    tutils.raises(socks.SocksError, msg.assert_socks5) + +    raw = tutils.treader(b"HTTP/1.1 200 OK" + b" " * 100) +    msg = socks.ClientGreeting.from_file(raw) +    try: +        msg.assert_socks5() +    except socks.SocksError as e: +        assert "Invalid SOCKS version" in str(e) +        assert "HTTP" not in str(e) +    else: +        assert False + +    raw = tutils.treader(b"GET / HTTP/1.1" + b" " * 100) +    msg = socks.ClientGreeting.from_file(raw) +    try: +        msg.assert_socks5() +    except socks.SocksError as e: +        assert "Invalid SOCKS version" in str(e) +        assert "HTTP" in str(e) +    else: +        assert False + +    raw = tutils.treader(b"XX") +    tutils.raises( +        socks.SocksError, +        socks.ClientGreeting.from_file, +        raw, +        fail_early=True) + + +def test_server_greeting(): +    raw = tutils.treader(b"\x05\x02") +    out = BytesIO() +    msg = socks.ServerGreeting.from_file(raw) +    msg.assert_socks5() +    msg.to_file(out) + +    assert out.getvalue() == raw.getvalue() +    assert msg.ver == 5 +    assert msg.method == 0x02 + + +def test_server_greeting_assert_socks5(): +    raw = tutils.treader(b"HTTP/1.1 200 OK" + b" " * 100) +    msg = socks.ServerGreeting.from_file(raw) +    try: +        msg.assert_socks5() +    except socks.SocksError as e: +        assert "Invalid SOCKS version" in str(e) +        assert "HTTP" in str(e) +    else: +        assert False + +    raw = tutils.treader(b"GET / HTTP/1.1" + b" " * 100) +    msg = socks.ServerGreeting.from_file(raw) +    try: +        msg.assert_socks5() +    except socks.SocksError as e: +        assert "Invalid SOCKS version" in str(e) +        assert "HTTP" not in str(e) +    else: +        assert False + + +def test_message(): +    raw = tutils.treader(b"\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") +    out = BytesIO() +    msg = socks.Message.from_file(raw) +    msg.assert_socks5() +    assert raw.read(2) == b"\xBE\xEF" +    msg.to_file(out) + +    assert out.getvalue() == raw.getvalue()[:-2] +    assert msg.ver == 5 +    assert msg.msg == 0x01 +    assert msg.atyp == 0x03 +    assert msg.addr == ("example.com", 0xDEAD) + + +def test_message_assert_socks5(): +    raw = tutils.treader(b"\xEE\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") +    msg = socks.Message.from_file(raw) +    tutils.raises(socks.SocksError, msg.assert_socks5) + + +def test_message_ipv4(): +    # Test ATYP=0x01 (IPV4) +    raw = tutils.treader(b"\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") +    out = BytesIO() +    msg = socks.Message.from_file(raw) +    left = raw.read(2) +    assert left == b"\xBE\xEF" +    msg.to_file(out) + +    assert out.getvalue() == raw.getvalue()[:-2] +    assert msg.addr == ("127.0.0.1", 0xDEAD) + + +def test_message_ipv6(): +    # Test ATYP=0x04 (IPV6) +    ipv6_addr = u"2001:db8:85a3:8d3:1319:8a2e:370:7344" + +    raw = tutils.treader( +        b"\x05\x01\x00\x04" + +        ipaddress.IPv6Address(ipv6_addr).packed + +        b"\xDE\xAD\xBE\xEF") +    out = BytesIO() +    msg = socks.Message.from_file(raw) +    assert raw.read(2) == b"\xBE\xEF" +    msg.to_file(out) + +    assert out.getvalue() == raw.getvalue()[:-2] +    assert msg.addr.host == ipv6_addr + + +def test_message_invalid_rsv(): +    raw = tutils.treader(b"\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") +    tutils.raises(socks.SocksError, socks.Message.from_file, raw) + + +def test_message_unknown_atyp(): +    raw = tutils.treader(b"\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") +    tutils.raises(socks.SocksError, socks.Message.from_file, raw) + +    m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) +    tutils.raises(socks.SocksError, m.to_file, BytesIO()) diff --git a/netlib/test/test_tcp.py b/netlib/test/test_tcp.py new file mode 100644 index 00000000..2b091ef0 --- /dev/null +++ b/netlib/test/test_tcp.py @@ -0,0 +1,795 @@ +from io import BytesIO +from six.moves import queue +import time +import socket +import random +import os +import threading +import mock + +from OpenSSL import SSL +import OpenSSL + +from netlib import tcp, certutils, tutils, tservers +from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ +    TcpTimeout, TcpDisconnect, TcpException, NetlibException + + +class EchoHandler(tcp.BaseHandler): +    sni = None + +    def handle_sni(self, connection): +        self.sni = connection.get_servername() + +    def handle(self): +        v = self.rfile.readline() +        self.wfile.write(v) +        self.wfile.flush() + + +class ClientCipherListHandler(tcp.BaseHandler): +    sni = None + +    def handle(self): +        self.wfile.write("%s" % self.connection.get_cipher_list()) +        self.wfile.flush() + + +class HangHandler(tcp.BaseHandler): + +    def handle(self): +        while True: +            time.sleep(1) + + +class ALPNHandler(tcp.BaseHandler): +    sni = None + +    def handle(self): +        alp = self.get_alpn_proto_negotiated() +        if alp: +            self.wfile.write(alp) +        else: +            self.wfile.write(b"NONE") +        self.wfile.flush() + + +class TestServer(tservers.ServerTestBase): +    handler = EchoHandler + +    def test_echo(self): +        testval = b"echo!\n" +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.wfile.write(testval) +        c.wfile.flush() +        assert c.rfile.readline() == testval + +    def test_thread_start_error(self): +        with mock.patch.object(threading.Thread, "start", side_effect=threading.ThreadError("nonewthread")) as m: +            c = tcp.TCPClient(("127.0.0.1", self.port)) +            c.connect() +            assert not c.rfile.read(1) +            assert m.called +            assert "nonewthread" in self.q.get_nowait() +        self.test_echo() + + +class TestServerBind(tservers.ServerTestBase): + +    class handler(tcp.BaseHandler): + +        def handle(self): +            self.wfile.write(str(self.connection.getpeername()).encode()) +            self.wfile.flush() + +    def test_bind(self): +        """ Test to bind to a given random port. Try again if the random port turned out to be blocked. """ +        for i in range(20): +            random_port = random.randrange(1024, 65535) +            try: +                c = tcp.TCPClient( +                    ("127.0.0.1", self.port), source_address=( +                        "127.0.0.1", random_port)) +                c.connect() +                assert c.rfile.readline() == str(("127.0.0.1", random_port)).encode() +                return +            except TcpException:  # port probably already in use +                pass + + +class TestServerIPv6(tservers.ServerTestBase): +    handler = EchoHandler +    addr = tcp.Address(("localhost", 0), use_ipv6=True) + +    def test_echo(self): +        testval = b"echo!\n" +        c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True)) +        c.connect() +        c.wfile.write(testval) +        c.wfile.flush() +        assert c.rfile.readline() == testval + + +class TestEcho(tservers.ServerTestBase): +    handler = EchoHandler + +    def test_echo(self): +        testval = b"echo!\n" +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.wfile.write(testval) +        c.wfile.flush() +        assert c.rfile.readline() == testval + + +class HardDisconnectHandler(tcp.BaseHandler): + +    def handle(self): +        self.connection.close() + + +class TestFinishFail(tservers.ServerTestBase): + +    """ +        This tests a difficult-to-trigger exception in the .finish() method of +        the handler. +    """ +    handler = EchoHandler + +    def test_disconnect_in_finish(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.wfile.write(b"foo\n") +        c.wfile.flush = mock.Mock(side_effect=TcpDisconnect) +        c.finish() + + +class TestServerSSL(tservers.ServerTestBase): +    handler = EchoHandler +    ssl = dict( +        cipher_list="AES256-SHA", +        chain_file=tutils.test_data.path("data/server.crt") +    ) + +    def test_echo(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL) +        testval = b"echo!\n" +        c.wfile.write(testval) +        c.wfile.flush() +        assert c.rfile.readline() == testval + +    def test_get_current_cipher(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        assert not c.get_current_cipher() +        c.convert_to_ssl(sni=b"foo.com") +        ret = c.get_current_cipher() +        assert ret +        assert "AES" in ret[0] + + +class TestSSLv3Only(tservers.ServerTestBase): +    handler = EchoHandler +    ssl = dict( +        request_client_cert=False, +        v3_only=True +    ) + +    def test_failure(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com") + + +class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): +    handler = EchoHandler + +    ssl = dict( +        cert=tutils.test_data.path("data/verificationcerts/self-signed.crt"), +        key=tutils.test_data.path("data/verificationcerts/self-signed.key") +    ) + +    def test_mode_default_should_pass(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() + +        c.convert_to_ssl() + +        # Verification errors should be saved even if connection isn't aborted +        # aborted +        assert c.ssl_verification_error is not None + +        testval = b"echo!\n" +        c.wfile.write(testval) +        c.wfile.flush() +        assert c.rfile.readline() == testval + +    def test_mode_none_should_pass(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() + +        c.convert_to_ssl(verify_options=SSL.VERIFY_NONE) + +        # Verification errors should be saved even if connection isn't aborted +        assert c.ssl_verification_error is not None + +        testval = b"echo!\n" +        c.wfile.write(testval) +        c.wfile.flush() +        assert c.rfile.readline() == testval + +    def test_mode_strict_should_fail(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() + +        with tutils.raises(InvalidCertificateException): +            c.convert_to_ssl( +                sni=b"example.mitmproxy.org", +                verify_options=SSL.VERIFY_PEER, +                ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") +            ) + +        assert c.ssl_verification_error is not None + +        # Unknown issuing certificate authority for first certificate +        assert c.ssl_verification_error['errno'] == 18 +        assert c.ssl_verification_error['depth'] == 0 + + +class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): +    handler = EchoHandler + +    ssl = dict( +        cert=tutils.test_data.path("data/verificationcerts/trusted-leaf.crt"), +        key=tutils.test_data.path("data/verificationcerts/trusted-leaf.key") +    ) + +    def test_should_fail_without_sni(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() + +        with tutils.raises(TlsException): +            c.convert_to_ssl( +                verify_options=SSL.VERIFY_PEER, +                ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") +            ) + +    def test_should_fail(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() + +        with tutils.raises(InvalidCertificateException): +            c.convert_to_ssl( +                sni=b"mitmproxy.org", +                verify_options=SSL.VERIFY_PEER, +                ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") +            ) + +        assert c.ssl_verification_error is not None + + +class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): +    handler = EchoHandler + +    ssl = dict( +        cert=tutils.test_data.path("data/verificationcerts/trusted-leaf.crt"), +        key=tutils.test_data.path("data/verificationcerts/trusted-leaf.key") +    ) + +    def test_mode_strict_w_pemfile_should_pass(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() + +        c.convert_to_ssl( +            sni=b"example.mitmproxy.org", +            verify_options=SSL.VERIFY_PEER, +            ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") +        ) + +        assert c.ssl_verification_error is None + +        testval = b"echo!\n" +        c.wfile.write(testval) +        c.wfile.flush() +        assert c.rfile.readline() == testval + +    def test_mode_strict_w_cadir_should_pass(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() + +        c.convert_to_ssl( +            sni=b"example.mitmproxy.org", +            verify_options=SSL.VERIFY_PEER, +            ca_path=tutils.test_data.path("data/verificationcerts/") +        ) + +        assert c.ssl_verification_error is None + +        testval = b"echo!\n" +        c.wfile.write(testval) +        c.wfile.flush() +        assert c.rfile.readline() == testval + + +class TestSSLClientCert(tservers.ServerTestBase): + +    class handler(tcp.BaseHandler): +        sni = None + +        def handle_sni(self, connection): +            self.sni = connection.get_servername() + +        def handle(self): +            self.wfile.write(b"%d\n" % self.clientcert.serial) +            self.wfile.flush() + +    ssl = dict( +        request_client_cert=True, +        v3_only=False +    ) + +    def test_clientcert(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl( +            cert=tutils.test_data.path("data/clientcert/client.pem")) +        assert c.rfile.readline().strip() == b"1" + +    def test_clientcert_err(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        tutils.raises( +            TlsException, +            c.convert_to_ssl, +            cert=tutils.test_data.path("data/clientcert/make") +        ) + + +class TestSNI(tservers.ServerTestBase): + +    class handler(tcp.BaseHandler): +        sni = None + +        def handle_sni(self, connection): +            self.sni = connection.get_servername() + +        def handle(self): +            self.wfile.write(self.sni) +            self.wfile.flush() + +    ssl = True + +    def test_echo(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl(sni=b"foo.com") +        assert c.sni == b"foo.com" +        assert c.rfile.readline() == b"foo.com" + + +class TestServerCipherList(tservers.ServerTestBase): +    handler = ClientCipherListHandler +    ssl = dict( +        cipher_list='RC4-SHA' +    ) + +    def test_echo(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl(sni=b"foo.com") +        assert c.rfile.readline() == b"['RC4-SHA']" + + +class TestServerCurrentCipher(tservers.ServerTestBase): + +    class handler(tcp.BaseHandler): +        sni = None + +        def handle(self): +            self.wfile.write(str(self.get_current_cipher()).encode()) +            self.wfile.flush() + +    ssl = dict( +        cipher_list='RC4-SHA' +    ) + +    def test_echo(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl(sni=b"foo.com") +        assert b"RC4-SHA" in c.rfile.readline() + + +class TestServerCipherListError(tservers.ServerTestBase): +    handler = ClientCipherListHandler +    ssl = dict( +        cipher_list='bogus' +    ) + +    def test_echo(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com") + + +class TestClientCipherListError(tservers.ServerTestBase): +    handler = ClientCipherListHandler +    ssl = dict( +        cipher_list='RC4-SHA' +    ) + +    def test_echo(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        tutils.raises( +            "cipher specification", +            c.convert_to_ssl, +            sni=b"foo.com", +            cipher_list="bogus") + + +class TestSSLDisconnect(tservers.ServerTestBase): + +    class handler(tcp.BaseHandler): + +        def handle(self): +            self.finish() + +    ssl = True + +    def test_echo(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        # Excercise SSL.ZeroReturnError +        c.rfile.read(10) +        c.close() +        tutils.raises(TcpDisconnect, c.wfile.write, b"foo") +        tutils.raises(queue.Empty, self.q.get_nowait) + + +class TestSSLHardDisconnect(tservers.ServerTestBase): +    handler = HardDisconnectHandler +    ssl = True + +    def test_echo(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        # Exercise SSL.SysCallError +        c.rfile.read(10) +        c.close() +        tutils.raises(TcpDisconnect, c.wfile.write, b"foo") + + +class TestDisconnect(tservers.ServerTestBase): + +    def test_echo(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.rfile.read(10) +        c.wfile.write(b"foo") +        c.close() +        c.close() + + +class TestServerTimeOut(tservers.ServerTestBase): + +    class handler(tcp.BaseHandler): + +        def handle(self): +            self.timeout = False +            self.settimeout(0.01) +            try: +                self.rfile.read(10) +            except TcpTimeout: +                self.timeout = True + +    def test_timeout(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        time.sleep(0.3) +        assert self.last_handler.timeout + + +class TestTimeOut(tservers.ServerTestBase): +    handler = HangHandler + +    def test_timeout(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.settimeout(0.1) +        assert c.gettimeout() == 0.1 +        tutils.raises(TcpTimeout, c.rfile.read, 10) + + +class TestALPNClient(tservers.ServerTestBase): +    handler = ALPNHandler +    ssl = dict( +        alpn_select=b"bar" +    ) + +    if OpenSSL._util.lib.Cryptography_HAS_ALPN: +        def test_alpn(self): +            c = tcp.TCPClient(("127.0.0.1", self.port)) +            c.connect() +            c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"]) +            assert c.get_alpn_proto_negotiated() == b"bar" +            assert c.rfile.readline().strip() == b"bar" + +        def test_no_alpn(self): +            c = tcp.TCPClient(("127.0.0.1", self.port)) +            c.connect() +            c.convert_to_ssl() +            assert c.get_alpn_proto_negotiated() == b"" +            assert c.rfile.readline().strip() == b"NONE" + +    else: +        def test_none_alpn(self): +            c = tcp.TCPClient(("127.0.0.1", self.port)) +            c.connect() +            c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"]) +            assert c.get_alpn_proto_negotiated() == b"" +            assert c.rfile.readline() == b"NONE" + + +class TestNoSSLNoALPNClient(tservers.ServerTestBase): +    handler = ALPNHandler + +    def test_no_ssl_no_alpn(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        assert c.get_alpn_proto_negotiated() == b"" +        assert c.rfile.readline().strip() == b"NONE" + + +class TestSSLTimeOut(tservers.ServerTestBase): +    handler = HangHandler +    ssl = True + +    def test_timeout_client(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        c.settimeout(0.1) +        tutils.raises(TcpTimeout, c.rfile.read, 10) + + +class TestDHParams(tservers.ServerTestBase): +    handler = HangHandler +    ssl = dict( +        dhparams=certutils.CertStore.load_dhparam( +            tutils.test_data.path("data/dhparam.pem"), +        ), +        cipher_list="DHE-RSA-AES256-SHA" +    ) + +    def test_dhparams(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        ret = c.get_current_cipher() +        assert ret[0] == "DHE-RSA-AES256-SHA" + +    def test_create_dhparams(self): +        with tutils.tmpdir() as d: +            filename = os.path.join(d, "dhparam.pem") +            certutils.CertStore.load_dhparam(filename) +            assert os.path.exists(filename) + + +class TestTCPClient: + +    def test_conerr(self): +        c = tcp.TCPClient(("127.0.0.1", 0)) +        tutils.raises(TcpException, c.connect) + + +class TestFileLike: + +    def test_blocksize(self): +        s = BytesIO(b"1234567890abcdefghijklmnopqrstuvwxyz") +        s = tcp.Reader(s) +        s.BLOCKSIZE = 2 +        assert s.read(1) == b"1" +        assert s.read(2) == b"23" +        assert s.read(3) == b"456" +        assert s.read(4) == b"7890" +        d = s.read(-1) +        assert d.startswith(b"abc") and d.endswith(b"xyz") + +    def test_wrap(self): +        s = BytesIO(b"foobar\nfoobar") +        s.flush() +        s = tcp.Reader(s) +        assert s.readline() == b"foobar\n" +        assert s.readline() == b"foobar" +        # Test __getattr__ +        assert s.isatty + +    def test_limit(self): +        s = BytesIO(b"foobar\nfoobar") +        s = tcp.Reader(s) +        assert s.readline(3) == b"foo" + +    def test_limitless(self): +        s = BytesIO(b"f" * (50 * 1024)) +        s = tcp.Reader(s) +        ret = s.read(-1) +        assert len(ret) == 50 * 1024 + +    def test_readlog(self): +        s = BytesIO(b"foobar\nfoobar") +        s = tcp.Reader(s) +        assert not s.is_logging() +        s.start_log() +        assert s.is_logging() +        s.readline() +        assert s.get_log() == b"foobar\n" +        s.read(1) +        assert s.get_log() == b"foobar\nf" +        s.start_log() +        assert s.get_log() == b"" +        s.read(1) +        assert s.get_log() == b"o" +        s.stop_log() +        tutils.raises(ValueError, s.get_log) + +    def test_writelog(self): +        s = BytesIO() +        s = tcp.Writer(s) +        s.start_log() +        assert s.is_logging() +        s.write(b"x") +        assert s.get_log() == b"x" +        s.write(b"x") +        assert s.get_log() == b"xx" + +    def test_writer_flush_error(self): +        s = BytesIO() +        s = tcp.Writer(s) +        o = mock.MagicMock() +        o.flush = mock.MagicMock(side_effect=socket.error) +        s.o = o +        tutils.raises(TcpDisconnect, s.flush) + +    def test_reader_read_error(self): +        s = BytesIO(b"foobar\nfoobar") +        s = tcp.Reader(s) +        o = mock.MagicMock() +        o.read = mock.MagicMock(side_effect=socket.error) +        s.o = o +        tutils.raises(TcpDisconnect, s.read, 10) + +    def test_reset_timestamps(self): +        s = BytesIO(b"foobar\nfoobar") +        s = tcp.Reader(s) +        s.first_byte_timestamp = 500 +        s.reset_timestamps() +        assert not s.first_byte_timestamp + +    def test_first_byte_timestamp_updated_on_read(self): +        s = BytesIO(b"foobar\nfoobar") +        s = tcp.Reader(s) +        s.read(1) +        assert s.first_byte_timestamp +        expected = s.first_byte_timestamp +        s.read(5) +        assert s.first_byte_timestamp == expected + +    def test_first_byte_timestamp_updated_on_readline(self): +        s = BytesIO(b"foobar\nfoobar\nfoobar") +        s = tcp.Reader(s) +        s.readline() +        assert s.first_byte_timestamp +        expected = s.first_byte_timestamp +        s.readline() +        assert s.first_byte_timestamp == expected + +    def test_read_ssl_error(self): +        s = mock.MagicMock() +        s.read = mock.MagicMock(side_effect=SSL.Error()) +        s = tcp.Reader(s) +        tutils.raises(TlsException, s.read, 1) + +    def test_read_syscall_ssl_error(self): +        s = mock.MagicMock() +        s.read = mock.MagicMock(side_effect=SSL.SysCallError()) +        s = tcp.Reader(s) +        tutils.raises(TlsException, s.read, 1) + +    def test_reader_readline_disconnect(self): +        o = mock.MagicMock() +        o.read = mock.MagicMock(side_effect=socket.error) +        s = tcp.Reader(o) +        tutils.raises(TcpDisconnect, s.readline, 10) + +    def test_reader_incomplete_error(self): +        s = BytesIO(b"foobar") +        s = tcp.Reader(s) +        tutils.raises(TcpReadIncomplete, s.safe_read, 10) + + +class TestPeek(tservers.ServerTestBase): +    handler = EchoHandler + +    def _connect(self, c): +        c.connect() + +    def test_peek(self): +        testval = b"peek!\n" +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        self._connect(c) +        c.wfile.write(testval) +        c.wfile.flush() + +        assert c.rfile.peek(4) == b"peek" +        assert c.rfile.peek(6) == b"peek!\n" +        assert c.rfile.readline() == testval + +        c.close() +        with tutils.raises(NetlibException): +            if c.rfile.peek(1) == b"": +                # Workaround for Python 2 on Unix: +                # Peeking a closed connection does not raise an exception here. +                raise NetlibException() + + +class TestPeekSSL(TestPeek): +    ssl = True + +    def _connect(self, c): +        c.connect() +        c.convert_to_ssl() + + +class TestAddress: + +    def test_simple(self): +        a = tcp.Address("localhost", True) +        assert a.use_ipv6 +        b = tcp.Address("foo.com", True) +        assert not a == b +        assert str(b) == str(tuple("foo.com")) +        c = tcp.Address("localhost", True) +        assert a == c +        assert not a != c +        assert repr(a) + + +class TestSSLKeyLogger(tservers.ServerTestBase): +    handler = EchoHandler +    ssl = dict( +        cipher_list="AES256-SHA" +    ) + +    def test_log(self): +        testval = b"echo!\n" +        _logfun = tcp.log_ssl_key + +        with tutils.tmpdir() as d: +            logfile = os.path.join(d, "foo", "bar", "logfile") +            tcp.log_ssl_key = tcp.SSLKeyLogger(logfile) + +            c = tcp.TCPClient(("127.0.0.1", self.port)) +            c.connect() +            c.convert_to_ssl() +            c.wfile.write(testval) +            c.wfile.flush() +            assert c.rfile.readline() == testval +            c.finish() + +            tcp.log_ssl_key.close() +            with open(logfile, "rb") as f: +                assert f.read().count(b"CLIENT_RANDOM") == 2 + +        tcp.log_ssl_key = _logfun + +    def test_create_logfun(self): +        assert isinstance( +            tcp.SSLKeyLogger.create_logfun("test"), +            tcp.SSLKeyLogger) +        assert not tcp.SSLKeyLogger.create_logfun(False) diff --git a/netlib/test/test_utils.py b/netlib/test/test_utils.py new file mode 100644 index 00000000..b096e5bc --- /dev/null +++ b/netlib/test/test_utils.py @@ -0,0 +1,141 @@ +from netlib import utils, tutils +from netlib.http import Headers + +def test_bidi(): +    b = utils.BiDi(a=1, b=2) +    assert b.a == 1 +    assert b.get_name(1) == "a" +    assert b.get_name(5) is None +    tutils.raises(AttributeError, getattr, b, "c") +    tutils.raises(ValueError, utils.BiDi, one=1, two=1) + + +def test_hexdump(): +    assert list(utils.hexdump(b"one\0" * 10)) + + +def test_clean_bin(): +    assert utils.clean_bin(b"one") == b"one" +    assert utils.clean_bin(b"\00ne") == b".ne" +    assert utils.clean_bin(b"\nne") == b"\nne" +    assert utils.clean_bin(b"\nne", False) == b".ne" +    assert utils.clean_bin(u"\u2605".encode("utf8")) == b"..." + +    assert utils.clean_bin(u"one") == u"one" +    assert utils.clean_bin(u"\00ne") == u".ne" +    assert utils.clean_bin(u"\nne") == u"\nne" +    assert utils.clean_bin(u"\nne", False) == u".ne" +    assert utils.clean_bin(u"\u2605") == u"\u2605" + + +def test_pretty_size(): +    assert utils.pretty_size(100) == "100B" +    assert utils.pretty_size(1024) == "1kB" +    assert utils.pretty_size(1024 + (1024 / 2.0)) == "1.5kB" +    assert utils.pretty_size(1024 * 1024) == "1MB" + + +def test_parse_url(): +    with tutils.raises(ValueError): +        utils.parse_url("") + +    s, h, po, pa = utils.parse_url(b"http://foo.com:8888/test") +    assert s == b"http" +    assert h == b"foo.com" +    assert po == 8888 +    assert pa == b"/test" + +    s, h, po, pa = utils.parse_url("http://foo/bar") +    assert s == b"http" +    assert h == b"foo" +    assert po == 80 +    assert pa == b"/bar" + +    s, h, po, pa = utils.parse_url(b"http://user:pass@foo/bar") +    assert s == b"http" +    assert h == b"foo" +    assert po == 80 +    assert pa == b"/bar" + +    s, h, po, pa = utils.parse_url(b"http://foo") +    assert pa == b"/" + +    s, h, po, pa = utils.parse_url(b"https://foo") +    assert po == 443 + +    with tutils.raises(ValueError): +        utils.parse_url(b"https://foo:bar") + +    # Invalid IDNA +    with tutils.raises(ValueError): +        utils.parse_url("http://\xfafoo") +    # Invalid PATH +    with tutils.raises(ValueError): +        utils.parse_url("http:/\xc6/localhost:56121") +    # Null byte in host +    with tutils.raises(ValueError): +        utils.parse_url("http://foo\0") +    # Port out of range +    _, _, port, _ = utils.parse_url("http://foo:999999") +    assert port == 80 +    # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt +    with tutils.raises(ValueError): +        utils.parse_url('http://lo[calhost') + + +def test_unparse_url(): +    assert utils.unparse_url("http", "foo.com", 99, "") == "http://foo.com:99" +    assert utils.unparse_url("http", "foo.com", 80, "/bar") == "http://foo.com/bar" +    assert utils.unparse_url("https", "foo.com", 80, "") == "https://foo.com:80" +    assert utils.unparse_url("https", "foo.com", 443, "") == "https://foo.com" + + +def test_urlencode(): +    assert utils.urlencode([('foo', 'bar')]) + + +def test_urldecode(): +    s = "one=two&three=four" +    assert len(utils.urldecode(s)) == 2 + + +def test_get_header_tokens(): +    headers = Headers() +    assert utils.get_header_tokens(headers, "foo") == [] +    headers["foo"] = "bar" +    assert utils.get_header_tokens(headers, "foo") == ["bar"] +    headers["foo"] = "bar, voing" +    assert utils.get_header_tokens(headers, "foo") == ["bar", "voing"] +    headers.set_all("foo", ["bar, voing", "oink"]) +    assert utils.get_header_tokens(headers, "foo") == ["bar", "voing", "oink"] + + +def test_multipartdecode(): +    boundary = 'somefancyboundary' +    headers = Headers( +        content_type='multipart/form-data; boundary=' + boundary +    ) +    content = ( +        "--{0}\n" +        "Content-Disposition: form-data; name=\"field1\"\n\n" +        "value1\n" +        "--{0}\n" +        "Content-Disposition: form-data; name=\"field2\"\n\n" +        "value2\n" +        "--{0}--".format(boundary).encode() +    ) + +    form = utils.multipartdecode(headers, content) + +    assert len(form) == 2 +    assert form[0] == (b"field1", b"value1") +    assert form[1] == (b"field2", b"value2") + + +def test_parse_content_type(): +    p = utils.parse_content_type +    assert p("text/html") == ("text", "html", {}) +    assert p("text") is None + +    v = p("text/html; charset=UTF-8") +    assert v == ('text', 'html', {'charset': 'UTF-8'}) diff --git a/netlib/test/test_version_check.py b/netlib/test/test_version_check.py new file mode 100644 index 00000000..ec2396fe --- /dev/null +++ b/netlib/test/test_version_check.py @@ -0,0 +1,38 @@ +from io import StringIO +import mock +from netlib import version_check, version + + +@mock.patch("sys.exit") +def test_check_mitmproxy_version(sexit): +    fp = StringIO() +    version_check.check_mitmproxy_version(version.IVERSION, fp=fp) +    assert not fp.getvalue() +    assert not sexit.called + +    b = (version.IVERSION[0] - 1, version.IVERSION[1]) +    version_check.check_mitmproxy_version(b, fp=fp) +    assert fp.getvalue() +    assert sexit.called + + +@mock.patch("sys.exit") +def test_check_pyopenssl_version(sexit): +    fp = StringIO() +    version_check.check_pyopenssl_version(fp=fp) +    assert not fp.getvalue() +    assert not sexit.called + +    version_check.check_pyopenssl_version((9999,), fp=fp) +    assert "outdated" in fp.getvalue() +    assert sexit.called + + +@mock.patch("sys.exit") +@mock.patch("OpenSSL.__version__") +def test_unparseable_pyopenssl_version(version, sexit): +    version.split.return_value = ["foo", "bar"] +    fp = StringIO() +    version_check.check_pyopenssl_version(fp=fp) +    assert "Cannot parse" in fp.getvalue() +    assert not sexit.called diff --git a/netlib/test/test_wsgi.py b/netlib/test/test_wsgi.py new file mode 100644 index 00000000..8c782b27 --- /dev/null +++ b/netlib/test/test_wsgi.py @@ -0,0 +1,106 @@ +from io import BytesIO +import sys +from netlib import wsgi +from netlib.http import Headers + + +def tflow(): +    headers = Headers(test=b"value") +    req = wsgi.Request("http", "GET", "/", "HTTP/1.1", headers, "") +    return wsgi.Flow(("127.0.0.1", 8888), req) + + +class ExampleApp: +     +    def __init__(self): +        self.called = False + +    def __call__(self, environ, start_response): +        self.called = True +        status = '200 OK' +        response_headers = [('Content-type', 'text/plain')] +        start_response(status, response_headers) +        return [b'Hello', b' world!\n'] + + +class TestWSGI: + +    def test_make_environ(self): +        w = wsgi.WSGIAdaptor(None, "foo", 80, "version") +        tf = tflow() +        assert w.make_environ(tf, None) + +        tf.request.path = "/foo?bar=voing" +        r = w.make_environ(tf, None) +        assert r["QUERY_STRING"] == "bar=voing" + +    def test_serve(self): +        ta = ExampleApp() +        w = wsgi.WSGIAdaptor(ta, "foo", 80, "version") +        f = tflow() +        f.request.host = "foo" +        f.request.port = 80 + +        wfile = BytesIO() +        err = w.serve(f, wfile) +        assert ta.called +        assert not err + +        val = wfile.getvalue() +        assert b"Hello world" in val +        assert b"Server:" in val + +    def _serve(self, app): +        w = wsgi.WSGIAdaptor(app, "foo", 80, "version") +        f = tflow() +        f.request.host = "foo" +        f.request.port = 80 +        wfile = BytesIO() +        w.serve(f, wfile) +        return wfile.getvalue() + +    def test_serve_empty_body(self): +        def app(environ, start_response): +            status = '200 OK' +            response_headers = [('Foo', 'bar')] +            start_response(status, response_headers) +            return [] +        assert self._serve(app) + +    def test_serve_double_start(self): +        def app(environ, start_response): +            try: +                raise ValueError("foo") +            except: +                sys.exc_info() +            status = '200 OK' +            response_headers = [('Content-type', 'text/plain')] +            start_response(status, response_headers) +            start_response(status, response_headers) +        assert b"Internal Server Error" in self._serve(app) + +    def test_serve_single_err(self): +        def app(environ, start_response): +            try: +                raise ValueError("foo") +            except: +                ei = sys.exc_info() +            status = '200 OK' +            response_headers = [('Content-type', 'text/plain')] +            start_response(status, response_headers, ei) +            yield b"" +        assert b"Internal Server Error" in self._serve(app) + +    def test_serve_double_err(self): +        def app(environ, start_response): +            try: +                raise ValueError("foo") +            except: +                ei = sys.exc_info() +            status = '200 OK' +            response_headers = [('Content-type', 'text/plain')] +            start_response(status, response_headers) +            yield b"aaa" +            start_response(status, response_headers, ei) +            yield b"bbb" +        assert b"Internal Server Error" in self._serve(app) diff --git a/netlib/test/tools/getcertnames b/netlib/test/tools/getcertnames new file mode 100644 index 00000000..e33619f7 --- /dev/null +++ b/netlib/test/tools/getcertnames @@ -0,0 +1,27 @@ +#!/usr/bin/env python +import sys +sys.path.insert(0, "../../") +from netlib import tcp + + +def get_remote_cert(host, port, sni): +    c = tcp.TCPClient((host, port)) +    c.connect() +    c.convert_to_ssl(sni=sni) +    return c.cert + +if len(sys.argv) > 2: +    port = int(sys.argv[2]) +else: +    port = 443 +if len(sys.argv) > 3: +    sni = sys.argv[3] +else: +    sni = None + +cert = get_remote_cert(sys.argv[1], port, sni) +print "CN:", cert.cn +if cert.altnames: +    print "SANs:", +    for i in cert.altnames: +        print "\t", i diff --git a/netlib/test/websockets/__init__.py b/netlib/test/websockets/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/netlib/test/websockets/__init__.py diff --git a/netlib/test/websockets/test_websockets.py b/netlib/test/websockets/test_websockets.py new file mode 100644 index 00000000..d53f0d83 --- /dev/null +++ b/netlib/test/websockets/test_websockets.py @@ -0,0 +1,266 @@ +import os + +from netlib.http.http1 import read_response, read_request + +from netlib import tcp, websockets, http, tutils, tservers +from netlib.http import status_codes +from netlib.tutils import treq + +from netlib.exceptions import * + + +class WebSocketsEchoHandler(tcp.BaseHandler): + +    def __init__(self, connection, address, server): +        super(WebSocketsEchoHandler, self).__init__( +            connection, address, server +        ) +        self.protocol = websockets.WebsocketsProtocol() +        self.handshake_done = False + +    def handle(self): +        while True: +            if not self.handshake_done: +                self.handshake() +            else: +                self.read_next_message() + +    def read_next_message(self): +        frame = websockets.Frame.from_file(self.rfile) +        self.on_message(frame.payload) + +    def send_message(self, message): +        frame = websockets.Frame.default(message, from_client=False) +        frame.to_file(self.wfile) + +    def handshake(self): + +        req = read_request(self.rfile) +        key = self.protocol.check_client_handshake(req.headers) + +        preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) +        self.wfile.write(preamble.encode() + b"\r\n") +        headers = self.protocol.server_handshake_headers(key) +        self.wfile.write(str(headers) + "\r\n") +        self.wfile.flush() +        self.handshake_done = True + +    def on_message(self, message): +        if message is not None: +            self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + +    def __init__(self, address, source_address=None): +        super(WebSocketsClient, self).__init__(address, source_address) +        self.protocol = websockets.WebsocketsProtocol() +        self.client_nonce = None + +    def connect(self): +        super(WebSocketsClient, self).connect() + +        preamble = b'GET / HTTP/1.1' +        self.wfile.write(preamble + b"\r\n") +        headers = self.protocol.client_handshake_headers() +        self.client_nonce = headers["sec-websocket-key"].encode("ascii") +        self.wfile.write(bytes(headers) + b"\r\n") +        self.wfile.flush() + +        resp = read_response(self.rfile, treq(method=b"GET")) +        server_nonce = self.protocol.check_server_handshake(resp.headers) + +        if not server_nonce == self.protocol.create_server_nonce(self.client_nonce): +            self.close() + +    def read_next_message(self): +        return websockets.Frame.from_file(self.rfile).payload + +    def send_message(self, message): +        frame = websockets.Frame.default(message, from_client=True) +        frame.to_file(self.wfile) + + +class TestWebSockets(tservers.ServerTestBase): +    handler = WebSocketsEchoHandler + +    def __init__(self): +        self.protocol = websockets.WebsocketsProtocol() + +    def random_bytes(self, n=100): +        return os.urandom(n) + +    def echo(self, msg): +        client = WebSocketsClient(("127.0.0.1", self.port)) +        client.connect() +        client.send_message(msg) +        response = client.read_next_message() +        assert response == msg + +    def test_simple_echo(self): +        self.echo(b"hello I'm the client") + +    def test_frame_sizes(self): +        # length can fit in the the 7 bit payload length +        small_msg = self.random_bytes(100) +        # 50kb, sligthly larger than can fit in a 7 bit int +        medium_msg = self.random_bytes(50000) +        # 150kb, slightly larger than can fit in a 16 bit int +        large_msg = self.random_bytes(150000) + +        self.echo(small_msg) +        self.echo(medium_msg) +        self.echo(large_msg) + +    def test_default_builder(self): +        """ +          default builder should always generate valid frames +        """ +        msg = self.random_bytes() +        client_frame = websockets.Frame.default(msg, from_client=True) +        server_frame = websockets.Frame.default(msg, from_client=False) + +    def test_serialization_bijection(self): +        """ +          Ensure that various frame types can be serialized/deserialized back +          and forth between to_bytes() and from_bytes() +        """ +        for is_client in [True, False]: +            for num_bytes in [100, 50000, 150000]: +                frame = websockets.Frame.default( +                    self.random_bytes(num_bytes), is_client +                ) +                frame2 = websockets.Frame.from_bytes( +                    frame.to_bytes() +                ) +                assert frame == frame2 + +        bytes = b'\x81\x03cba' +        assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes + +    def test_check_server_handshake(self): +        headers = self.protocol.server_handshake_headers("key") +        assert self.protocol.check_server_handshake(headers) +        headers["Upgrade"] = "not_websocket" +        assert not self.protocol.check_server_handshake(headers) + +    def test_check_client_handshake(self): +        headers = self.protocol.client_handshake_headers("key") +        assert self.protocol.check_client_handshake(headers) == "key" +        headers["Upgrade"] = "not_websocket" +        assert not self.protocol.check_client_handshake(headers) + + +class BadHandshakeHandler(WebSocketsEchoHandler): + +    def handshake(self): + +        client_hs = read_request(self.rfile) +        self.protocol.check_client_handshake(client_hs.headers) + +        preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101) +        self.wfile.write(preamble.encode()) +        headers = self.protocol.server_handshake_headers(b"malformed key") +        self.wfile.write(bytes(headers) + b"\r\n") +        self.wfile.flush() +        self.handshake_done = True + + +class TestBadHandshake(tservers.ServerTestBase): + +    """ +      Ensure that the client disconnects if the server handshake is malformed +    """ +    handler = BadHandshakeHandler + +    def test(self): +        with tutils.raises(TcpDisconnect): +            client = WebSocketsClient(("127.0.0.1", self.port)) +            client.connect() +            client.send_message(b"hello") + + +class TestFrameHeader: + +    def test_roundtrip(self): +        def round(*args, **kwargs): +            f = websockets.FrameHeader(*args, **kwargs) +            f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f))) +            assert f == f2 +        round() +        round(fin=1) +        round(rsv1=1) +        round(rsv2=1) +        round(rsv3=1) +        round(payload_length=1) +        round(payload_length=100) +        round(payload_length=1000) +        round(payload_length=10000) +        round(opcode=websockets.OPCODE.PING) +        round(masking_key=b"test") + +    def test_human_readable(self): +        f = websockets.FrameHeader( +            masking_key=b"test", +            fin=True, +            payload_length=10 +        ) +        assert repr(f) +        f = websockets.FrameHeader() +        assert repr(f) + +    def test_funky(self): +        f = websockets.FrameHeader(masking_key=b"test", mask=False) +        raw = bytes(f) +        f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) +        assert not f2.mask + +    def test_violations(self): +        tutils.raises("opcode", websockets.FrameHeader, opcode=17) +        tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") + +    def test_automask(self): +        f = websockets.FrameHeader(mask=True) +        assert f.masking_key + +        f = websockets.FrameHeader(masking_key=b"foob") +        assert f.mask + +        f = websockets.FrameHeader(masking_key=b"foob", mask=0) +        assert not f.mask +        assert f.masking_key + + +class TestFrame: + +    def test_roundtrip(self): +        def round(*args, **kwargs): +            f = websockets.Frame(*args, **kwargs) +            raw = bytes(f) +            f2 = websockets.Frame.from_file(tutils.treader(raw)) +            assert f == f2 +        round(b"test") +        round(b"test", fin=1) +        round(b"test", rsv1=1) +        round(b"test", opcode=websockets.OPCODE.PING) +        round(b"test", masking_key=b"test") + +    def test_human_readable(self): +        f = websockets.Frame() +        assert repr(f) + + +def test_masker(): +    tests = [ +        [b"a"], +        [b"four"], +        [b"fourf"], +        [b"fourfive"], +        [b"a", b"aasdfasdfa", b"asdf"], +        [b"a" * 50, b"aasdfasdfa", b"asdf"], +    ] +    for i in tests: +        m = websockets.Masker(b"abcd") +        data = b"".join([m(t) for t in i]) +        data2 = websockets.Masker(b"abcd")(data) +        assert data2 == b"".join(i)  | 
