From 1f3fec2a3e1d6913ea7fe3480bc85f141737bd96 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 24 Oct 2017 22:44:37 +0200 Subject: remove old dntree implementation --- mitmproxy/certs.py | 40 ---------------------------------------- 1 file changed, 40 deletions(-) diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py index c5f930e1..5a737b61 100644 --- a/mitmproxy/certs.py +++ b/mitmproxy/certs.py @@ -114,46 +114,6 @@ def dummy_cert(privkey, cacert, commonname, sans): return SSLCert(cert) -# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. -# -# class _Node(UserDict.UserDict): -# def __init__(self): -# UserDict.UserDict.__init__(self) -# self.value = None -# -# -# class DNTree: -# """ -# Domain store that knows about wildcards. DNS wildcards are very -# restricted - the only valid variety is an asterisk on the left-most -# domain component, i.e.: -# -# *.foo.com -# """ -# def __init__(self): -# self.d = _Node() -# -# def add(self, dn, cert): -# parts = dn.split(".") -# parts.reverse() -# current = self.d -# for i in parts: -# current = current.setdefault(i, _Node()) -# current.value = cert -# -# def get(self, dn): -# parts = dn.split(".") -# current = self.d -# for i in reversed(parts): -# if i in current: -# current = current[i] -# elif "*" in current: -# return current["*"].value -# else: -# return None -# return current.value - - class CertStoreEntry: def __init__(self, cert, privatekey, chain_file): -- cgit v1.2.3 From 4a6d838ecc18388afa2f551799c679be752bbbf8 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 24 Oct 2017 21:12:39 +0200 Subject: fix #2563 --- mitmproxy/certs.py | 44 +++++++++++++++++++++---------------- test/mitmproxy/proxy/test_server.py | 4 ++-- test/mitmproxy/test_certs.py | 2 +- test/pathod/test_pathod.py | 2 +- 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py index 5a737b61..572a12d0 100644 --- a/mitmproxy/certs.py +++ b/mitmproxy/certs.py @@ -4,6 +4,7 @@ import time import datetime import ipaddress import sys +import typing from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode @@ -122,6 +123,11 @@ class CertStoreEntry: self.chain_file = chain_file +TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs) +TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans) +TCertId = typing.Union[TCustomCertId, TGeneratedCertId] + + class CertStore: """ @@ -139,7 +145,7 @@ class CertStore: self.default_ca = default_ca self.default_chain_file = default_chain_file self.dhparams = dhparams - self.certs = dict() + self.certs = {} # type: typing.Dict[TCertId, CertStoreEntry] self.expire_queue = [] def expire(self, entry): @@ -240,7 +246,7 @@ class CertStore: return key, ca - def add_cert_file(self, spec, path): + def add_cert_file(self, spec: str, path: str) -> None: with open(path, "rb") as f: raw = f.read() cert = SSLCert( @@ -255,10 +261,10 @@ class CertStore: privatekey = self.default_privatekey self.add_cert( CertStoreEntry(cert, privatekey, path), - spec + spec.encode("idna") ) - def add_cert(self, entry, *names): + def add_cert(self, entry: CertStoreEntry, *names: bytes): """ Adds a cert to the certstore. We register the CN in the cert plus any SANs, and also the list of names provided as an argument. @@ -271,21 +277,18 @@ class CertStore: self.certs[i] = entry @staticmethod - def asterisk_forms(dn): - if dn is None: - return [] + def asterisk_forms(dn: bytes) -> typing.List[bytes]: + """ + Return all asterisk forms for a domain. For example, for www.example.com this will return + [b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted. + """ parts = dn.split(b".") - parts.reverse() - curr_dn = b"" - dn_forms = [b"*"] - for part in parts[:-1]: - curr_dn = b"." + part + curr_dn # .example.com - dn_forms.append(b"*" + curr_dn) # *.example.com - if parts[-1] != b"*": - dn_forms.append(parts[-1] + curr_dn) - return dn_forms - - def get_cert(self, commonname, sans): + ret = [dn] + for i in range(1, len(parts)): + ret.append(b"*." + b".".join(parts[i:])) + return ret + + def get_cert(self, commonname: typing.Optional[bytes], sans: typing.List[bytes]): """ Returns an (cert, privkey, cert_chain) tuple. @@ -295,9 +298,12 @@ class CertStore: sans: A list of Subject Alternate Names. """ - potential_keys = self.asterisk_forms(commonname) + potential_keys = [] # type: typing.List[TCertId] + if commonname: + potential_keys.extend(self.asterisk_forms(commonname)) for s in sans: potential_keys.extend(self.asterisk_forms(s)) + potential_keys.append(b"*") potential_keys.append((commonname, tuple(sans))) name = next( diff --git a/test/mitmproxy/proxy/test_server.py b/test/mitmproxy/proxy/test_server.py index affdf221..8dce9bcd 100644 --- a/test/mitmproxy/proxy/test_server.py +++ b/test/mitmproxy/proxy/test_server.py @@ -479,7 +479,7 @@ class TestHTTPSNoCommonName(tservers.HTTPProxyTest): ssl = True ssloptions = pathod.SSLOptions( certs=[ - (b"*", tutils.test_data.path("mitmproxy/data/no_common_name.pem")) + ("*", tutils.test_data.path("mitmproxy/data/no_common_name.pem")) ] ) @@ -1142,7 +1142,7 @@ class AddUpstreamCertsToClientChainMixin: ssloptions = pathod.SSLOptions( cn=b"example.mitmproxy.org", certs=[ - (b"example.mitmproxy.org", servercert) + ("example.mitmproxy.org", servercert) ] ) diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py index 88c49561..693bebc6 100644 --- a/test/mitmproxy/test_certs.py +++ b/test/mitmproxy/test_certs.py @@ -102,7 +102,7 @@ class TestCertStore: dc = ca2.get_cert(b"foo.com", [b"sans.example.com"]) dcp = tmpdir.join("dc") dcp.write(dc[0].to_pem()) - ca1.add_cert_file(b"foo.com", str(dcp)) + ca1.add_cert_file("foo.com", str(dcp)) ret = ca1.get_cert(b"foo.com", []) assert ret[0].serial == dc[0].serial diff --git a/test/pathod/test_pathod.py b/test/pathod/test_pathod.py index 5f191c0d..c0011952 100644 --- a/test/pathod/test_pathod.py +++ b/test/pathod/test_pathod.py @@ -57,7 +57,7 @@ class TestNotAfterConnect(tservers.DaemonTests): class TestCustomCert(tservers.DaemonTests): ssl = True ssloptions = dict( - certs=[(b"*", tutils.test_data.path("pathod/data/testkey.pem"))], + certs=[("*", tutils.test_data.path("pathod/data/testkey.pem"))], ) def test_connect(self): -- cgit v1.2.3