aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/certutils.py87
-rw-r--r--test/test_certutils.py68
2 files changed, 140 insertions, 15 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py
index b9c291d0..fafcb5fd 100644
--- a/netlib/certutils.py
+++ b/netlib/certutils.py
@@ -4,6 +4,7 @@ from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error
import OpenSSL
import tcp
+import UserDict
DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720
# Generated with "openssl dhparam". It's too slow to generate this on startup.
@@ -42,11 +43,11 @@ def create_ca(o, cn, exp):
return key, cert
-def dummy_cert(pkey, cacert, commonname, sans):
+def dummy_cert(privkey, cacert, commonname, sans):
"""
Generates a dummy certificate.
- pkey: CA private key
+ privkey: CA private key
cacert: CA certificate
commonname: Common name for the generated certificate.
sans: A list of Subject Alternate Names.
@@ -68,17 +69,55 @@ def dummy_cert(pkey, cacert, commonname, sans):
cert.set_version(2)
cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)])
cert.set_pubkey(cacert.get_pubkey())
- cert.sign(pkey, "sha1")
+ cert.sign(privkey, "sha1")
return SSLCert(cert)
+class _Node(UserDict.UserDict):
+ def __init__(self):
+ UserDict.UserDict.__init__(self)
+ self.value = None
+
+
+class DNTree:
+ """
+ Domain store that knows about wildcards. DNS wildcards are very
+ restricted - the only valid variety is an asterisk on the left-most
+ domain component, i.e.:
+
+ *.foo.com
+ """
+ def __init__(self):
+ self.d = _Node()
+
+ def add(self, dn, cert):
+ parts = dn.split(".")
+ parts.reverse()
+ current = self.d
+ for i in parts:
+ current = current.setdefault(i, _Node())
+ current.value = cert
+
+ def get(self, dn):
+ parts = dn.split(".")
+ current = self.d
+ for i in reversed(parts):
+ if i in current:
+ current = current[i]
+ elif "*" in current:
+ return current["*"].value
+ else:
+ return None
+ return current.value
+
+
class CertStore:
"""
Implements an in-memory certificate store.
"""
- def __init__(self, pkey, cert):
- self.pkey, self.cert = pkey, cert
- self.certs = {}
+ def __init__(self, privkey, cacert):
+ self.privkey, self.cacert = privkey, cacert
+ self.certs = DNTree()
@classmethod
def from_store(klass, path, basename):
@@ -130,9 +169,29 @@ class CertStore:
f.close()
return key, ca
+ def add_cert_file(self, commonname, path):
+ raw = file(path, "rb").read()
+ cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)
+ try:
+ privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
+ except Exception:
+ privkey = None
+ self.add_cert(SSLCert(cert), privkey, commonname)
+
+ def add_cert(self, cert, privkey, *names):
+ """
+ Adds a cert to the certstore. We register the CN in the cert plus
+ any SANs, and also the list of names provided as an argument.
+ """
+ self.certs.add(cert.cn, (cert, privkey))
+ for i in cert.altnames:
+ self.certs.add(i, (cert, privkey))
+ for i in names:
+ self.certs.add(i, (cert, privkey))
+
def get_cert(self, commonname, sans):
"""
- Returns an SSLCert object.
+ Returns an (cert, privkey) tuple.
commonname: Common name for the generated certificate. Must be a
valid, plain-ASCII, IDNA-encoded domain name.
@@ -141,11 +200,12 @@ class CertStore:
Return None if the certificate could not be found or generated.
"""
- if commonname in self.certs:
- return self.certs[commonname]
- c = dummy_cert(self.pkey, self.cert, commonname, sans)
- self.certs[commonname] = c
- return c
+ c = self.certs.get(commonname)
+ if not c:
+ c = dummy_cert(self.privkey, self.cacert, commonname, sans)
+ self.add_cert(c, None)
+ c = (c, None)
+ return (c[0], c[1] or self.privkey)
class _GeneralName(univ.Choice):
@@ -171,6 +231,9 @@ class SSLCert:
"""
self.x509 = cert
+ def __eq__(self, other):
+ return self.digest("sha1") == other.digest("sha1")
+
@classmethod
def from_pem(klass, txt):
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
diff --git a/test/test_certutils.py b/test/test_certutils.py
index f741bdec..7f320e7e 100644
--- a/test/test_certutils.py
+++ b/test/test_certutils.py
@@ -1,7 +1,37 @@
import os
from netlib import certutils
+import OpenSSL
import tutils
+class TestDNTree:
+ def test_simple(self):
+ d = certutils.DNTree()
+ d.add("foo.com", "foo")
+ d.add("bar.com", "bar")
+ assert d.get("foo.com") == "foo"
+ assert d.get("bar.com") == "bar"
+ assert not d.get("oink.com")
+ assert not d.get("oink")
+ assert not d.get("")
+ assert not d.get("oink.oink")
+
+ d.add("*.match.org", "match")
+ assert not d.get("match.org")
+ assert d.get("foo.match.org") == "match"
+ assert d.get("foo.foo.match.org") == "match"
+
+ def test_wildcard(self):
+ d = certutils.DNTree()
+ d.add("foo.com", "foo")
+ assert not d.get("*.foo.com")
+ d.add("*.foo.com", "wild")
+
+ d = certutils.DNTree()
+ d.add("*", "foo")
+ assert d.get("foo.com") == "foo"
+ assert d.get("*.foo.com") == "foo"
+ assert d.get("com") == "foo"
+
class TestCertStore:
def test_create_explicit(self):
@@ -12,7 +42,7 @@ class TestCertStore:
ca2 = certutils.CertStore.from_store(d, "test")
assert ca2.get_cert("foo", [])
- assert ca.cert.get_serial_number() == ca2.cert.get_serial_number()
+ assert ca.cacert.get_serial_number() == ca2.cacert.get_serial_number()
def test_create_tmp(self):
with tutils.tmpdir() as d:
@@ -21,14 +51,46 @@ class TestCertStore:
assert ca.get_cert("foo.com", [])
assert ca.get_cert("*.foo.com", [])
+ r = ca.get_cert("*.foo.com", [])
+ assert r[1] == ca.privkey
+
+ def test_add_cert(self):
+ with tutils.tmpdir() as d:
+ ca = certutils.CertStore.from_store(d, "test")
+
+ def test_sans(self):
+ with tutils.tmpdir() as d:
+ ca = certutils.CertStore.from_store(d, "test")
+ c1 = ca.get_cert("foo.com", ["*.bar.com"])
+ c2 = ca.get_cert("foo.bar.com", [])
+ assert c1 == c2
+ c3 = ca.get_cert("bar.com", [])
+ assert not c1 == c3
+
+ def test_overrides(self):
+ with tutils.tmpdir() as d:
+ ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test")
+ ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test")
+ assert not ca1.cacert.get_serial_number() == ca2.cacert.get_serial_number()
+
+ dc = ca2.get_cert("foo.com", [])
+ dcp = os.path.join(d, "dc")
+ f = open(dcp, "wb")
+ f.write(dc[0].to_pem())
+ f.close()
+ ca1.add_cert_file("foo.com", dcp)
+
+ ret = ca1.get_cert("foo.com", [])
+ assert ret[0].serial == dc[0].serial
+
class TestDummyCert:
def test_with_ca(self):
with tutils.tmpdir() as d:
ca = certutils.CertStore.from_store(d, "test")
r = certutils.dummy_cert(
- ca.pkey,
- ca.cert,
+ ca.privkey,
+ ca.cacert,
"foo.com",
["one.com", "two.com", "*.three.com"]
)