aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/certutils.py481
-rw-r--r--netlib/tcp.py12
2 files changed, 6 insertions, 487 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py
deleted file mode 100644
index 9cb8a40e..00000000
--- a/netlib/certutils.py
+++ /dev/null
@@ -1,481 +0,0 @@
-import os
-import ssl
-import time
-import datetime
-import ipaddress
-
-import sys
-from pyasn1.type import univ, constraint, char, namedtype, tag
-from pyasn1.codec.der.decoder import decode
-from pyasn1.error import PyAsn1Error
-import OpenSSL
-
-from mitmproxy.types import serializable
-
-# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
-
-DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3
-# Generated with "openssl dhparam". It's too slow to generate this on startup.
-DEFAULT_DHPARAM = b"""
------BEGIN DH PARAMETERS-----
-MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3
-O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv
-j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ
-Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB
-chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC
-ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq
-o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX
-IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv
-A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8
-6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I
-rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI=
------END DH PARAMETERS-----
-"""
-
-
-def create_ca(o, cn, exp):
- key = OpenSSL.crypto.PKey()
- key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048)
- cert = OpenSSL.crypto.X509()
- cert.set_serial_number(int(time.time() * 10000))
- cert.set_version(2)
- cert.get_subject().CN = cn
- cert.get_subject().O = o
- cert.gmtime_adj_notBefore(-3600 * 48)
- cert.gmtime_adj_notAfter(exp)
- cert.set_issuer(cert.get_subject())
- cert.set_pubkey(key)
- cert.add_extensions([
- OpenSSL.crypto.X509Extension(
- b"basicConstraints",
- True,
- b"CA:TRUE"
- ),
- OpenSSL.crypto.X509Extension(
- b"nsCertType",
- False,
- b"sslCA"
- ),
- OpenSSL.crypto.X509Extension(
- b"extendedKeyUsage",
- False,
- b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC"
- ),
- OpenSSL.crypto.X509Extension(
- b"keyUsage",
- True,
- b"keyCertSign, cRLSign"
- ),
- OpenSSL.crypto.X509Extension(
- b"subjectKeyIdentifier",
- False,
- b"hash",
- subject=cert
- ),
- ])
- cert.sign(key, "sha256")
- return key, cert
-
-
-def dummy_cert(privkey, cacert, commonname, sans):
- """
- Generates a dummy certificate.
-
- privkey: CA private key
- cacert: CA certificate
- commonname: Common name for the generated certificate.
- sans: A list of Subject Alternate Names.
-
- Returns cert if operation succeeded, None if not.
- """
- ss = []
- for i in sans:
- try:
- ipaddress.ip_address(i.decode("ascii"))
- except ValueError:
- ss.append(b"DNS: %s" % i)
- else:
- ss.append(b"IP: %s" % i)
- ss = b", ".join(ss)
-
- cert = OpenSSL.crypto.X509()
- cert.gmtime_adj_notBefore(-3600 * 48)
- cert.gmtime_adj_notAfter(DEFAULT_EXP)
- cert.set_issuer(cacert.get_subject())
- if commonname is not None:
- cert.get_subject().CN = commonname
- cert.set_serial_number(int(time.time() * 10000))
- if ss:
- cert.set_version(2)
- cert.add_extensions(
- [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)])
- cert.set_pubkey(cacert.get_pubkey())
- cert.sign(privkey, "sha256")
- return SSLCert(cert)
-
-
-# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict.
-#
-# class _Node(UserDict.UserDict):
-# def __init__(self):
-# UserDict.UserDict.__init__(self)
-# self.value = None
-#
-#
-# class DNTree:
-# """
-# Domain store that knows about wildcards. DNS wildcards are very
-# restricted - the only valid variety is an asterisk on the left-most
-# domain component, i.e.:
-#
-# *.foo.com
-# """
-# def __init__(self):
-# self.d = _Node()
-#
-# def add(self, dn, cert):
-# parts = dn.split(".")
-# parts.reverse()
-# current = self.d
-# for i in parts:
-# current = current.setdefault(i, _Node())
-# current.value = cert
-#
-# def get(self, dn):
-# parts = dn.split(".")
-# current = self.d
-# for i in reversed(parts):
-# if i in current:
-# current = current[i]
-# elif "*" in current:
-# return current["*"].value
-# else:
-# return None
-# return current.value
-
-
-class CertStoreEntry:
-
- def __init__(self, cert, privatekey, chain_file):
- self.cert = cert
- self.privatekey = privatekey
- self.chain_file = chain_file
-
-
-class CertStore:
-
- """
- Implements an in-memory certificate store.
- """
- STORE_CAP = 100
-
- def __init__(
- self,
- default_privatekey,
- default_ca,
- default_chain_file,
- dhparams):
- self.default_privatekey = default_privatekey
- self.default_ca = default_ca
- self.default_chain_file = default_chain_file
- self.dhparams = dhparams
- self.certs = dict()
- self.expire_queue = []
-
- def expire(self, entry):
- self.expire_queue.append(entry)
- if len(self.expire_queue) > self.STORE_CAP:
- d = self.expire_queue.pop(0)
- for k, v in list(self.certs.items()):
- if v == d:
- del self.certs[k]
-
- @staticmethod
- def load_dhparam(path):
-
- # netlib<=0.10 doesn't generate a dhparam file.
- # Create it now if neccessary.
- if not os.path.exists(path):
- with open(path, "wb") as f:
- f.write(DEFAULT_DHPARAM)
-
- bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r")
- if bio != OpenSSL.SSL._ffi.NULL:
- bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free)
- dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams(
- bio,
- OpenSSL.SSL._ffi.NULL,
- OpenSSL.SSL._ffi.NULL,
- OpenSSL.SSL._ffi.NULL)
- dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free)
- return dh
-
- @classmethod
- def from_store(cls, path, basename):
- ca_path = os.path.join(path, basename + "-ca.pem")
- if not os.path.exists(ca_path):
- key, ca = cls.create_store(path, basename)
- else:
- with open(ca_path, "rb") as f:
- raw = f.read()
- ca = OpenSSL.crypto.load_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- raw)
- key = OpenSSL.crypto.load_privatekey(
- OpenSSL.crypto.FILETYPE_PEM,
- raw)
- dh_path = os.path.join(path, basename + "-dhparam.pem")
- dh = cls.load_dhparam(dh_path)
- return cls(key, ca, ca_path, dh)
-
- @staticmethod
- def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP):
- if not os.path.exists(path):
- os.makedirs(path)
-
- o = o or basename
- cn = cn or basename
-
- key, ca = create_ca(o=o, cn=cn, exp=expiry)
- # Dump the CA plus private key
- with open(os.path.join(path, basename + "-ca.pem"), "wb") as f:
- f.write(
- OpenSSL.crypto.dump_privatekey(
- OpenSSL.crypto.FILETYPE_PEM,
- key))
- f.write(
- OpenSSL.crypto.dump_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- ca))
-
- # Dump the certificate in PEM format
- with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f:
- f.write(
- OpenSSL.crypto.dump_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- ca))
-
- # Create a .cer file with the same contents for Android
- with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f:
- f.write(
- OpenSSL.crypto.dump_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- ca))
-
- # Dump the certificate in PKCS12 format for Windows devices
- with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f:
- p12 = OpenSSL.crypto.PKCS12()
- p12.set_certificate(ca)
- p12.set_privatekey(key)
- f.write(p12.export())
-
- with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f:
- f.write(DEFAULT_DHPARAM)
-
- return key, ca
-
- def add_cert_file(self, spec, path):
- with open(path, "rb") as f:
- raw = f.read()
- cert = SSLCert(
- OpenSSL.crypto.load_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- raw))
- try:
- privatekey = OpenSSL.crypto.load_privatekey(
- OpenSSL.crypto.FILETYPE_PEM,
- raw)
- except Exception:
- privatekey = self.default_privatekey
- self.add_cert(
- CertStoreEntry(cert, privatekey, path),
- spec
- )
-
- def add_cert(self, entry, *names):
- """
- Adds a cert to the certstore. We register the CN in the cert plus
- any SANs, and also the list of names provided as an argument.
- """
- if entry.cert.cn:
- self.certs[entry.cert.cn] = entry
- for i in entry.cert.altnames:
- self.certs[i] = entry
- for i in names:
- self.certs[i] = entry
-
- @staticmethod
- def asterisk_forms(dn):
- if dn is None:
- return []
- parts = dn.split(b".")
- parts.reverse()
- curr_dn = b""
- dn_forms = [b"*"]
- for part in parts[:-1]:
- curr_dn = b"." + part + curr_dn # .example.com
- dn_forms.append(b"*" + curr_dn) # *.example.com
- if parts[-1] != b"*":
- dn_forms.append(parts[-1] + curr_dn)
- return dn_forms
-
- def get_cert(self, commonname, sans):
- """
- Returns an (cert, privkey, cert_chain) tuple.
-
- commonname: Common name for the generated certificate. Must be a
- valid, plain-ASCII, IDNA-encoded domain name.
-
- sans: A list of Subject Alternate Names.
- """
-
- potential_keys = self.asterisk_forms(commonname)
- for s in sans:
- potential_keys.extend(self.asterisk_forms(s))
- potential_keys.append((commonname, tuple(sans)))
-
- name = next(
- filter(lambda key: key in self.certs, potential_keys),
- None
- )
- if name:
- entry = self.certs[name]
- else:
- entry = CertStoreEntry(
- cert=dummy_cert(
- self.default_privatekey,
- self.default_ca,
- commonname,
- sans),
- privatekey=self.default_privatekey,
- chain_file=self.default_chain_file)
- self.certs[(commonname, tuple(sans))] = entry
- self.expire(entry)
-
- return entry.cert, entry.privatekey, entry.chain_file
-
-
-class _GeneralName(univ.Choice):
- # We are only interested in dNSNames. We use a default handler to ignore
- # other types.
- # TODO: We should also handle iPAddresses.
- componentType = namedtype.NamedTypes(
- namedtype.NamedType('dNSName', char.IA5String().subtype(
- implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2)
- )
- ),
- )
-
-
-class _GeneralNames(univ.SequenceOf):
- componentType = _GeneralName()
- sizeSpec = univ.SequenceOf.sizeSpec + \
- constraint.ValueSizeConstraint(1, 1024)
-
-
-class SSLCert(serializable.Serializable):
-
- def __init__(self, cert):
- """
- Returns a (common name, [subject alternative names]) tuple.
- """
- self.x509 = cert
-
- def __eq__(self, other):
- return self.digest("sha256") == other.digest("sha256")
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def get_state(self):
- return self.to_pem()
-
- def set_state(self, state):
- self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state)
-
- @classmethod
- def from_state(cls, state):
- return cls.from_pem(state)
-
- @classmethod
- def from_pem(cls, txt):
- x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
- return cls(x509)
-
- @classmethod
- def from_der(cls, der):
- pem = ssl.DER_cert_to_PEM_cert(der)
- return cls.from_pem(pem)
-
- def to_pem(self):
- return OpenSSL.crypto.dump_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- self.x509)
-
- def digest(self, name):
- return self.x509.digest(name)
-
- @property
- def issuer(self):
- return self.x509.get_issuer().get_components()
-
- @property
- def notbefore(self):
- t = self.x509.get_notBefore()
- return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ")
-
- @property
- def notafter(self):
- t = self.x509.get_notAfter()
- return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ")
-
- @property
- def has_expired(self):
- return self.x509.has_expired()
-
- @property
- def subject(self):
- return self.x509.get_subject().get_components()
-
- @property
- def serial(self):
- return self.x509.get_serial_number()
-
- @property
- def keyinfo(self):
- pk = self.x509.get_pubkey()
- types = {
- OpenSSL.crypto.TYPE_RSA: "RSA",
- OpenSSL.crypto.TYPE_DSA: "DSA",
- }
- return (
- types.get(pk.type(), "UNKNOWN"),
- pk.bits()
- )
-
- @property
- def cn(self):
- c = None
- for i in self.subject:
- if i[0] == b"CN":
- c = i[1]
- return c
-
- @property
- def altnames(self):
- """
- Returns:
- All DNS altnames.
- """
- # tcp.TCPClient.convert_to_ssl assumes that this property only contains DNS altnames for hostname verification.
- altnames = []
- for i in range(self.x509.get_extension_count()):
- ext = self.x509.get_extension(i)
- if ext.get_short_name() == b"subjectAltName":
- try:
- dec = decode(ext.get_data(), asn1Spec=_GeneralNames())
- except PyAsn1Error:
- continue
- for i in dec[0]:
- altnames.append(i[0].asOctets())
- return altnames
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 4379c9b5..6e323957 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -17,7 +17,7 @@ from backports import ssl_match_hostname
import OpenSSL
from OpenSSL import SSL
-from netlib import certutils
+from mitmproxy import certs
from mitmproxy.utils import version_check
from mitmproxy.types import serializable
from netlib import exceptions
@@ -685,11 +685,11 @@ class TCPClient(_Connection):
if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error:
raise self.ssl_verification_error
- self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
+ self.cert = certs.SSLCert(self.connection.get_peer_certificate())
# Keep all server certificates in a list
for i in self.connection.get_peer_cert_chain():
- self.server_certs.append(certutils.SSLCert(i))
+ self.server_certs.append(certs.SSLCert(i))
# Validate TLS Hostname
try:
@@ -782,7 +782,7 @@ class BaseHandler(_Connection):
extra_chain_certs=None,
**sslctx_kwargs):
"""
- cert: A certutils.SSLCert object or the path to a certificate
+ cert: A certs.SSLCert object or the path to a certificate
chain file.
handle_sni: SNI handler, should take a connection object. Server
@@ -810,7 +810,7 @@ class BaseHandler(_Connection):
context = self._create_ssl_context(ca_pemfile=chain_file, **sslctx_kwargs)
context.use_privatekey(key)
- if isinstance(cert, certutils.SSLCert):
+ if isinstance(cert, certs.SSLCert):
context.use_certificate(cert.x509)
else:
context.use_certificate_chain_file(cert)
@@ -825,7 +825,7 @@ class BaseHandler(_Connection):
if request_client_cert:
def save_cert(conn_, cert, errno_, depth_, preverify_ok_):
- self.clientcert = certutils.SSLCert(cert)
+ self.clientcert = certs.SSLCert(cert)
# Return true to prevent cert verification error
return True
context.set_verify(SSL.VERIFY_PEER, save_cert)