aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2014-03-10 17:29:27 +1300
committerAldo Cortesi <aldo@nullcube.com>2014-03-10 17:29:27 +1300
commitf5cc63d653b27210d9c3d7646c01c3a9d540d9c7 (patch)
treec52924dd1e31bd465751491166a4774d1e9ea49d
parent2a12aa3c47d57cc2d3a36f6726a5f081ca493457 (diff)
downloadmitmproxy-f5cc63d653b27210d9c3d7646c01c3a9d540d9c7.tar.gz
mitmproxy-f5cc63d653b27210d9c3d7646c01c3a9d540d9c7.tar.bz2
mitmproxy-f5cc63d653b27210d9c3d7646c01c3a9d540d9c7.zip
Certificate flags
-rw-r--r--.gitignore3
-rw-r--r--netlib/certffi.py36
-rw-r--r--netlib/certutils.py7
-rw-r--r--test/test_certutils.py14
-rw-r--r--test/test_tcp.py127
5 files changed, 130 insertions, 57 deletions
diff --git a/.gitignore b/.gitignore
index e66d51fe..26c449d1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,4 +7,5 @@ MANIFEST
*.swp
*.swo
.coverage
-.idea \ No newline at end of file
+.idea
+__pycache__
diff --git a/netlib/certffi.py b/netlib/certffi.py
new file mode 100644
index 00000000..c5d7c95e
--- /dev/null
+++ b/netlib/certffi.py
@@ -0,0 +1,36 @@
+import cffi
+import OpenSSL
+xffi = cffi.FFI()
+xffi.cdef ("""
+ struct rsa_meth_st {
+ int flags;
+ ...;
+ };
+ struct rsa_st {
+ int pad;
+ long version;
+ struct rsa_meth_st *meth;
+ ...;
+ };
+""")
+xffi.verify(
+ """#include <openssl/rsa.h>""",
+ extra_compile_args=['-w']
+)
+
+def handle(privkey):
+ new = xffi.new("struct rsa_st*")
+ newbuf = xffi.buffer(new)
+ rsa = OpenSSL.SSL._lib.EVP_PKEY_get1_RSA(privkey._pkey)
+ oldbuf = OpenSSL.SSL._ffi.buffer(rsa)
+ newbuf[:] = oldbuf[:]
+ return new
+
+def set_flags(privkey, val):
+ hdl = handle(privkey)
+ hdl.meth.flags = val
+ return privkey
+
+def get_flags(privkey):
+ hdl = handle(privkey)
+ return hdl.meth.flags
diff --git a/netlib/certutils.py b/netlib/certutils.py
index 19148382..92b219ee 100644
--- a/netlib/certutils.py
+++ b/netlib/certutils.py
@@ -111,6 +111,7 @@ class DNTree:
return current.value
+
class CertStore:
"""
Implements an in-memory certificate store.
@@ -222,6 +223,11 @@ class CertStore:
c = (c, None)
return (c[0], c[1] or self.privkey)
+ def gen_pkey(self, cert):
+ import certffi
+ certffi.set_flags(self.privkey, 1)
+ return self.privkey
+
class _GeneralName(univ.Choice):
# We are only interested in dNSNames. We use a default handler to ignore
@@ -326,6 +332,7 @@ class SSLCert:
return altnames
+
def get_remote_cert(host, port, sni):
c = tcp.TCPClient((host, port))
c.connect()
diff --git a/test/test_certutils.py b/test/test_certutils.py
index 7f320e7e..176575ea 100644
--- a/test/test_certutils.py
+++ b/test/test_certutils.py
@@ -1,5 +1,5 @@
import os
-from netlib import certutils
+from netlib import certutils, certffi
import OpenSSL
import tutils
@@ -83,6 +83,16 @@ class TestCertStore:
ret = ca1.get_cert("foo.com", [])
assert ret[0].serial == dc[0].serial
+ def test_gen_pkey(self):
+ try:
+ with tutils.tmpdir() as d:
+ ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test")
+ ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test")
+ cert = ca1.get_cert("foo.com", [])
+ assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1
+ finally:
+ certffi.set_flags(ca2.privkey, 0)
+
class TestDummyCert:
def test_with_ca(self):
@@ -125,3 +135,5 @@ class TestSSLCert:
d = file(tutils.test_data.path("data/dercert"),"rb").read()
s = certutils.SSLCert.from_der(d)
assert s.cn
+
+
diff --git a/test/test_tcp.py b/test/test_tcp.py
index 814754cd..ec995702 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -4,16 +4,6 @@ import mock
import tutils
from OpenSSL import SSL
-class SNIHandler(tcp.BaseHandler):
- sni = None
- def handle_sni(self, connection):
- self.sni = connection.get_servername()
-
- def handle(self):
- self.wfile.write(self.sni)
- self.wfile.flush()
-
-
class EchoHandler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
@@ -25,58 +15,19 @@ class EchoHandler(tcp.BaseHandler):
self.wfile.flush()
-class ClientPeernameHandler(tcp.BaseHandler):
- def handle(self):
- self.wfile.write(str(self.connection.getpeername()))
- self.wfile.flush()
-
-
-class CertHandler(tcp.BaseHandler):
- sni = None
- def handle_sni(self, connection):
- self.sni = connection.get_servername()
-
- def handle(self):
- self.wfile.write("%s\n"%self.clientcert.serial)
- self.wfile.flush()
-
-
class ClientCipherListHandler(tcp.BaseHandler):
sni = None
-
def handle(self):
self.wfile.write("%s"%self.connection.get_cipher_list())
self.wfile.flush()
-class CurrentCipherHandler(tcp.BaseHandler):
- sni = None
- def handle(self):
- self.wfile.write("%s"%str(self.get_current_cipher()))
- self.wfile.flush()
-
-
-class DisconnectHandler(tcp.BaseHandler):
- def handle(self):
- self.close()
-
-
class HangHandler(tcp.BaseHandler):
def handle(self):
while 1:
time.sleep(1)
-class TimeoutHandler(tcp.BaseHandler):
- def handle(self):
- self.timeout = False
- self.settimeout(0.01)
- try:
- self.rfile.read(10)
- except tcp.NetLibTimeout:
- self.timeout = True
-
-
class TestServer(test.ServerTestBase):
handler = EchoHandler
def test_echo(self):
@@ -89,7 +40,10 @@ class TestServer(test.ServerTestBase):
class TestServerBind(test.ServerTestBase):
- handler = ClientPeernameHandler
+ class handler(tcp.BaseHandler):
+ def handle(self):
+ self.wfile.write(str(self.connection.getpeername()))
+ self.wfile.flush()
def test_bind(self):
""" Test to bind to a given random port. Try again if the random port turned out to be blocked. """
@@ -198,7 +152,14 @@ class TestSSLv3Only(test.ServerTestBase):
class TestSSLClientCert(test.ServerTestBase):
- handler = CertHandler
+ class handler(tcp.BaseHandler):
+ sni = None
+ def handle_sni(self, connection):
+ self.sni = connection.get_servername()
+
+ def handle(self):
+ self.wfile.write("%s\n"%self.clientcert.serial)
+ self.wfile.flush()
ssl = dict(
cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"),
@@ -222,7 +183,15 @@ class TestSSLClientCert(test.ServerTestBase):
class TestSNI(test.ServerTestBase):
- handler = SNIHandler
+ class handler(tcp.BaseHandler):
+ sni = None
+ def handle_sni(self, connection):
+ self.sni = connection.get_servername()
+
+ def handle(self):
+ self.wfile.write(self.sni)
+ self.wfile.flush()
+
ssl = dict(
cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"),
@@ -254,7 +223,11 @@ class TestServerCipherList(test.ServerTestBase):
class TestServerCurrentCipher(test.ServerTestBase):
- handler = CurrentCipherHandler
+ class handler(tcp.BaseHandler):
+ sni = None
+ def handle(self):
+ self.wfile.write("%s"%str(self.get_current_cipher()))
+ self.wfile.flush()
ssl = dict(
cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"),
@@ -300,7 +273,9 @@ class TestClientCipherListError(test.ServerTestBase):
class TestSSLDisconnect(test.ServerTestBase):
- handler = DisconnectHandler
+ class handler(tcp.BaseHandler):
+ def handle(self):
+ self.close()
ssl = dict(
cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"),
@@ -329,7 +304,15 @@ class TestDisconnect(test.ServerTestBase):
class TestServerTimeOut(test.ServerTestBase):
- handler = TimeoutHandler
+ class handler(tcp.BaseHandler):
+ def handle(self):
+ self.timeout = False
+ self.settimeout(0.01)
+ try:
+ self.rfile.read(10)
+ except tcp.NetLibTimeout:
+ self.timeout = True
+
def test_timeout(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
@@ -383,6 +366,40 @@ class TestDHParams(test.ServerTestBase):
assert ret[0] == "DHE-RSA-AES256-SHA"
+
+class TestPrivkeyGen(test.ServerTestBase):
+ class handler(tcp.BaseHandler):
+ def handle(self):
+ with tutils.tmpdir() as d:
+ ca1 = certutils.CertStore.from_store(d, "test2")
+ ca2 = certutils.CertStore.from_store(d, "test3")
+ cert, _ = ca1.get_cert("foo.com", [])
+ key = ca2.gen_pkey(cert)
+ self.convert_to_ssl(cert, key)
+
+ def test_privkey(self):
+ c = tcp.TCPClient(("127.0.0.1", self.port))
+ c.connect()
+ tutils.raises("bad record mac", c.convert_to_ssl)
+
+
+class TestPrivkeyGenNoFlags(test.ServerTestBase):
+ class handler(tcp.BaseHandler):
+ def handle(self):
+ with tutils.tmpdir() as d:
+ ca1 = certutils.CertStore.from_store(d, "test2")
+ ca2 = certutils.CertStore.from_store(d, "test3")
+ cert, _ = ca1.get_cert("foo.com", [])
+ certffi.set_flags(ca2.privkey, 0)
+ self.convert_to_ssl(cert, ca2.privkey)
+
+ def test_privkey(self):
+ c = tcp.TCPClient(("127.0.0.1", self.port))
+ c.connect()
+ tutils.raises("unexpected eof", c.convert_to_ssl)
+
+
+
class TestTCPClient:
def test_conerr(self):
c = tcp.TCPClient(("127.0.0.1", 0))