aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2017-10-25 10:20:09 +0200
committerGitHub <noreply@github.com>2017-10-25 10:20:09 +0200
commitfdd6bd8277265563888d676dcff42b1bca007363 (patch)
treec067140d9477c6eab3501af55109652d149cdfd2
parent45145ed08b7623add9a96bfcdcd02a746e44124a (diff)
parent4a6d838ecc18388afa2f551799c679be752bbbf8 (diff)
downloadmitmproxy-fdd6bd8277265563888d676dcff42b1bca007363.tar.gz
mitmproxy-fdd6bd8277265563888d676dcff42b1bca007363.tar.bz2
mitmproxy-fdd6bd8277265563888d676dcff42b1bca007363.zip
Merge pull request #2606 from mhils/issue-2563
Fix #2563
-rw-r--r--mitmproxy/certs.py84
-rw-r--r--test/mitmproxy/proxy/test_server.py4
-rw-r--r--test/mitmproxy/test_certs.py2
-rw-r--r--test/pathod/test_pathod.py2
4 files changed, 29 insertions, 63 deletions
diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py
index c5f930e1..572a12d0 100644
--- a/mitmproxy/certs.py
+++ b/mitmproxy/certs.py
@@ -4,6 +4,7 @@ import time
import datetime
import ipaddress
import sys
+import typing
from pyasn1.type import univ, constraint, char, namedtype, tag
from pyasn1.codec.der.decoder import decode
@@ -114,46 +115,6 @@ def dummy_cert(privkey, cacert, commonname, sans):
return SSLCert(cert)
-# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict.
-#
-# class _Node(UserDict.UserDict):
-# def __init__(self):
-# UserDict.UserDict.__init__(self)
-# self.value = None
-#
-#
-# class DNTree:
-# """
-# Domain store that knows about wildcards. DNS wildcards are very
-# restricted - the only valid variety is an asterisk on the left-most
-# domain component, i.e.:
-#
-# *.foo.com
-# """
-# def __init__(self):
-# self.d = _Node()
-#
-# def add(self, dn, cert):
-# parts = dn.split(".")
-# parts.reverse()
-# current = self.d
-# for i in parts:
-# current = current.setdefault(i, _Node())
-# current.value = cert
-#
-# def get(self, dn):
-# parts = dn.split(".")
-# current = self.d
-# for i in reversed(parts):
-# if i in current:
-# current = current[i]
-# elif "*" in current:
-# return current["*"].value
-# else:
-# return None
-# return current.value
-
-
class CertStoreEntry:
def __init__(self, cert, privatekey, chain_file):
@@ -162,6 +123,11 @@ class CertStoreEntry:
self.chain_file = chain_file
+TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
+TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans)
+TCertId = typing.Union[TCustomCertId, TGeneratedCertId]
+
+
class CertStore:
"""
@@ -179,7 +145,7 @@ class CertStore:
self.default_ca = default_ca
self.default_chain_file = default_chain_file
self.dhparams = dhparams
- self.certs = dict()
+ self.certs = {} # type: typing.Dict[TCertId, CertStoreEntry]
self.expire_queue = []
def expire(self, entry):
@@ -280,7 +246,7 @@ class CertStore:
return key, ca
- def add_cert_file(self, spec, path):
+ def add_cert_file(self, spec: str, path: str) -> None:
with open(path, "rb") as f:
raw = f.read()
cert = SSLCert(
@@ -295,10 +261,10 @@ class CertStore:
privatekey = self.default_privatekey
self.add_cert(
CertStoreEntry(cert, privatekey, path),
- spec
+ spec.encode("idna")
)
- def add_cert(self, entry, *names):
+ def add_cert(self, entry: CertStoreEntry, *names: bytes):
"""
Adds a cert to the certstore. We register the CN in the cert plus
any SANs, and also the list of names provided as an argument.
@@ -311,21 +277,18 @@ class CertStore:
self.certs[i] = entry
@staticmethod
- def asterisk_forms(dn):
- if dn is None:
- return []
+ def asterisk_forms(dn: bytes) -> typing.List[bytes]:
+ """
+ Return all asterisk forms for a domain. For example, for www.example.com this will return
+ [b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
+ """
parts = dn.split(b".")
- parts.reverse()
- curr_dn = b""
- dn_forms = [b"*"]
- for part in parts[:-1]:
- curr_dn = b"." + part + curr_dn # .example.com
- dn_forms.append(b"*" + curr_dn) # *.example.com
- if parts[-1] != b"*":
- dn_forms.append(parts[-1] + curr_dn)
- return dn_forms
-
- def get_cert(self, commonname, sans):
+ ret = [dn]
+ for i in range(1, len(parts)):
+ ret.append(b"*." + b".".join(parts[i:]))
+ return ret
+
+ def get_cert(self, commonname: typing.Optional[bytes], sans: typing.List[bytes]):
"""
Returns an (cert, privkey, cert_chain) tuple.
@@ -335,9 +298,12 @@ class CertStore:
sans: A list of Subject Alternate Names.
"""
- potential_keys = self.asterisk_forms(commonname)
+ potential_keys = [] # type: typing.List[TCertId]
+ if commonname:
+ potential_keys.extend(self.asterisk_forms(commonname))
for s in sans:
potential_keys.extend(self.asterisk_forms(s))
+ potential_keys.append(b"*")
potential_keys.append((commonname, tuple(sans)))
name = next(
diff --git a/test/mitmproxy/proxy/test_server.py b/test/mitmproxy/proxy/test_server.py
index affdf221..8dce9bcd 100644
--- a/test/mitmproxy/proxy/test_server.py
+++ b/test/mitmproxy/proxy/test_server.py
@@ -479,7 +479,7 @@ class TestHTTPSNoCommonName(tservers.HTTPProxyTest):
ssl = True
ssloptions = pathod.SSLOptions(
certs=[
- (b"*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
+ ("*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
]
)
@@ -1142,7 +1142,7 @@ class AddUpstreamCertsToClientChainMixin:
ssloptions = pathod.SSLOptions(
cn=b"example.mitmproxy.org",
certs=[
- (b"example.mitmproxy.org", servercert)
+ ("example.mitmproxy.org", servercert)
]
)
diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py
index 88c49561..693bebc6 100644
--- a/test/mitmproxy/test_certs.py
+++ b/test/mitmproxy/test_certs.py
@@ -102,7 +102,7 @@ class TestCertStore:
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
dcp = tmpdir.join("dc")
dcp.write(dc[0].to_pem())
- ca1.add_cert_file(b"foo.com", str(dcp))
+ ca1.add_cert_file("foo.com", str(dcp))
ret = ca1.get_cert(b"foo.com", [])
assert ret[0].serial == dc[0].serial
diff --git a/test/pathod/test_pathod.py b/test/pathod/test_pathod.py
index 5f191c0d..c0011952 100644
--- a/test/pathod/test_pathod.py
+++ b/test/pathod/test_pathod.py
@@ -57,7 +57,7 @@ class TestNotAfterConnect(tservers.DaemonTests):
class TestCustomCert(tservers.DaemonTests):
ssl = True
ssloptions = dict(
- certs=[(b"*", tutils.test_data.path("pathod/data/testkey.pem"))],
+ certs=[("*", tutils.test_data.path("pathod/data/testkey.pem"))],
)
def test_connect(self):