aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/certutils.py
diff options
context:
space:
mode:
Diffstat (limited to 'netlib/certutils.py')
-rw-r--r--netlib/certutils.py247
1 files changed, 139 insertions, 108 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py
index 187abfae..af6177d8 100644
--- a/netlib/certutils.py
+++ b/netlib/certutils.py
@@ -1,10 +1,10 @@
+from __future__ import (absolute_import, print_function, division)
import os, ssl, time, datetime
+import itertools
from pyasn1.type import univ, constraint, char, namedtype, tag
from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error
import OpenSSL
-import tcp
-import UserDict
DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720
# Generated with "openssl dhparam". It's too slow to generate this on startup.
@@ -29,12 +29,12 @@ def create_ca(o, cn, exp):
cert.add_extensions([
OpenSSL.crypto.X509Extension("basicConstraints", True,
"CA:TRUE"),
- OpenSSL.crypto.X509Extension("nsCertType", True,
+ OpenSSL.crypto.X509Extension("nsCertType", False,
"sslCA"),
- OpenSSL.crypto.X509Extension("extendedKeyUsage", True,
+ OpenSSL.crypto.X509Extension("extendedKeyUsage", False,
"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC"
),
- OpenSSL.crypto.X509Extension("keyUsage", False,
+ OpenSSL.crypto.X509Extension("keyUsage", True,
"keyCertSign, cRLSign"),
OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash",
subject=cert),
@@ -67,62 +67,72 @@ def dummy_cert(privkey, cacert, commonname, sans):
cert.set_serial_number(int(time.time()*10000))
if ss:
cert.set_version(2)
- cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)])
+ cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)])
cert.set_pubkey(cacert.get_pubkey())
cert.sign(privkey, "sha1")
return SSLCert(cert)
-class _Node(UserDict.UserDict):
- def __init__(self):
- UserDict.UserDict.__init__(self)
- self.value = None
-
-
-class DNTree:
- """
- Domain store that knows about wildcards. DNS wildcards are very
- restricted - the only valid variety is an asterisk on the left-most
- domain component, i.e.:
-
- *.foo.com
- """
- def __init__(self):
- self.d = _Node()
-
- def add(self, dn, cert):
- parts = dn.split(".")
- parts.reverse()
- current = self.d
- for i in parts:
- current = current.setdefault(i, _Node())
- current.value = cert
-
- def get(self, dn):
- parts = dn.split(".")
- current = self.d
- for i in reversed(parts):
- if i in current:
- current = current[i]
- elif "*" in current:
- return current["*"].value
- else:
- return None
- return current.value
-
+# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict.
+#
+# class _Node(UserDict.UserDict):
+# def __init__(self):
+# UserDict.UserDict.__init__(self)
+# self.value = None
+#
+#
+# class DNTree:
+# """
+# Domain store that knows about wildcards. DNS wildcards are very
+# restricted - the only valid variety is an asterisk on the left-most
+# domain component, i.e.:
+#
+# *.foo.com
+# """
+# def __init__(self):
+# self.d = _Node()
+#
+# def add(self, dn, cert):
+# parts = dn.split(".")
+# parts.reverse()
+# current = self.d
+# for i in parts:
+# current = current.setdefault(i, _Node())
+# current.value = cert
+#
+# def get(self, dn):
+# parts = dn.split(".")
+# current = self.d
+# for i in reversed(parts):
+# if i in current:
+# current = current[i]
+# elif "*" in current:
+# return current["*"].value
+# else:
+# return None
+# return current.value
+
+
+class CertStoreEntry(object):
+ def __init__(self, cert, privatekey, chain_file):
+ self.cert = cert
+ self.privatekey = privatekey
+ self.chain_file = chain_file
class CertStore:
"""
Implements an in-memory certificate store.
"""
- def __init__(self, privkey, cacert, dhparams=None):
- self.privkey, self.cacert = privkey, cacert
+ def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None):
+ self.default_privatekey = default_privatekey
+ self.default_ca = default_ca
+ self.default_chain_file = default_chain_file
self.dhparams = dhparams
- self.certs = DNTree()
+ self.certs = dict()
- @classmethod
- def load_dhparam(klass, path):
+ @staticmethod
+ def load_dhparam(path):
# netlib<=0.10 doesn't generate a dhparam file.
# Create it now if neccessary.
@@ -140,21 +150,21 @@ class CertStore:
return dh
@classmethod
- def from_store(klass, path, basename):
- p = os.path.join(path, basename + "-ca.pem")
- if not os.path.exists(p):
- key, ca = klass.create_store(path, basename)
+ def from_store(cls, path, basename):
+ ca_path = os.path.join(path, basename + "-ca.pem")
+ if not os.path.exists(ca_path):
+ key, ca = cls.create_store(path, basename)
else:
- p = os.path.join(path, basename + "-ca.pem")
- raw = file(p, "rb").read()
+ with open(ca_path, "rb") as f:
+ raw = f.read()
ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)
key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
- dhp = os.path.join(path, basename + "-dhparam.pem")
- dh = klass.load_dhparam(dhp)
- return klass(key, ca, dh)
+ dh_path = os.path.join(path, basename + "-dhparam.pem")
+ dh = cls.load_dhparam(dh_path)
+ return cls(key, ca, ca_path, dh)
- @classmethod
- def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP):
+ @staticmethod
+ def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP):
if not os.path.exists(path):
os.makedirs(path)
@@ -163,58 +173,71 @@ class CertStore:
key, ca = create_ca(o=o, cn=cn, exp=expiry)
# Dump the CA plus private key
- f = open(os.path.join(path, basename + "-ca.pem"), "wb")
- f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key))
- f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
- f.close()
+ with open(os.path.join(path, basename + "-ca.pem"), "wb") as f:
+ f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key))
+ f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
# Dump the certificate in PEM format
- f = open(os.path.join(path, basename + "-ca-cert.pem"), "wb")
- f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
- f.close()
+ with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f:
+ f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
# Create a .cer file with the same contents for Android
- f = open(os.path.join(path, basename + "-ca-cert.cer"), "wb")
- f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
- f.close()
+ with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f:
+ f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
# Dump the certificate in PKCS12 format for Windows devices
- f = open(os.path.join(path, basename + "-ca-cert.p12"), "wb")
- p12 = OpenSSL.crypto.PKCS12()
- p12.set_certificate(ca)
- p12.set_privatekey(key)
- f.write(p12.export())
- f.close()
-
- f = open(os.path.join(path, basename + "-dhparam.pem"), "wb")
- f.write(DEFAULT_DHPARAM)
- f.close()
+ with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f:
+ p12 = OpenSSL.crypto.PKCS12()
+ p12.set_certificate(ca)
+ p12.set_privatekey(key)
+ f.write(p12.export())
+
+ with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f:
+ f.write(DEFAULT_DHPARAM)
+
return key, ca
def add_cert_file(self, spec, path):
- raw = file(path, "rb").read()
- cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)
+ with open(path, "rb") as f:
+ raw = f.read()
+ cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw))
try:
- privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
+ privatekey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
except Exception:
- privkey = None
- self.add_cert(SSLCert(cert), privkey, spec)
+ privatekey = self.default_privatekey
+ self.add_cert(
+ CertStoreEntry(cert, privatekey, path),
+ spec
+ )
- def add_cert(self, cert, privkey, *names):
+ def add_cert(self, entry, *names):
"""
Adds a cert to the certstore. We register the CN in the cert plus
any SANs, and also the list of names provided as an argument.
"""
- if cert.cn:
- self.certs.add(cert.cn, (cert, privkey))
- for i in cert.altnames:
- self.certs.add(i, (cert, privkey))
+ if entry.cert.cn:
+ self.certs[entry.cert.cn] = entry
+ for i in entry.cert.altnames:
+ self.certs[i] = entry
for i in names:
- self.certs.add(i, (cert, privkey))
+ self.certs[i] = entry
+
+ @staticmethod
+ def asterisk_forms(dn):
+ parts = dn.split(".")
+ parts.reverse()
+ curr_dn = ""
+ dn_forms = ["*"]
+ for part in parts[:-1]:
+ curr_dn = "." + part + curr_dn # .example.com
+ dn_forms.append("*" + curr_dn) # *.example.com
+ if parts[-1] != "*":
+ dn_forms.append(parts[-1] + curr_dn)
+ return dn_forms
def get_cert(self, commonname, sans):
"""
- Returns an (cert, privkey) tuple.
+ Returns an (cert, privkey, cert_chain) tuple.
commonname: Common name for the generated certificate. Must be a
valid, plain-ASCII, IDNA-encoded domain name.
@@ -223,17 +246,30 @@ class CertStore:
Return None if the certificate could not be found or generated.
"""
- c = self.certs.get(commonname)
- if not c:
- c = dummy_cert(self.privkey, self.cacert, commonname, sans)
- self.add_cert(c, None)
- c = (c, None)
- return (c[0], c[1] or self.privkey)
+
+ potential_keys = self.asterisk_forms(commonname)
+ for s in sans:
+ potential_keys.extend(self.asterisk_forms(s))
+ potential_keys.append((commonname, tuple(sans)))
+
+ name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None)
+ if name:
+ entry = self.certs[name]
+ else:
+ entry = CertStoreEntry(
+ cert=dummy_cert(self.default_privatekey, self.default_ca, commonname, sans),
+ privatekey=self.default_privatekey,
+ chain_file=self.default_chain_file
+ )
+ self.certs[(commonname, tuple(sans))] = entry
+
+ return entry.cert, entry.privatekey, entry.chain_file
def gen_pkey(self, cert):
- import certffi
- certffi.set_flags(self.privkey, 1)
- return self.privkey
+ # FIXME: We should do something with cert here?
+ from . import certffi
+ certffi.set_flags(self.default_privatekey, 1)
+ return self.default_privatekey
class _GeneralName(univ.Choice):
@@ -262,6 +298,9 @@ class SSLCert:
def __eq__(self, other):
return self.digest("sha1") == other.digest("sha1")
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
@classmethod
def from_pem(klass, txt):
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
@@ -337,11 +376,3 @@ class SSLCert:
for i in dec[0]:
altnames.append(i[0].asOctets())
return altnames
-
-
-
-def get_remote_cert(host, port, sni):
- c = tcp.TCPClient((host, port))
- c.connect()
- c.convert_to_ssl(sni=sni)
- return c.cert