diff options
Diffstat (limited to 'netlib/certutils.py')
-rw-r--r-- | netlib/certutils.py | 247 |
1 files changed, 139 insertions, 108 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py index 187abfae..af6177d8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,10 +1,10 @@ +from __future__ import (absolute_import, print_function, division) import os, ssl, time, datetime +import itertools from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -import tcp -import UserDict DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 # Generated with "openssl dhparam". It's too slow to generate this on startup. @@ -29,12 +29,12 @@ def create_ca(o, cn, exp): cert.add_extensions([ OpenSSL.crypto.X509Extension("basicConstraints", True, "CA:TRUE"), - OpenSSL.crypto.X509Extension("nsCertType", True, + OpenSSL.crypto.X509Extension("nsCertType", False, "sslCA"), - OpenSSL.crypto.X509Extension("extendedKeyUsage", True, + OpenSSL.crypto.X509Extension("extendedKeyUsage", False, "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" ), - OpenSSL.crypto.X509Extension("keyUsage", False, + OpenSSL.crypto.X509Extension("keyUsage", True, "keyCertSign, cRLSign"), OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", subject=cert), @@ -67,62 +67,72 @@ def dummy_cert(privkey, cacert, commonname, sans): cert.set_serial_number(int(time.time()*10000)) if ss: cert.set_version(2) - cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) + cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) cert.set_pubkey(cacert.get_pubkey()) cert.sign(privkey, "sha1") return SSLCert(cert) -class _Node(UserDict.UserDict): - def __init__(self): - UserDict.UserDict.__init__(self) - self.value = None - - -class DNTree: - """ - Domain store that knows about wildcards. DNS wildcards are very - restricted - the only valid variety is an asterisk on the left-most - domain component, i.e.: - - *.foo.com - """ - def __init__(self): - self.d = _Node() - - def add(self, dn, cert): - parts = dn.split(".") - parts.reverse() - current = self.d - for i in parts: - current = current.setdefault(i, _Node()) - current.value = cert - - def get(self, dn): - parts = dn.split(".") - current = self.d - for i in reversed(parts): - if i in current: - current = current[i] - elif "*" in current: - return current["*"].value - else: - return None - return current.value - +# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. +# +# class _Node(UserDict.UserDict): +# def __init__(self): +# UserDict.UserDict.__init__(self) +# self.value = None +# +# +# class DNTree: +# """ +# Domain store that knows about wildcards. DNS wildcards are very +# restricted - the only valid variety is an asterisk on the left-most +# domain component, i.e.: +# +# *.foo.com +# """ +# def __init__(self): +# self.d = _Node() +# +# def add(self, dn, cert): +# parts = dn.split(".") +# parts.reverse() +# current = self.d +# for i in parts: +# current = current.setdefault(i, _Node()) +# current.value = cert +# +# def get(self, dn): +# parts = dn.split(".") +# current = self.d +# for i in reversed(parts): +# if i in current: +# current = current[i] +# elif "*" in current: +# return current["*"].value +# else: +# return None +# return current.value + + +class CertStoreEntry(object): + def __init__(self, cert, privatekey, chain_file): + self.cert = cert + self.privatekey = privatekey + self.chain_file = chain_file class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, privkey, cacert, dhparams=None): - self.privkey, self.cacert = privkey, cacert + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None): + self.default_privatekey = default_privatekey + self.default_ca = default_ca + self.default_chain_file = default_chain_file self.dhparams = dhparams - self.certs = DNTree() + self.certs = dict() - @classmethod - def load_dhparam(klass, path): + @staticmethod + def load_dhparam(path): # netlib<=0.10 doesn't generate a dhparam file. # Create it now if neccessary. @@ -140,21 +150,21 @@ class CertStore: return dh @classmethod - def from_store(klass, path, basename): - p = os.path.join(path, basename + "-ca.pem") - if not os.path.exists(p): - key, ca = klass.create_store(path, basename) + def from_store(cls, path, basename): + ca_path = os.path.join(path, basename + "-ca.pem") + if not os.path.exists(ca_path): + key, ca = cls.create_store(path, basename) else: - p = os.path.join(path, basename + "-ca.pem") - raw = file(p, "rb").read() + with open(ca_path, "rb") as f: + raw = f.read() ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - dhp = os.path.join(path, basename + "-dhparam.pem") - dh = klass.load_dhparam(dhp) - return klass(key, ca, dh) + dh_path = os.path.join(path, basename + "-dhparam.pem") + dh = cls.load_dhparam(dh_path) + return cls(key, ca, ca_path, dh) - @classmethod - def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): + @staticmethod + def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP): if not os.path.exists(path): os.makedirs(path) @@ -163,58 +173,71 @@ class CertStore: key, ca = create_ca(o=o, cn=cn, exp=expiry) # Dump the CA plus private key - f = open(os.path.join(path, basename + "-ca.pem"), "wb") - f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: + f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Dump the certificate in PEM format - f = open(os.path.join(path, basename + "-ca-cert.pem"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Create a .cer file with the same contents for Android - f = open(os.path.join(path, basename + "-ca-cert.cer"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(path, basename + "-ca-cert.p12"), "wb") - p12 = OpenSSL.crypto.PKCS12() - p12.set_certificate(ca) - p12.set_privatekey(key) - f.write(p12.export()) - f.close() - - f = open(os.path.join(path, basename + "-dhparam.pem"), "wb") - f.write(DEFAULT_DHPARAM) - f.close() + with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: + p12 = OpenSSL.crypto.PKCS12() + p12.set_certificate(ca) + p12.set_privatekey(key) + f.write(p12.export()) + + with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: + f.write(DEFAULT_DHPARAM) + return key, ca def add_cert_file(self, spec, path): - raw = file(path, "rb").read() - cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + with open(path, "rb") as f: + raw = f.read() + cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)) try: - privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + privatekey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) except Exception: - privkey = None - self.add_cert(SSLCert(cert), privkey, spec) + privatekey = self.default_privatekey + self.add_cert( + CertStoreEntry(cert, privatekey, path), + spec + ) - def add_cert(self, cert, privkey, *names): + def add_cert(self, entry, *names): """ Adds a cert to the certstore. We register the CN in the cert plus any SANs, and also the list of names provided as an argument. """ - if cert.cn: - self.certs.add(cert.cn, (cert, privkey)) - for i in cert.altnames: - self.certs.add(i, (cert, privkey)) + if entry.cert.cn: + self.certs[entry.cert.cn] = entry + for i in entry.cert.altnames: + self.certs[i] = entry for i in names: - self.certs.add(i, (cert, privkey)) + self.certs[i] = entry + + @staticmethod + def asterisk_forms(dn): + parts = dn.split(".") + parts.reverse() + curr_dn = "" + dn_forms = ["*"] + for part in parts[:-1]: + curr_dn = "." + part + curr_dn # .example.com + dn_forms.append("*" + curr_dn) # *.example.com + if parts[-1] != "*": + dn_forms.append(parts[-1] + curr_dn) + return dn_forms def get_cert(self, commonname, sans): """ - Returns an (cert, privkey) tuple. + Returns an (cert, privkey, cert_chain) tuple. commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. @@ -223,17 +246,30 @@ class CertStore: Return None if the certificate could not be found or generated. """ - c = self.certs.get(commonname) - if not c: - c = dummy_cert(self.privkey, self.cacert, commonname, sans) - self.add_cert(c, None) - c = (c, None) - return (c[0], c[1] or self.privkey) + + potential_keys = self.asterisk_forms(commonname) + for s in sans: + potential_keys.extend(self.asterisk_forms(s)) + potential_keys.append((commonname, tuple(sans))) + + name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) + if name: + entry = self.certs[name] + else: + entry = CertStoreEntry( + cert=dummy_cert(self.default_privatekey, self.default_ca, commonname, sans), + privatekey=self.default_privatekey, + chain_file=self.default_chain_file + ) + self.certs[(commonname, tuple(sans))] = entry + + return entry.cert, entry.privatekey, entry.chain_file def gen_pkey(self, cert): - import certffi - certffi.set_flags(self.privkey, 1) - return self.privkey + # FIXME: We should do something with cert here? + from . import certffi + certffi.set_flags(self.default_privatekey, 1) + return self.default_privatekey class _GeneralName(univ.Choice): @@ -262,6 +298,9 @@ class SSLCert: def __eq__(self, other): return self.digest("sha1") == other.digest("sha1") + def __ne__(self, other): + return not self.__eq__(other) + @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) @@ -337,11 +376,3 @@ class SSLCert: for i in dec[0]: altnames.append(i[0].asOctets()) return altnames - - - -def get_remote_cert(host, port, sni): - c = tcp.TCPClient((host, port)) - c.connect() - c.convert_to_ssl(sni=sni) - return c.cert |