aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/certutils.py45
-rw-r--r--netlib/tcp.py3
-rw-r--r--netlib/test.py7
-rw-r--r--test/test_certutils.py14
4 files changed, 23 insertions, 46 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py
index 4c06eb8f..7dcb5450 100644
--- a/netlib/certutils.py
+++ b/netlib/certutils.py
@@ -73,7 +73,7 @@ def dummy_ca(path):
return True
-def dummy_cert(fp, ca, commonname, sans):
+def dummy_cert(ca, commonname, sans):
"""
Generates and writes a certificate to fp.
@@ -111,27 +111,15 @@ def dummy_cert(fp, ca, commonname, sans):
cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)])
cert.set_pubkey(req.get_pubkey())
cert.sign(key, "sha1")
-
- fp.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert))
- fp.close()
+ return SSLCert(cert)
class CertStore:
"""
- Implements an on-disk certificate store.
+ Implements an in-memory certificate store.
"""
- def __init__(self, certdir=None):
- """
- certdir: The certificate store directory. If None, a temporary
- directory will be created, and destroyed when the .cleanup() method
- is called.
- """
- if certdir:
- self.remove = False
- self.certdir = certdir
- else:
- self.remove = True
- self.certdir = tempfile.mkdtemp(prefix="certstore")
+ def __init__(self):
+ self.certs = {}
def check_domain(self, commonname):
try:
@@ -145,33 +133,26 @@ class CertStore:
return False
return True
- def get_cert(self, commonname, sans, cacert=False):
+ def get_cert(self, commonname, sans, cacert):
"""
- Returns the path to a certificate.
+ Returns an SSLCert object.
commonname: Common name for the generated certificate. Must be a
valid, plain-ASCII, IDNA-encoded domain name.
sans: A list of Subject Alternate Names.
- cacert: An optional path to a CA certificate. If specified, the
- cert is created if it does not exist, else return None.
+ cacert: The path to a CA certificate.
Return None if the certificate could not be found or generated.
"""
if not self.check_domain(commonname):
return None
- certpath = os.path.join(self.certdir, commonname + ".pem")
- if os.path.exists(certpath):
- return certpath
- elif cacert:
- f = open(certpath, "wb")
- dummy_cert(f, cacert, commonname, sans)
- return certpath
-
- def cleanup(self):
- if self.remove:
- shutil.rmtree(self.certdir)
+ if commonname in self.certs:
+ return self.certs[commonname]
+ c = dummy_cert(cacert, commonname, sans)
+ self.certs[commonname] = c
+ return c
class _GeneralName(univ.Choice):
diff --git a/netlib/tcp.py b/netlib/tcp.py
index f4a8acf9..31e9a398 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -268,6 +268,7 @@ class BaseHandler:
def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False):
"""
+ cert: A certutils.SSLCert object.
method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD
handle_sni: SNI handler, should take a connection object. Server
name can be retrieved like this:
@@ -297,7 +298,7 @@ class BaseHandler:
# SNI callback happens during do_handshake()
ctx.set_tlsext_servername_callback(handle_sni)
ctx.use_privatekey_file(key)
- ctx.use_certificate_file(cert)
+ ctx.use_certificate(cert.x509)
if request_client_cert:
def ver(*args):
self.clientcert = certutils.SSLCert(args[1])
diff --git a/netlib/test.py b/netlib/test.py
index deaef64e..661395c5 100644
--- a/netlib/test.py
+++ b/netlib/test.py
@@ -1,5 +1,5 @@
import threading, Queue, cStringIO
-import tcp
+import tcp, certutils
class ServerThread(threading.Thread):
def __init__(self, server):
@@ -51,6 +51,9 @@ class TServer(tcp.TCPServer):
h = self.handler_klass(request, client_address, self)
self.last_handler = h
if self.ssl:
+ cert = certutils.SSLCert.from_pem(
+ file(self.ssl["cert"], "r").read()
+ )
if self.ssl["v3_only"]:
method = tcp.SSLv3_METHOD
options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1
@@ -58,7 +61,7 @@ class TServer(tcp.TCPServer):
method = tcp.SSLv23_METHOD
options = None
h.convert_to_ssl(
- self.ssl["cert"],
+ cert,
self.ssl["key"],
method = method,
options = options,
diff --git a/test/test_certutils.py b/test/test_certutils.py
index b335e946..0b4baf75 100644
--- a/test/test_certutils.py
+++ b/test/test_certutils.py
@@ -21,21 +21,16 @@ class TestCertStore:
with tutils.tmpdir() as d:
ca = os.path.join(d, "ca")
assert certutils.dummy_ca(ca)
- c = certutils.CertStore(d)
- c.cleanup()
- assert os.path.exists(d)
+ c = certutils.CertStore()
def test_create_tmp(self):
with tutils.tmpdir() as d:
ca = os.path.join(d, "ca")
assert certutils.dummy_ca(ca)
c = certutils.CertStore()
- assert not c.get_cert("../foo.com", [])
- assert not c.get_cert("foo.com", [])
assert c.get_cert("foo.com", [], ca)
assert c.get_cert("foo.com", [], ca)
assert c.get_cert("*.foo.com", [], ca)
- c.cleanup()
def test_check_domain(self):
c = certutils.CertStore()
@@ -52,15 +47,12 @@ class TestDummyCert:
with tutils.tmpdir() as d:
cacert = os.path.join(d, "cacert")
assert certutils.dummy_ca(cacert)
- p = os.path.join(d, "foo")
- certutils.dummy_cert(
- file(p, "wb"),
+ r = certutils.dummy_cert(
cacert,
"foo.com",
["one.com", "two.com", "*.three.com"]
)
- assert file(p,"rb").read()
-
+ assert r.cn == "foo.com"
class TestSSLCert: