From f964d49853a3f0d22e0f6d4cff7cfbc49008e40e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 20 Oct 2016 11:02:52 +1300 Subject: netlib.certutils -> mitmproxy.certs --- mitmproxy/certs.py | 481 ++++++++++++++++++++++++++++++++++++++++++ mitmproxy/connections.py | 6 +- mitmproxy/proxy/config.py | 4 +- netlib/certutils.py | 481 ------------------------------------------ netlib/tcp.py | 12 +- pathod/pathoc.py | 12 +- pathod/pathod.py | 22 +- test/mitmproxy/test_certs.py | 181 ++++++++++++++++ test/mitmproxy/test_server.py | 33 +-- test/netlib/test_certutils.py | 180 ---------------- test/netlib/test_tcp.py | 53 ++--- 11 files changed, 738 insertions(+), 727 deletions(-) create mode 100644 mitmproxy/certs.py delete mode 100644 netlib/certutils.py create mode 100644 test/mitmproxy/test_certs.py delete mode 100644 test/netlib/test_certutils.py diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py new file mode 100644 index 00000000..9cb8a40e --- /dev/null +++ b/mitmproxy/certs.py @@ -0,0 +1,481 @@ +import os +import ssl +import time +import datetime +import ipaddress + +import sys +from pyasn1.type import univ, constraint, char, namedtype, tag +from pyasn1.codec.der.decoder import decode +from pyasn1.error import PyAsn1Error +import OpenSSL + +from mitmproxy.types import serializable + +# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 + +DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 +# Generated with "openssl dhparam". It's too slow to generate this on startup. +DEFAULT_DHPARAM = b""" +-----BEGIN DH PARAMETERS----- +MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 +O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv +j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ +Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB +chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC +ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq +o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX +IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv +A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8 +6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I +rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= +-----END DH PARAMETERS----- +""" + + +def create_ca(o, cn, exp): + key = OpenSSL.crypto.PKey() + key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048) + cert = OpenSSL.crypto.X509() + cert.set_serial_number(int(time.time() * 10000)) + cert.set_version(2) + cert.get_subject().CN = cn + cert.get_subject().O = o + cert.gmtime_adj_notBefore(-3600 * 48) + cert.gmtime_adj_notAfter(exp) + cert.set_issuer(cert.get_subject()) + cert.set_pubkey(key) + cert.add_extensions([ + OpenSSL.crypto.X509Extension( + b"basicConstraints", + True, + b"CA:TRUE" + ), + OpenSSL.crypto.X509Extension( + b"nsCertType", + False, + b"sslCA" + ), + OpenSSL.crypto.X509Extension( + b"extendedKeyUsage", + False, + b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" + ), + OpenSSL.crypto.X509Extension( + b"keyUsage", + True, + b"keyCertSign, cRLSign" + ), + OpenSSL.crypto.X509Extension( + b"subjectKeyIdentifier", + False, + b"hash", + subject=cert + ), + ]) + cert.sign(key, "sha256") + return key, cert + + +def dummy_cert(privkey, cacert, commonname, sans): + """ + Generates a dummy certificate. + + privkey: CA private key + cacert: CA certificate + commonname: Common name for the generated certificate. + sans: A list of Subject Alternate Names. + + Returns cert if operation succeeded, None if not. + """ + ss = [] + for i in sans: + try: + ipaddress.ip_address(i.decode("ascii")) + except ValueError: + ss.append(b"DNS: %s" % i) + else: + ss.append(b"IP: %s" % i) + ss = b", ".join(ss) + + cert = OpenSSL.crypto.X509() + cert.gmtime_adj_notBefore(-3600 * 48) + cert.gmtime_adj_notAfter(DEFAULT_EXP) + cert.set_issuer(cacert.get_subject()) + if commonname is not None: + cert.get_subject().CN = commonname + cert.set_serial_number(int(time.time() * 10000)) + if ss: + cert.set_version(2) + cert.add_extensions( + [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)]) + cert.set_pubkey(cacert.get_pubkey()) + cert.sign(privkey, "sha256") + return SSLCert(cert) + + +# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. +# +# class _Node(UserDict.UserDict): +# def __init__(self): +# UserDict.UserDict.__init__(self) +# self.value = None +# +# +# class DNTree: +# """ +# Domain store that knows about wildcards. DNS wildcards are very +# restricted - the only valid variety is an asterisk on the left-most +# domain component, i.e.: +# +# *.foo.com +# """ +# def __init__(self): +# self.d = _Node() +# +# def add(self, dn, cert): +# parts = dn.split(".") +# parts.reverse() +# current = self.d +# for i in parts: +# current = current.setdefault(i, _Node()) +# current.value = cert +# +# def get(self, dn): +# parts = dn.split(".") +# current = self.d +# for i in reversed(parts): +# if i in current: +# current = current[i] +# elif "*" in current: +# return current["*"].value +# else: +# return None +# return current.value + + +class CertStoreEntry: + + def __init__(self, cert, privatekey, chain_file): + self.cert = cert + self.privatekey = privatekey + self.chain_file = chain_file + + +class CertStore: + + """ + Implements an in-memory certificate store. + """ + STORE_CAP = 100 + + def __init__( + self, + default_privatekey, + default_ca, + default_chain_file, + dhparams): + self.default_privatekey = default_privatekey + self.default_ca = default_ca + self.default_chain_file = default_chain_file + self.dhparams = dhparams + self.certs = dict() + self.expire_queue = [] + + def expire(self, entry): + self.expire_queue.append(entry) + if len(self.expire_queue) > self.STORE_CAP: + d = self.expire_queue.pop(0) + for k, v in list(self.certs.items()): + if v == d: + del self.certs[k] + + @staticmethod + def load_dhparam(path): + + # netlib<=0.10 doesn't generate a dhparam file. + # Create it now if neccessary. + if not os.path.exists(path): + with open(path, "wb") as f: + f.write(DEFAULT_DHPARAM) + + bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r") + if bio != OpenSSL.SSL._ffi.NULL: + bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) + dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( + bio, + OpenSSL.SSL._ffi.NULL, + OpenSSL.SSL._ffi.NULL, + OpenSSL.SSL._ffi.NULL) + dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) + return dh + + @classmethod + def from_store(cls, path, basename): + ca_path = os.path.join(path, basename + "-ca.pem") + if not os.path.exists(ca_path): + key, ca = cls.create_store(path, basename) + else: + with open(ca_path, "rb") as f: + raw = f.read() + ca = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, + raw) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + raw) + dh_path = os.path.join(path, basename + "-dhparam.pem") + dh = cls.load_dhparam(dh_path) + return cls(key, ca, ca_path, dh) + + @staticmethod + def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP): + if not os.path.exists(path): + os.makedirs(path) + + o = o or basename + cn = cn or basename + + key, ca = create_ca(o=o, cn=cn, exp=expiry) + # Dump the CA plus private key + with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: + f.write( + OpenSSL.crypto.dump_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + key)) + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) + + # Dump the certificate in PEM format + with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) + + # Create a .cer file with the same contents for Android + with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) + + # Dump the certificate in PKCS12 format for Windows devices + with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: + p12 = OpenSSL.crypto.PKCS12() + p12.set_certificate(ca) + p12.set_privatekey(key) + f.write(p12.export()) + + with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: + f.write(DEFAULT_DHPARAM) + + return key, ca + + def add_cert_file(self, spec, path): + with open(path, "rb") as f: + raw = f.read() + cert = SSLCert( + OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, + raw)) + try: + privatekey = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + raw) + except Exception: + privatekey = self.default_privatekey + self.add_cert( + CertStoreEntry(cert, privatekey, path), + spec + ) + + def add_cert(self, entry, *names): + """ + Adds a cert to the certstore. We register the CN in the cert plus + any SANs, and also the list of names provided as an argument. + """ + if entry.cert.cn: + self.certs[entry.cert.cn] = entry + for i in entry.cert.altnames: + self.certs[i] = entry + for i in names: + self.certs[i] = entry + + @staticmethod + def asterisk_forms(dn): + if dn is None: + return [] + parts = dn.split(b".") + parts.reverse() + curr_dn = b"" + dn_forms = [b"*"] + for part in parts[:-1]: + curr_dn = b"." + part + curr_dn # .example.com + dn_forms.append(b"*" + curr_dn) # *.example.com + if parts[-1] != b"*": + dn_forms.append(parts[-1] + curr_dn) + return dn_forms + + def get_cert(self, commonname, sans): + """ + Returns an (cert, privkey, cert_chain) tuple. + + commonname: Common name for the generated certificate. Must be a + valid, plain-ASCII, IDNA-encoded domain name. + + sans: A list of Subject Alternate Names. + """ + + potential_keys = self.asterisk_forms(commonname) + for s in sans: + potential_keys.extend(self.asterisk_forms(s)) + potential_keys.append((commonname, tuple(sans))) + + name = next( + filter(lambda key: key in self.certs, potential_keys), + None + ) + if name: + entry = self.certs[name] + else: + entry = CertStoreEntry( + cert=dummy_cert( + self.default_privatekey, + self.default_ca, + commonname, + sans), + privatekey=self.default_privatekey, + chain_file=self.default_chain_file) + self.certs[(commonname, tuple(sans))] = entry + self.expire(entry) + + return entry.cert, entry.privatekey, entry.chain_file + + +class _GeneralName(univ.Choice): + # We are only interested in dNSNames. We use a default handler to ignore + # other types. + # TODO: We should also handle iPAddresses. + componentType = namedtype.NamedTypes( + namedtype.NamedType('dNSName', char.IA5String().subtype( + implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) + ) + ), + ) + + +class _GeneralNames(univ.SequenceOf): + componentType = _GeneralName() + sizeSpec = univ.SequenceOf.sizeSpec + \ + constraint.ValueSizeConstraint(1, 1024) + + +class SSLCert(serializable.Serializable): + + def __init__(self, cert): + """ + Returns a (common name, [subject alternative names]) tuple. + """ + self.x509 = cert + + def __eq__(self, other): + return self.digest("sha256") == other.digest("sha256") + + def __ne__(self, other): + return not self.__eq__(other) + + def get_state(self): + return self.to_pem() + + def set_state(self, state): + self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) + + @classmethod + def from_state(cls, state): + return cls.from_pem(state) + + @classmethod + def from_pem(cls, txt): + x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) + return cls(x509) + + @classmethod + def from_der(cls, der): + pem = ssl.DER_cert_to_PEM_cert(der) + return cls.from_pem(pem) + + def to_pem(self): + return OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + self.x509) + + def digest(self, name): + return self.x509.digest(name) + + @property + def issuer(self): + return self.x509.get_issuer().get_components() + + @property + def notbefore(self): + t = self.x509.get_notBefore() + return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") + + @property + def notafter(self): + t = self.x509.get_notAfter() + return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") + + @property + def has_expired(self): + return self.x509.has_expired() + + @property + def subject(self): + return self.x509.get_subject().get_components() + + @property + def serial(self): + return self.x509.get_serial_number() + + @property + def keyinfo(self): + pk = self.x509.get_pubkey() + types = { + OpenSSL.crypto.TYPE_RSA: "RSA", + OpenSSL.crypto.TYPE_DSA: "DSA", + } + return ( + types.get(pk.type(), "UNKNOWN"), + pk.bits() + ) + + @property + def cn(self): + c = None + for i in self.subject: + if i[0] == b"CN": + c = i[1] + return c + + @property + def altnames(self): + """ + Returns: + All DNS altnames. + """ + # tcp.TCPClient.convert_to_ssl assumes that this property only contains DNS altnames for hostname verification. + altnames = [] + for i in range(self.x509.get_extension_count()): + ext = self.x509.get_extension(i) + if ext.get_short_name() == b"subjectAltName": + try: + dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) + except PyAsn1Error: + continue + for i in dec[0]: + altnames.append(i[0].asOctets()) + return altnames diff --git a/mitmproxy/connections.py b/mitmproxy/connections.py index bf7a12aa..6b39ac20 100644 --- a/mitmproxy/connections.py +++ b/mitmproxy/connections.py @@ -4,7 +4,7 @@ import copy import os from mitmproxy import stateobject -from netlib import certutils +from mitmproxy import certs from netlib import tcp @@ -57,7 +57,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): _stateobject_attributes = dict( address=tcp.Address, ssl_established=bool, - clientcert=certutils.SSLCert, + clientcert=certs.SSLCert, timestamp_start=float, timestamp_ssl_setup=float, timestamp_end=float, @@ -151,7 +151,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): ip_address=tcp.Address, source_address=tcp.Address, ssl_established=bool, - cert=certutils.SSLCert, + cert=certs.SSLCert, sni=str, timestamp_start=float, timestamp_tcp_setup=float, diff --git a/mitmproxy/proxy/config.py b/mitmproxy/proxy/config.py index a6fc739b..86b68ee5 100644 --- a/mitmproxy/proxy/config.py +++ b/mitmproxy/proxy/config.py @@ -10,7 +10,7 @@ from OpenSSL import SSL, crypto from mitmproxy import exceptions from mitmproxy import options as moptions -from netlib import certutils +from mitmproxy import certs from netlib import tcp from netlib.http import authentication from netlib.http import url @@ -106,7 +106,7 @@ class ProxyConfig: "Certificate Authority parent directory does not exist: %s" % os.path.dirname(options.cadir) ) - self.certstore = certutils.CertStore.from_store( + self.certstore = certs.CertStore.from_store( certstore_path, CONF_BASENAME ) diff --git a/netlib/certutils.py b/netlib/certutils.py deleted file mode 100644 index 9cb8a40e..00000000 --- a/netlib/certutils.py +++ /dev/null @@ -1,481 +0,0 @@ -import os -import ssl -import time -import datetime -import ipaddress - -import sys -from pyasn1.type import univ, constraint, char, namedtype, tag -from pyasn1.codec.der.decoder import decode -from pyasn1.error import PyAsn1Error -import OpenSSL - -from mitmproxy.types import serializable - -# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 - -DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 -# Generated with "openssl dhparam". It's too slow to generate this on startup. -DEFAULT_DHPARAM = b""" ------BEGIN DH PARAMETERS----- -MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 -O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv -j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ -Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB -chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC -ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq -o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX -IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv -A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8 -6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I -rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= ------END DH PARAMETERS----- -""" - - -def create_ca(o, cn, exp): - key = OpenSSL.crypto.PKey() - key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048) - cert = OpenSSL.crypto.X509() - cert.set_serial_number(int(time.time() * 10000)) - cert.set_version(2) - cert.get_subject().CN = cn - cert.get_subject().O = o - cert.gmtime_adj_notBefore(-3600 * 48) - cert.gmtime_adj_notAfter(exp) - cert.set_issuer(cert.get_subject()) - cert.set_pubkey(key) - cert.add_extensions([ - OpenSSL.crypto.X509Extension( - b"basicConstraints", - True, - b"CA:TRUE" - ), - OpenSSL.crypto.X509Extension( - b"nsCertType", - False, - b"sslCA" - ), - OpenSSL.crypto.X509Extension( - b"extendedKeyUsage", - False, - b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" - ), - OpenSSL.crypto.X509Extension( - b"keyUsage", - True, - b"keyCertSign, cRLSign" - ), - OpenSSL.crypto.X509Extension( - b"subjectKeyIdentifier", - False, - b"hash", - subject=cert - ), - ]) - cert.sign(key, "sha256") - return key, cert - - -def dummy_cert(privkey, cacert, commonname, sans): - """ - Generates a dummy certificate. - - privkey: CA private key - cacert: CA certificate - commonname: Common name for the generated certificate. - sans: A list of Subject Alternate Names. - - Returns cert if operation succeeded, None if not. - """ - ss = [] - for i in sans: - try: - ipaddress.ip_address(i.decode("ascii")) - except ValueError: - ss.append(b"DNS: %s" % i) - else: - ss.append(b"IP: %s" % i) - ss = b", ".join(ss) - - cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore(-3600 * 48) - cert.gmtime_adj_notAfter(DEFAULT_EXP) - cert.set_issuer(cacert.get_subject()) - if commonname is not None: - cert.get_subject().CN = commonname - cert.set_serial_number(int(time.time() * 10000)) - if ss: - cert.set_version(2) - cert.add_extensions( - [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)]) - cert.set_pubkey(cacert.get_pubkey()) - cert.sign(privkey, "sha256") - return SSLCert(cert) - - -# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. -# -# class _Node(UserDict.UserDict): -# def __init__(self): -# UserDict.UserDict.__init__(self) -# self.value = None -# -# -# class DNTree: -# """ -# Domain store that knows about wildcards. DNS wildcards are very -# restricted - the only valid variety is an asterisk on the left-most -# domain component, i.e.: -# -# *.foo.com -# """ -# def __init__(self): -# self.d = _Node() -# -# def add(self, dn, cert): -# parts = dn.split(".") -# parts.reverse() -# current = self.d -# for i in parts: -# current = current.setdefault(i, _Node()) -# current.value = cert -# -# def get(self, dn): -# parts = dn.split(".") -# current = self.d -# for i in reversed(parts): -# if i in current: -# current = current[i] -# elif "*" in current: -# return current["*"].value -# else: -# return None -# return current.value - - -class CertStoreEntry: - - def __init__(self, cert, privatekey, chain_file): - self.cert = cert - self.privatekey = privatekey - self.chain_file = chain_file - - -class CertStore: - - """ - Implements an in-memory certificate store. - """ - STORE_CAP = 100 - - def __init__( - self, - default_privatekey, - default_ca, - default_chain_file, - dhparams): - self.default_privatekey = default_privatekey - self.default_ca = default_ca - self.default_chain_file = default_chain_file - self.dhparams = dhparams - self.certs = dict() - self.expire_queue = [] - - def expire(self, entry): - self.expire_queue.append(entry) - if len(self.expire_queue) > self.STORE_CAP: - d = self.expire_queue.pop(0) - for k, v in list(self.certs.items()): - if v == d: - del self.certs[k] - - @staticmethod - def load_dhparam(path): - - # netlib<=0.10 doesn't generate a dhparam file. - # Create it now if neccessary. - if not os.path.exists(path): - with open(path, "wb") as f: - f.write(DEFAULT_DHPARAM) - - bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r") - if bio != OpenSSL.SSL._ffi.NULL: - bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) - dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( - bio, - OpenSSL.SSL._ffi.NULL, - OpenSSL.SSL._ffi.NULL, - OpenSSL.SSL._ffi.NULL) - dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) - return dh - - @classmethod - def from_store(cls, path, basename): - ca_path = os.path.join(path, basename + "-ca.pem") - if not os.path.exists(ca_path): - key, ca = cls.create_store(path, basename) - else: - with open(ca_path, "rb") as f: - raw = f.read() - ca = OpenSSL.crypto.load_certificate( - OpenSSL.crypto.FILETYPE_PEM, - raw) - key = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - raw) - dh_path = os.path.join(path, basename + "-dhparam.pem") - dh = cls.load_dhparam(dh_path) - return cls(key, ca, ca_path, dh) - - @staticmethod - def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP): - if not os.path.exists(path): - os.makedirs(path) - - o = o or basename - cn = cn or basename - - key, ca = create_ca(o=o, cn=cn, exp=expiry) - # Dump the CA plus private key - with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: - f.write( - OpenSSL.crypto.dump_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - key)) - f.write( - OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - ca)) - - # Dump the certificate in PEM format - with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: - f.write( - OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - ca)) - - # Create a .cer file with the same contents for Android - with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: - f.write( - OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - ca)) - - # Dump the certificate in PKCS12 format for Windows devices - with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: - p12 = OpenSSL.crypto.PKCS12() - p12.set_certificate(ca) - p12.set_privatekey(key) - f.write(p12.export()) - - with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: - f.write(DEFAULT_DHPARAM) - - return key, ca - - def add_cert_file(self, spec, path): - with open(path, "rb") as f: - raw = f.read() - cert = SSLCert( - OpenSSL.crypto.load_certificate( - OpenSSL.crypto.FILETYPE_PEM, - raw)) - try: - privatekey = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - raw) - except Exception: - privatekey = self.default_privatekey - self.add_cert( - CertStoreEntry(cert, privatekey, path), - spec - ) - - def add_cert(self, entry, *names): - """ - Adds a cert to the certstore. We register the CN in the cert plus - any SANs, and also the list of names provided as an argument. - """ - if entry.cert.cn: - self.certs[entry.cert.cn] = entry - for i in entry.cert.altnames: - self.certs[i] = entry - for i in names: - self.certs[i] = entry - - @staticmethod - def asterisk_forms(dn): - if dn is None: - return [] - parts = dn.split(b".") - parts.reverse() - curr_dn = b"" - dn_forms = [b"*"] - for part in parts[:-1]: - curr_dn = b"." + part + curr_dn # .example.com - dn_forms.append(b"*" + curr_dn) # *.example.com - if parts[-1] != b"*": - dn_forms.append(parts[-1] + curr_dn) - return dn_forms - - def get_cert(self, commonname, sans): - """ - Returns an (cert, privkey, cert_chain) tuple. - - commonname: Common name for the generated certificate. Must be a - valid, plain-ASCII, IDNA-encoded domain name. - - sans: A list of Subject Alternate Names. - """ - - potential_keys = self.asterisk_forms(commonname) - for s in sans: - potential_keys.extend(self.asterisk_forms(s)) - potential_keys.append((commonname, tuple(sans))) - - name = next( - filter(lambda key: key in self.certs, potential_keys), - None - ) - if name: - entry = self.certs[name] - else: - entry = CertStoreEntry( - cert=dummy_cert( - self.default_privatekey, - self.default_ca, - commonname, - sans), - privatekey=self.default_privatekey, - chain_file=self.default_chain_file) - self.certs[(commonname, tuple(sans))] = entry - self.expire(entry) - - return entry.cert, entry.privatekey, entry.chain_file - - -class _GeneralName(univ.Choice): - # We are only interested in dNSNames. We use a default handler to ignore - # other types. - # TODO: We should also handle iPAddresses. - componentType = namedtype.NamedTypes( - namedtype.NamedType('dNSName', char.IA5String().subtype( - implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) - ) - ), - ) - - -class _GeneralNames(univ.SequenceOf): - componentType = _GeneralName() - sizeSpec = univ.SequenceOf.sizeSpec + \ - constraint.ValueSizeConstraint(1, 1024) - - -class SSLCert(serializable.Serializable): - - def __init__(self, cert): - """ - Returns a (common name, [subject alternative names]) tuple. - """ - self.x509 = cert - - def __eq__(self, other): - return self.digest("sha256") == other.digest("sha256") - - def __ne__(self, other): - return not self.__eq__(other) - - def get_state(self): - return self.to_pem() - - def set_state(self, state): - self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) - - @classmethod - def from_state(cls, state): - return cls.from_pem(state) - - @classmethod - def from_pem(cls, txt): - x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) - return cls(x509) - - @classmethod - def from_der(cls, der): - pem = ssl.DER_cert_to_PEM_cert(der) - return cls.from_pem(pem) - - def to_pem(self): - return OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - self.x509) - - def digest(self, name): - return self.x509.digest(name) - - @property - def issuer(self): - return self.x509.get_issuer().get_components() - - @property - def notbefore(self): - t = self.x509.get_notBefore() - return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") - - @property - def notafter(self): - t = self.x509.get_notAfter() - return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") - - @property - def has_expired(self): - return self.x509.has_expired() - - @property - def subject(self): - return self.x509.get_subject().get_components() - - @property - def serial(self): - return self.x509.get_serial_number() - - @property - def keyinfo(self): - pk = self.x509.get_pubkey() - types = { - OpenSSL.crypto.TYPE_RSA: "RSA", - OpenSSL.crypto.TYPE_DSA: "DSA", - } - return ( - types.get(pk.type(), "UNKNOWN"), - pk.bits() - ) - - @property - def cn(self): - c = None - for i in self.subject: - if i[0] == b"CN": - c = i[1] - return c - - @property - def altnames(self): - """ - Returns: - All DNS altnames. - """ - # tcp.TCPClient.convert_to_ssl assumes that this property only contains DNS altnames for hostname verification. - altnames = [] - for i in range(self.x509.get_extension_count()): - ext = self.x509.get_extension(i) - if ext.get_short_name() == b"subjectAltName": - try: - dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) - except PyAsn1Error: - continue - for i in dec[0]: - altnames.append(i[0].asOctets()) - return altnames diff --git a/netlib/tcp.py b/netlib/tcp.py index 4379c9b5..6e323957 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -17,7 +17,7 @@ from backports import ssl_match_hostname import OpenSSL from OpenSSL import SSL -from netlib import certutils +from mitmproxy import certs from mitmproxy.utils import version_check from mitmproxy.types import serializable from netlib import exceptions @@ -685,11 +685,11 @@ class TCPClient(_Connection): if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error: raise self.ssl_verification_error - self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) + self.cert = certs.SSLCert(self.connection.get_peer_certificate()) # Keep all server certificates in a list for i in self.connection.get_peer_cert_chain(): - self.server_certs.append(certutils.SSLCert(i)) + self.server_certs.append(certs.SSLCert(i)) # Validate TLS Hostname try: @@ -782,7 +782,7 @@ class BaseHandler(_Connection): extra_chain_certs=None, **sslctx_kwargs): """ - cert: A certutils.SSLCert object or the path to a certificate + cert: A certs.SSLCert object or the path to a certificate chain file. handle_sni: SNI handler, should take a connection object. Server @@ -810,7 +810,7 @@ class BaseHandler(_Connection): context = self._create_ssl_context(ca_pemfile=chain_file, **sslctx_kwargs) context.use_privatekey(key) - if isinstance(cert, certutils.SSLCert): + if isinstance(cert, certs.SSLCert): context.use_certificate(cert.x509) else: context.use_certificate_chain_file(cert) @@ -825,7 +825,7 @@ class BaseHandler(_Connection): if request_client_cert: def save_cert(conn_, cert, errno_, depth_, preverify_ok_): - self.clientcert = certutils.SSLCert(cert) + self.clientcert = certs.SSLCert(cert) # Return true to prevent cert verification error return True context.set_verify(SSL.VERIFY_PEER, save_cert) diff --git a/pathod/pathoc.py b/pathod/pathoc.py index caa9accb..39dedf05 100644 --- a/pathod/pathoc.py +++ b/pathod/pathoc.py @@ -13,13 +13,17 @@ import logging from netlib.tutils import treq from mitmproxy.utils import strutils -from netlib import tcp, certutils, websockets, socks +from netlib import tcp +from mitmproxy import certs +from netlib import websockets +from netlib import socks from netlib import exceptions from netlib.http import http1 from mitmproxy.types import basethread -from . import log, language -from .protocols import http2 +from pathod import log +from pathod import language +from pathod.protocols import http2 logging.getLogger("hpack").setLevel(logging.WARNING) @@ -76,7 +80,7 @@ class SSLInfo: } t = types.get(pk.type(), "Uknown") parts.append("\tPubkey: %s bit %s" % (pk.bits(), t)) - s = certutils.SSLCert(i) + s = certs.SSLCert(i) if s.altnames: parts.append("\tSANs: %s" % " ".join(strutils.native(n, "utf8") for n in s.altnames)) return "\n".join(parts) diff --git a/pathod/pathod.py b/pathod/pathod.py index 3692ceff..5d951350 100644 --- a/pathod/pathod.py +++ b/pathod/pathod.py @@ -5,15 +5,17 @@ import sys import threading from netlib import tcp -from netlib import certutils +from mitmproxy import certs as mcerts from netlib import websockets from mitmproxy import version import urllib -from netlib.exceptions import HttpException, HttpReadDisconnect, TcpTimeout, TcpDisconnect, \ - TlsException +from netlib import exceptions -from . import language, utils, log, protocols +from pathod import language +from pathod import utils +from pathod import log +from pathod import protocols DEFAULT_CERT_DOMAIN = b"pathod.net" @@ -52,7 +54,7 @@ class SSLOptions: self.ssl_options = ssl_options self.ciphers = ciphers self.alpn_select = alpn_select - self.certstore = certutils.CertStore.from_store( + self.certstore = mcerts.CertStore.from_store( os.path.expanduser(confdir), CERTSTORE_BASENAME ) @@ -128,9 +130,9 @@ class PathodHandler(tcp.BaseHandler): with logger.ctx() as lg: try: req = self.protocol.read_request(self.rfile) - except HttpReadDisconnect: + except exceptions.HttpReadDisconnect: return None, None - except HttpException as s: + except exceptions.HttpException as s: s = str(s) lg(s) return None, dict(type="error", msg=s) @@ -252,7 +254,7 @@ class PathodHandler(tcp.BaseHandler): options=self.server.ssloptions.ssl_options, alpn_select=self.server.ssloptions.alpn_select, ) - except TlsException as v: + except exceptions.TlsException as v: s = str(v) self.server.add_log( dict( @@ -384,7 +386,7 @@ class Pathod(tcp.TCPServer): try: h.handle() h.finish() - except TcpDisconnect: # pragma: no cover + except exceptions.TcpDisconnect: # pragma: no cover log.write_raw(self.logfp, "Disconnect") self.add_log( dict( @@ -393,7 +395,7 @@ class Pathod(tcp.TCPServer): ) ) return - except TcpTimeout: + except exceptions.TcpTimeout: log.write_raw(self.logfp, "Timeout") self.add_log( dict( diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py new file mode 100644 index 00000000..35407fd6 --- /dev/null +++ b/test/mitmproxy/test_certs.py @@ -0,0 +1,181 @@ +import os +from mitmproxy import certs +from netlib import tutils + +# class TestDNTree: +# def test_simple(self): +# d = certs.DNTree() +# d.add("foo.com", "foo") +# d.add("bar.com", "bar") +# assert d.get("foo.com") == "foo" +# assert d.get("bar.com") == "bar" +# assert not d.get("oink.com") +# assert not d.get("oink") +# assert not d.get("") +# assert not d.get("oink.oink") +# +# d.add("*.match.org", "match") +# assert not d.get("match.org") +# assert d.get("foo.match.org") == "match" +# assert d.get("foo.foo.match.org") == "match" +# +# def test_wildcard(self): +# d = certs.DNTree() +# d.add("foo.com", "foo") +# assert not d.get("*.foo.com") +# d.add("*.foo.com", "wild") +# +# d = certs.DNTree() +# d.add("*", "foo") +# assert d.get("foo.com") == "foo" +# assert d.get("*.foo.com") == "foo" +# assert d.get("com") == "foo" + + +class TestCertStore: + + def test_create_explicit(self): + with tutils.tmpdir() as d: + ca = certs.CertStore.from_store(d, "test") + assert ca.get_cert(b"foo", []) + + ca2 = certs.CertStore.from_store(d, "test") + assert ca2.get_cert(b"foo", []) + + assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() + + def test_create_no_common_name(self): + with tutils.tmpdir() as d: + ca = certs.CertStore.from_store(d, "test") + assert ca.get_cert(None, [])[0].cn is None + + def test_create_tmp(self): + with tutils.tmpdir() as d: + ca = certs.CertStore.from_store(d, "test") + assert ca.get_cert(b"foo.com", []) + assert ca.get_cert(b"foo.com", []) + assert ca.get_cert(b"*.foo.com", []) + + r = ca.get_cert(b"*.foo.com", []) + assert r[1] == ca.default_privatekey + + def test_sans(self): + with tutils.tmpdir() as d: + ca = certs.CertStore.from_store(d, "test") + c1 = ca.get_cert(b"foo.com", [b"*.bar.com"]) + ca.get_cert(b"foo.bar.com", []) + # assert c1 == c2 + c3 = ca.get_cert(b"bar.com", []) + assert not c1 == c3 + + def test_sans_change(self): + with tutils.tmpdir() as d: + ca = certs.CertStore.from_store(d, "test") + ca.get_cert(b"foo.com", [b"*.bar.com"]) + cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"]) + assert b"*.baz.com" in cert.altnames + + def test_expire(self): + with tutils.tmpdir() as d: + ca = certs.CertStore.from_store(d, "test") + ca.STORE_CAP = 3 + ca.get_cert(b"one.com", []) + ca.get_cert(b"two.com", []) + ca.get_cert(b"three.com", []) + + assert (b"one.com", ()) in ca.certs + assert (b"two.com", ()) in ca.certs + assert (b"three.com", ()) in ca.certs + + ca.get_cert(b"one.com", []) + + assert (b"one.com", ()) in ca.certs + assert (b"two.com", ()) in ca.certs + assert (b"three.com", ()) in ca.certs + + ca.get_cert(b"four.com", []) + + assert (b"one.com", ()) not in ca.certs + assert (b"two.com", ()) in ca.certs + assert (b"three.com", ()) in ca.certs + assert (b"four.com", ()) in ca.certs + + def test_overrides(self): + with tutils.tmpdir() as d: + ca1 = certs.CertStore.from_store(os.path.join(d, "ca1"), "test") + ca2 = certs.CertStore.from_store(os.path.join(d, "ca2"), "test") + assert not ca1.default_ca.get_serial_number( + ) == ca2.default_ca.get_serial_number() + + dc = ca2.get_cert(b"foo.com", [b"sans.example.com"]) + dcp = os.path.join(d, "dc") + f = open(dcp, "wb") + f.write(dc[0].to_pem()) + f.close() + ca1.add_cert_file(b"foo.com", dcp) + + ret = ca1.get_cert(b"foo.com", []) + assert ret[0].serial == dc[0].serial + + +class TestDummyCert: + + def test_with_ca(self): + with tutils.tmpdir() as d: + ca = certs.CertStore.from_store(d, "test") + r = certs.dummy_cert( + ca.default_privatekey, + ca.default_ca, + b"foo.com", + [b"one.com", b"two.com", b"*.three.com"] + ) + assert r.cn == b"foo.com" + + r = certs.dummy_cert( + ca.default_privatekey, + ca.default_ca, + None, + [] + ) + assert r.cn is None + + +class TestSSLCert: + + def test_simple(self): + with open(tutils.test_data.path("data/text_cert"), "rb") as f: + d = f.read() + c1 = certs.SSLCert.from_pem(d) + assert c1.cn == b"google.com" + assert len(c1.altnames) == 436 + + with open(tutils.test_data.path("data/text_cert_2"), "rb") as f: + d = f.read() + c2 = certs.SSLCert.from_pem(d) + assert c2.cn == b"www.inode.co.nz" + assert len(c2.altnames) == 2 + assert c2.digest("sha1") + assert c2.notbefore + assert c2.notafter + assert c2.subject + assert c2.keyinfo == ("RSA", 2048) + assert c2.serial + assert c2.issuer + assert c2.to_pem() + assert c2.has_expired is not None + + assert not c1 == c2 + assert c1 != c2 + + def test_err_broken_sans(self): + with open(tutils.test_data.path("data/text_cert_weird1"), "rb") as f: + d = f.read() + c = certs.SSLCert.from_pem(d) + # This breaks unless we ignore a decoding error. + assert c.altnames is not None + + def test_der(self): + with open(tutils.test_data.path("data/dercert"), "rb") as f: + d = f.read() + s = certs.SSLCert.from_der(d) + assert s.cn diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index 93a82954..cadc67a8 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -9,13 +9,16 @@ from mitmproxy.addons import script from mitmproxy import http from mitmproxy.proxy.config import HostMatcher, parse_server_spec import netlib.http -from netlib import tcp, socks -from netlib.certutils import SSLCert -from netlib.exceptions import HttpReadDisconnect, HttpException -from netlib.http import authentication, http1 +from netlib import tcp +from netlib import socks +from mitmproxy import certs +from netlib import exceptions +from netlib.http import authentication +from netlib.http import http1 from netlib.tcp import Address from netlib.tutils import raises -from pathod import pathoc, pathod +from pathod import pathoc +from pathod import pathod from . import tutils, tservers @@ -144,9 +147,9 @@ class TcpMixin: # Test that we get the original SSL cert if self.ssl: - i_cert = SSLCert(i.sslinfo.certchain[0]) - i2_cert = SSLCert(i2.sslinfo.certchain[0]) - n_cert = SSLCert(n.sslinfo.certchain[0]) + i_cert = certs.SSLCert(i.sslinfo.certchain[0]) + i2_cert = certs.SSLCert(i2.sslinfo.certchain[0]) + n_cert = certs.SSLCert(n.sslinfo.certchain[0]) assert i_cert == i2_cert assert i_cert != n_cert @@ -156,7 +159,7 @@ class TcpMixin: # mitmproxy responds with bad gateway assert self.pathod(spec).status_code == 502 self._ignore_on() - with raises(HttpException): + with raises(exceptions.HttpException): self.pathod(spec) # pathoc tries to parse answer as HTTP self._ignore_off() @@ -190,9 +193,9 @@ class TcpMixin: # Test that we get the original SSL cert if self.ssl: - i_cert = SSLCert(i.sslinfo.certchain[0]) - i2_cert = SSLCert(i2.sslinfo.certchain[0]) - n_cert = SSLCert(n.sslinfo.certchain[0]) + i_cert = certs.SSLCert(i.sslinfo.certchain[0]) + i2_cert = certs.SSLCert(i2.sslinfo.certchain[0]) + n_cert = certs.SSLCert(n.sslinfo.certchain[0]) assert i_cert == i2_cert == n_cert @@ -830,7 +833,7 @@ class TestKillRequest(tservers.HTTPProxyTest): masterclass = MasterKillRequest def test_kill(self): - with raises(HttpReadDisconnect): + with raises(exceptions.HttpReadDisconnect): self.pathod("200") # Nothing should have hit the server assert not self.server.last_log() @@ -847,7 +850,7 @@ class TestKillResponse(tservers.HTTPProxyTest): masterclass = MasterKillResponse def test_kill(self): - with raises(HttpReadDisconnect): + with raises(exceptions.HttpReadDisconnect): self.pathod("200") # The server should have seen a request assert self.server.last_log() @@ -1050,7 +1053,7 @@ class AddUpstreamCertsToClientChainMixin: def test_add_upstream_certs_to_client_chain(self): with open(self.servercert, "rb") as f: d = f.read() - upstreamCert = SSLCert.from_pem(d) + upstreamCert = certs.SSLCert.from_pem(d) p = self.pathoc() with p.connect(): upstream_cert_found_in_client_chain = False diff --git a/test/netlib/test_certutils.py b/test/netlib/test_certutils.py deleted file mode 100644 index cf9a671b..00000000 --- a/test/netlib/test_certutils.py +++ /dev/null @@ -1,180 +0,0 @@ -import os -from netlib import certutils, tutils - -# class TestDNTree: -# def test_simple(self): -# d = certutils.DNTree() -# d.add("foo.com", "foo") -# d.add("bar.com", "bar") -# assert d.get("foo.com") == "foo" -# assert d.get("bar.com") == "bar" -# assert not d.get("oink.com") -# assert not d.get("oink") -# assert not d.get("") -# assert not d.get("oink.oink") -# -# d.add("*.match.org", "match") -# assert not d.get("match.org") -# assert d.get("foo.match.org") == "match" -# assert d.get("foo.foo.match.org") == "match" -# -# def test_wildcard(self): -# d = certutils.DNTree() -# d.add("foo.com", "foo") -# assert not d.get("*.foo.com") -# d.add("*.foo.com", "wild") -# -# d = certutils.DNTree() -# d.add("*", "foo") -# assert d.get("foo.com") == "foo" -# assert d.get("*.foo.com") == "foo" -# assert d.get("com") == "foo" - - -class TestCertStore: - - def test_create_explicit(self): - with tutils.tmpdir() as d: - ca = certutils.CertStore.from_store(d, "test") - assert ca.get_cert(b"foo", []) - - ca2 = certutils.CertStore.from_store(d, "test") - assert ca2.get_cert(b"foo", []) - - assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() - - def test_create_no_common_name(self): - with tutils.tmpdir() as d: - ca = certutils.CertStore.from_store(d, "test") - assert ca.get_cert(None, [])[0].cn is None - - def test_create_tmp(self): - with tutils.tmpdir() as d: - ca = certutils.CertStore.from_store(d, "test") - assert ca.get_cert(b"foo.com", []) - assert ca.get_cert(b"foo.com", []) - assert ca.get_cert(b"*.foo.com", []) - - r = ca.get_cert(b"*.foo.com", []) - assert r[1] == ca.default_privatekey - - def test_sans(self): - with tutils.tmpdir() as d: - ca = certutils.CertStore.from_store(d, "test") - c1 = ca.get_cert(b"foo.com", [b"*.bar.com"]) - ca.get_cert(b"foo.bar.com", []) - # assert c1 == c2 - c3 = ca.get_cert(b"bar.com", []) - assert not c1 == c3 - - def test_sans_change(self): - with tutils.tmpdir() as d: - ca = certutils.CertStore.from_store(d, "test") - ca.get_cert(b"foo.com", [b"*.bar.com"]) - cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"]) - assert b"*.baz.com" in cert.altnames - - def test_expire(self): - with tutils.tmpdir() as d: - ca = certutils.CertStore.from_store(d, "test") - ca.STORE_CAP = 3 - ca.get_cert(b"one.com", []) - ca.get_cert(b"two.com", []) - ca.get_cert(b"three.com", []) - - assert (b"one.com", ()) in ca.certs - assert (b"two.com", ()) in ca.certs - assert (b"three.com", ()) in ca.certs - - ca.get_cert(b"one.com", []) - - assert (b"one.com", ()) in ca.certs - assert (b"two.com", ()) in ca.certs - assert (b"three.com", ()) in ca.certs - - ca.get_cert(b"four.com", []) - - assert (b"one.com", ()) not in ca.certs - assert (b"two.com", ()) in ca.certs - assert (b"three.com", ()) in ca.certs - assert (b"four.com", ()) in ca.certs - - def test_overrides(self): - with tutils.tmpdir() as d: - ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") - ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test") - assert not ca1.default_ca.get_serial_number( - ) == ca2.default_ca.get_serial_number() - - dc = ca2.get_cert(b"foo.com", [b"sans.example.com"]) - dcp = os.path.join(d, "dc") - f = open(dcp, "wb") - f.write(dc[0].to_pem()) - f.close() - ca1.add_cert_file(b"foo.com", dcp) - - ret = ca1.get_cert(b"foo.com", []) - assert ret[0].serial == dc[0].serial - - -class TestDummyCert: - - def test_with_ca(self): - with tutils.tmpdir() as d: - ca = certutils.CertStore.from_store(d, "test") - r = certutils.dummy_cert( - ca.default_privatekey, - ca.default_ca, - b"foo.com", - [b"one.com", b"two.com", b"*.three.com"] - ) - assert r.cn == b"foo.com" - - r = certutils.dummy_cert( - ca.default_privatekey, - ca.default_ca, - None, - [] - ) - assert r.cn is None - - -class TestSSLCert: - - def test_simple(self): - with open(tutils.test_data.path("data/text_cert"), "rb") as f: - d = f.read() - c1 = certutils.SSLCert.from_pem(d) - assert c1.cn == b"google.com" - assert len(c1.altnames) == 436 - - with open(tutils.test_data.path("data/text_cert_2"), "rb") as f: - d = f.read() - c2 = certutils.SSLCert.from_pem(d) - assert c2.cn == b"www.inode.co.nz" - assert len(c2.altnames) == 2 - assert c2.digest("sha1") - assert c2.notbefore - assert c2.notafter - assert c2.subject - assert c2.keyinfo == ("RSA", 2048) - assert c2.serial - assert c2.issuer - assert c2.to_pem() - assert c2.has_expired is not None - - assert not c1 == c2 - assert c1 != c2 - - def test_err_broken_sans(self): - with open(tutils.test_data.path("data/text_cert_weird1"), "rb") as f: - d = f.read() - c = certutils.SSLCert.from_pem(d) - # This breaks unless we ignore a decoding error. - assert c.altnames is not None - - def test_der(self): - with open(tutils.test_data.path("data/dercert"), "rb") as f: - d = f.read() - s = certutils.SSLCert.from_der(d) - assert s.cn diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py index 797a5a04..2c1b92dc 100644 --- a/test/netlib/test_tcp.py +++ b/test/netlib/test_tcp.py @@ -9,9 +9,10 @@ import mock from OpenSSL import SSL -from netlib import tcp, certutils, tutils -from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ - TcpTimeout, TcpDisconnect, TcpException, NetlibException +from mitmproxy import certs +from netlib import tcp +from netlib import tutils +from netlib import exceptions from . import tservers @@ -108,7 +109,7 @@ class TestServerBind(tservers.ServerTestBase): with c.connect(): assert c.rfile.readline() == str(("127.0.0.1", random_port)).encode() return - except TcpException: # port probably already in use + except exceptions.TcpException: # port probably already in use pass @@ -155,7 +156,7 @@ class TestFinishFail(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): c.wfile.write(b"foo\n") - c.wfile.flush = mock.Mock(side_effect=TcpDisconnect) + c.wfile.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) c.finish() @@ -195,7 +196,7 @@ class TestSSLv3Only(tservers.ServerTestBase): def test_failure(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com") + tutils.raises(exceptions.TlsException, c.convert_to_ssl, sni="foo.com") class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): @@ -236,7 +237,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): def test_mode_strict_should_fail(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - with tutils.raises(InvalidCertificateException): + with tutils.raises(exceptions.InvalidCertificateException): c.convert_to_ssl( sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, @@ -261,7 +262,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): def test_should_fail_without_sni(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - with tutils.raises(TlsException): + with tutils.raises(exceptions.TlsException): c.convert_to_ssl( verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") @@ -270,7 +271,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): def test_should_fail(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - with tutils.raises(InvalidCertificateException): + with tutils.raises(exceptions.InvalidCertificateException): c.convert_to_ssl( sni="mitmproxy.org", verify_options=SSL.VERIFY_PEER, @@ -348,7 +349,7 @@ class TestSSLClientCert(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): tutils.raises( - TlsException, + exceptions.TlsException, c.convert_to_ssl, cert=tutils.test_data.path("data/clientcert/make") ) @@ -454,7 +455,7 @@ class TestSSLDisconnect(tservers.ServerTestBase): # Excercise SSL.ZeroReturnError c.rfile.read(10) c.close() - tutils.raises(TcpDisconnect, c.wfile.write, b"foo") + tutils.raises(exceptions.TcpDisconnect, c.wfile.write, b"foo") tutils.raises(queue.Empty, self.q.get_nowait) @@ -469,7 +470,7 @@ class TestSSLHardDisconnect(tservers.ServerTestBase): # Exercise SSL.SysCallError c.rfile.read(10) c.close() - tutils.raises(TcpDisconnect, c.wfile.write, b"foo") + tutils.raises(exceptions.TcpDisconnect, c.wfile.write, b"foo") class TestDisconnect(tservers.ServerTestBase): @@ -492,7 +493,7 @@ class TestServerTimeOut(tservers.ServerTestBase): self.settimeout(0.01) try: self.rfile.read(10) - except TcpTimeout: + except exceptions.TcpTimeout: self.timeout = True def test_timeout(self): @@ -510,7 +511,7 @@ class TestTimeOut(tservers.ServerTestBase): with c.connect(): c.settimeout(0.1) assert c.gettimeout() == 0.1 - tutils.raises(TcpTimeout, c.rfile.read, 10) + tutils.raises(exceptions.TcpTimeout, c.rfile.read, 10) class TestALPNClient(tservers.ServerTestBase): @@ -562,13 +563,13 @@ class TestSSLTimeOut(tservers.ServerTestBase): with c.connect(): c.convert_to_ssl() c.settimeout(0.1) - tutils.raises(TcpTimeout, c.rfile.read, 10) + tutils.raises(exceptions.TcpTimeout, c.rfile.read, 10) class TestDHParams(tservers.ServerTestBase): handler = HangHandler ssl = dict( - dhparams=certutils.CertStore.load_dhparam( + dhparams=certs.CertStore.load_dhparam( tutils.test_data.path("data/dhparam.pem"), ), cipher_list="DHE-RSA-AES256-SHA" @@ -584,7 +585,7 @@ class TestDHParams(tservers.ServerTestBase): def test_create_dhparams(self): with tutils.tmpdir() as d: filename = os.path.join(d, "dhparam.pem") - certutils.CertStore.load_dhparam(filename) + certs.CertStore.load_dhparam(filename) assert os.path.exists(filename) @@ -592,7 +593,7 @@ class TestTCPClient: def test_conerr(self): c = tcp.TCPClient(("127.0.0.1", 0)) - tutils.raises(TcpException, c.connect) + tutils.raises(exceptions.TcpException, c.connect) class TestFileLike: @@ -661,7 +662,7 @@ class TestFileLike: o = mock.MagicMock() o.flush = mock.MagicMock(side_effect=socket.error) s.o = o - tutils.raises(TcpDisconnect, s.flush) + tutils.raises(exceptions.TcpDisconnect, s.flush) def test_reader_read_error(self): s = BytesIO(b"foobar\nfoobar") @@ -669,7 +670,7 @@ class TestFileLike: o = mock.MagicMock() o.read = mock.MagicMock(side_effect=socket.error) s.o = o - tutils.raises(TcpDisconnect, s.read, 10) + tutils.raises(exceptions.TcpDisconnect, s.read, 10) def test_reset_timestamps(self): s = BytesIO(b"foobar\nfoobar") @@ -700,24 +701,24 @@ class TestFileLike: s = mock.MagicMock() s.read = mock.MagicMock(side_effect=SSL.Error()) s = tcp.Reader(s) - tutils.raises(TlsException, s.read, 1) + tutils.raises(exceptions.TlsException, s.read, 1) def test_read_syscall_ssl_error(self): s = mock.MagicMock() s.read = mock.MagicMock(side_effect=SSL.SysCallError()) s = tcp.Reader(s) - tutils.raises(TlsException, s.read, 1) + tutils.raises(exceptions.TlsException, s.read, 1) def test_reader_readline_disconnect(self): o = mock.MagicMock() o.read = mock.MagicMock(side_effect=socket.error) s = tcp.Reader(o) - tutils.raises(TcpDisconnect, s.readline, 10) + tutils.raises(exceptions.TcpDisconnect, s.readline, 10) def test_reader_incomplete_error(self): s = BytesIO(b"foobar") s = tcp.Reader(s) - tutils.raises(TcpReadIncomplete, s.safe_read, 10) + tutils.raises(exceptions.TcpReadIncomplete, s.safe_read, 10) class TestPeek(tservers.ServerTestBase): @@ -738,11 +739,11 @@ class TestPeek(tservers.ServerTestBase): assert c.rfile.readline() == testval c.close() - with tutils.raises(NetlibException): + with tutils.raises(exceptions.NetlibException): if c.rfile.peek(1) == b"": # Workaround for Python 2 on Unix: # Peeking a closed connection does not raise an exception here. - raise NetlibException() + raise exceptions.NetlibException() class TestPeekSSL(TestPeek): -- cgit v1.2.3