From a7837846a2c20f3fc48406fc63845aec1a7efae0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 18 Jul 2014 22:55:25 +0200 Subject: temporarily replace DNTree with a simpler cert lookup mechanism, fix mitmproxy/mitmproxy#295 --- netlib/certutils.py | 99 ++++++++++++++++++++++++++++------------------------- 1 file changed, 53 insertions(+), 46 deletions(-) (limited to 'netlib/certutils.py') diff --git a/netlib/certutils.py b/netlib/certutils.py index 8aec5e82..87fb99c3 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,4 +1,5 @@ import os, ssl, time, datetime +import itertools from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error @@ -73,42 +74,44 @@ def dummy_cert(privkey, cacert, commonname, sans): return SSLCert(cert) -class _Node(UserDict.UserDict): - def __init__(self): - UserDict.UserDict.__init__(self) - self.value = None - - -class DNTree: - """ - Domain store that knows about wildcards. DNS wildcards are very - restricted - the only valid variety is an asterisk on the left-most - domain component, i.e.: - - *.foo.com - """ - def __init__(self): - self.d = _Node() - - def add(self, dn, cert): - parts = dn.split(".") - parts.reverse() - current = self.d - for i in parts: - current = current.setdefault(i, _Node()) - current.value = cert - - def get(self, dn): - parts = dn.split(".") - current = self.d - for i in reversed(parts): - if i in current: - current = current[i] - elif "*" in current: - return current["*"].value - else: - return None - return current.value +# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. +# +# class _Node(UserDict.UserDict): +# def __init__(self): +# UserDict.UserDict.__init__(self) +# self.value = None +# +# +# class DNTree: +# """ +# Domain store that knows about wildcards. DNS wildcards are very +# restricted - the only valid variety is an asterisk on the left-most +# domain component, i.e.: +# +# *.foo.com +# """ +# def __init__(self): +# self.d = _Node() +# +# def add(self, dn, cert): +# parts = dn.split(".") +# parts.reverse() +# current = self.d +# for i in parts: +# current = current.setdefault(i, _Node()) +# current.value = cert +# +# def get(self, dn): +# parts = dn.split(".") +# current = self.d +# for i in reversed(parts): +# if i in current: +# current = current[i] +# elif "*" in current: +# return current["*"].value +# else: +# return None +# return current.value @@ -119,7 +122,7 @@ class CertStore: def __init__(self, privkey, cacert, dhparams=None): self.privkey, self.cacert = privkey, cacert self.dhparams = dhparams - self.certs = DNTree() + self.certs = dict() @classmethod def load_dhparam(klass, path): @@ -206,11 +209,11 @@ class CertStore: any SANs, and also the list of names provided as an argument. """ if cert.cn: - self.certs.add(cert.cn, (cert, privkey)) + self.certs[cert.cn] = (cert, privkey) for i in cert.altnames: - self.certs.add(i, (cert, privkey)) + self.certs[i] = (cert, privkey) for i in names: - self.certs.add(i, (cert, privkey)) + self.certs[i] = (cert, privkey) def get_cert(self, commonname, sans): """ @@ -223,12 +226,16 @@ class CertStore: Return None if the certificate could not be found or generated. """ - c = self.certs.get(commonname) - if not c: - c = dummy_cert(self.privkey, self.cacert, commonname, sans) - self.add_cert(c, None) - c = (c, None) - return (c[0], c[1] or self.privkey) + + potential_keys = [commonname] + sans + [(commonname, tuple(sans))] + name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) + if name: + c = self.certs[name] + else: + c = dummy_cert(self.privkey, self.cacert, commonname, sans), None + self.certs[(commonname, tuple(sans))] = c + + return c[0], (c[1] or self.privkey) def gen_pkey(self, cert): import certffi -- cgit v1.2.3