aboutsummaryrefslogtreecommitdiffstats
path: root/test/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'test/netlib')
-rw-r--r--test/netlib/test_certutils.py180
-rw-r--r--test/netlib/test_tcp.py53
2 files changed, 27 insertions, 206 deletions
diff --git a/test/netlib/test_certutils.py b/test/netlib/test_certutils.py
deleted file mode 100644
index cf9a671b..00000000
--- a/test/netlib/test_certutils.py
+++ /dev/null
@@ -1,180 +0,0 @@
-import os
-from netlib import certutils, tutils
-
-# class TestDNTree:
-# def test_simple(self):
-# d = certutils.DNTree()
-# d.add("foo.com", "foo")
-# d.add("bar.com", "bar")
-# assert d.get("foo.com") == "foo"
-# assert d.get("bar.com") == "bar"
-# assert not d.get("oink.com")
-# assert not d.get("oink")
-# assert not d.get("")
-# assert not d.get("oink.oink")
-#
-# d.add("*.match.org", "match")
-# assert not d.get("match.org")
-# assert d.get("foo.match.org") == "match"
-# assert d.get("foo.foo.match.org") == "match"
-#
-# def test_wildcard(self):
-# d = certutils.DNTree()
-# d.add("foo.com", "foo")
-# assert not d.get("*.foo.com")
-# d.add("*.foo.com", "wild")
-#
-# d = certutils.DNTree()
-# d.add("*", "foo")
-# assert d.get("foo.com") == "foo"
-# assert d.get("*.foo.com") == "foo"
-# assert d.get("com") == "foo"
-
-
-class TestCertStore:
-
- def test_create_explicit(self):
- with tutils.tmpdir() as d:
- ca = certutils.CertStore.from_store(d, "test")
- assert ca.get_cert(b"foo", [])
-
- ca2 = certutils.CertStore.from_store(d, "test")
- assert ca2.get_cert(b"foo", [])
-
- assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
-
- def test_create_no_common_name(self):
- with tutils.tmpdir() as d:
- ca = certutils.CertStore.from_store(d, "test")
- assert ca.get_cert(None, [])[0].cn is None
-
- def test_create_tmp(self):
- with tutils.tmpdir() as d:
- ca = certutils.CertStore.from_store(d, "test")
- assert ca.get_cert(b"foo.com", [])
- assert ca.get_cert(b"foo.com", [])
- assert ca.get_cert(b"*.foo.com", [])
-
- r = ca.get_cert(b"*.foo.com", [])
- assert r[1] == ca.default_privatekey
-
- def test_sans(self):
- with tutils.tmpdir() as d:
- ca = certutils.CertStore.from_store(d, "test")
- c1 = ca.get_cert(b"foo.com", [b"*.bar.com"])
- ca.get_cert(b"foo.bar.com", [])
- # assert c1 == c2
- c3 = ca.get_cert(b"bar.com", [])
- assert not c1 == c3
-
- def test_sans_change(self):
- with tutils.tmpdir() as d:
- ca = certutils.CertStore.from_store(d, "test")
- ca.get_cert(b"foo.com", [b"*.bar.com"])
- cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"])
- assert b"*.baz.com" in cert.altnames
-
- def test_expire(self):
- with tutils.tmpdir() as d:
- ca = certutils.CertStore.from_store(d, "test")
- ca.STORE_CAP = 3
- ca.get_cert(b"one.com", [])
- ca.get_cert(b"two.com", [])
- ca.get_cert(b"three.com", [])
-
- assert (b"one.com", ()) in ca.certs
- assert (b"two.com", ()) in ca.certs
- assert (b"three.com", ()) in ca.certs
-
- ca.get_cert(b"one.com", [])
-
- assert (b"one.com", ()) in ca.certs
- assert (b"two.com", ()) in ca.certs
- assert (b"three.com", ()) in ca.certs
-
- ca.get_cert(b"four.com", [])
-
- assert (b"one.com", ()) not in ca.certs
- assert (b"two.com", ()) in ca.certs
- assert (b"three.com", ()) in ca.certs
- assert (b"four.com", ()) in ca.certs
-
- def test_overrides(self):
- with tutils.tmpdir() as d:
- ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test")
- ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test")
- assert not ca1.default_ca.get_serial_number(
- ) == ca2.default_ca.get_serial_number()
-
- dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
- dcp = os.path.join(d, "dc")
- f = open(dcp, "wb")
- f.write(dc[0].to_pem())
- f.close()
- ca1.add_cert_file(b"foo.com", dcp)
-
- ret = ca1.get_cert(b"foo.com", [])
- assert ret[0].serial == dc[0].serial
-
-
-class TestDummyCert:
-
- def test_with_ca(self):
- with tutils.tmpdir() as d:
- ca = certutils.CertStore.from_store(d, "test")
- r = certutils.dummy_cert(
- ca.default_privatekey,
- ca.default_ca,
- b"foo.com",
- [b"one.com", b"two.com", b"*.three.com"]
- )
- assert r.cn == b"foo.com"
-
- r = certutils.dummy_cert(
- ca.default_privatekey,
- ca.default_ca,
- None,
- []
- )
- assert r.cn is None
-
-
-class TestSSLCert:
-
- def test_simple(self):
- with open(tutils.test_data.path("data/text_cert"), "rb") as f:
- d = f.read()
- c1 = certutils.SSLCert.from_pem(d)
- assert c1.cn == b"google.com"
- assert len(c1.altnames) == 436
-
- with open(tutils.test_data.path("data/text_cert_2"), "rb") as f:
- d = f.read()
- c2 = certutils.SSLCert.from_pem(d)
- assert c2.cn == b"www.inode.co.nz"
- assert len(c2.altnames) == 2
- assert c2.digest("sha1")
- assert c2.notbefore
- assert c2.notafter
- assert c2.subject
- assert c2.keyinfo == ("RSA", 2048)
- assert c2.serial
- assert c2.issuer
- assert c2.to_pem()
- assert c2.has_expired is not None
-
- assert not c1 == c2
- assert c1 != c2
-
- def test_err_broken_sans(self):
- with open(tutils.test_data.path("data/text_cert_weird1"), "rb") as f:
- d = f.read()
- c = certutils.SSLCert.from_pem(d)
- # This breaks unless we ignore a decoding error.
- assert c.altnames is not None
-
- def test_der(self):
- with open(tutils.test_data.path("data/dercert"), "rb") as f:
- d = f.read()
- s = certutils.SSLCert.from_der(d)
- assert s.cn
diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py
index 797a5a04..2c1b92dc 100644
--- a/test/netlib/test_tcp.py
+++ b/test/netlib/test_tcp.py
@@ -9,9 +9,10 @@ import mock
from OpenSSL import SSL
-from netlib import tcp, certutils, tutils
-from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \
- TcpTimeout, TcpDisconnect, TcpException, NetlibException
+from mitmproxy import certs
+from netlib import tcp
+from netlib import tutils
+from netlib import exceptions
from . import tservers
@@ -108,7 +109,7 @@ class TestServerBind(tservers.ServerTestBase):
with c.connect():
assert c.rfile.readline() == str(("127.0.0.1", random_port)).encode()
return
- except TcpException: # port probably already in use
+ except exceptions.TcpException: # port probably already in use
pass
@@ -155,7 +156,7 @@ class TestFinishFail(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.wfile.write(b"foo\n")
- c.wfile.flush = mock.Mock(side_effect=TcpDisconnect)
+ c.wfile.flush = mock.Mock(side_effect=exceptions.TcpDisconnect)
c.finish()
@@ -195,7 +196,7 @@ class TestSSLv3Only(tservers.ServerTestBase):
def test_failure(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com")
+ tutils.raises(exceptions.TlsException, c.convert_to_ssl, sni="foo.com")
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
@@ -236,7 +237,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
def test_mode_strict_should_fail(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- with tutils.raises(InvalidCertificateException):
+ with tutils.raises(exceptions.InvalidCertificateException):
c.convert_to_ssl(
sni="example.mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
@@ -261,7 +262,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
def test_should_fail_without_sni(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- with tutils.raises(TlsException):
+ with tutils.raises(exceptions.TlsException):
c.convert_to_ssl(
verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
@@ -270,7 +271,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
def test_should_fail(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- with tutils.raises(InvalidCertificateException):
+ with tutils.raises(exceptions.InvalidCertificateException):
c.convert_to_ssl(
sni="mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
@@ -348,7 +349,7 @@ class TestSSLClientCert(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
tutils.raises(
- TlsException,
+ exceptions.TlsException,
c.convert_to_ssl,
cert=tutils.test_data.path("data/clientcert/make")
)
@@ -454,7 +455,7 @@ class TestSSLDisconnect(tservers.ServerTestBase):
# Excercise SSL.ZeroReturnError
c.rfile.read(10)
c.close()
- tutils.raises(TcpDisconnect, c.wfile.write, b"foo")
+ tutils.raises(exceptions.TcpDisconnect, c.wfile.write, b"foo")
tutils.raises(queue.Empty, self.q.get_nowait)
@@ -469,7 +470,7 @@ class TestSSLHardDisconnect(tservers.ServerTestBase):
# Exercise SSL.SysCallError
c.rfile.read(10)
c.close()
- tutils.raises(TcpDisconnect, c.wfile.write, b"foo")
+ tutils.raises(exceptions.TcpDisconnect, c.wfile.write, b"foo")
class TestDisconnect(tservers.ServerTestBase):
@@ -492,7 +493,7 @@ class TestServerTimeOut(tservers.ServerTestBase):
self.settimeout(0.01)
try:
self.rfile.read(10)
- except TcpTimeout:
+ except exceptions.TcpTimeout:
self.timeout = True
def test_timeout(self):
@@ -510,7 +511,7 @@ class TestTimeOut(tservers.ServerTestBase):
with c.connect():
c.settimeout(0.1)
assert c.gettimeout() == 0.1
- tutils.raises(TcpTimeout, c.rfile.read, 10)
+ tutils.raises(exceptions.TcpTimeout, c.rfile.read, 10)
class TestALPNClient(tservers.ServerTestBase):
@@ -562,13 +563,13 @@ class TestSSLTimeOut(tservers.ServerTestBase):
with c.connect():
c.convert_to_ssl()
c.settimeout(0.1)
- tutils.raises(TcpTimeout, c.rfile.read, 10)
+ tutils.raises(exceptions.TcpTimeout, c.rfile.read, 10)
class TestDHParams(tservers.ServerTestBase):
handler = HangHandler
ssl = dict(
- dhparams=certutils.CertStore.load_dhparam(
+ dhparams=certs.CertStore.load_dhparam(
tutils.test_data.path("data/dhparam.pem"),
),
cipher_list="DHE-RSA-AES256-SHA"
@@ -584,7 +585,7 @@ class TestDHParams(tservers.ServerTestBase):
def test_create_dhparams(self):
with tutils.tmpdir() as d:
filename = os.path.join(d, "dhparam.pem")
- certutils.CertStore.load_dhparam(filename)
+ certs.CertStore.load_dhparam(filename)
assert os.path.exists(filename)
@@ -592,7 +593,7 @@ class TestTCPClient:
def test_conerr(self):
c = tcp.TCPClient(("127.0.0.1", 0))
- tutils.raises(TcpException, c.connect)
+ tutils.raises(exceptions.TcpException, c.connect)
class TestFileLike:
@@ -661,7 +662,7 @@ class TestFileLike:
o = mock.MagicMock()
o.flush = mock.MagicMock(side_effect=socket.error)
s.o = o
- tutils.raises(TcpDisconnect, s.flush)
+ tutils.raises(exceptions.TcpDisconnect, s.flush)
def test_reader_read_error(self):
s = BytesIO(b"foobar\nfoobar")
@@ -669,7 +670,7 @@ class TestFileLike:
o = mock.MagicMock()
o.read = mock.MagicMock(side_effect=socket.error)
s.o = o
- tutils.raises(TcpDisconnect, s.read, 10)
+ tutils.raises(exceptions.TcpDisconnect, s.read, 10)
def test_reset_timestamps(self):
s = BytesIO(b"foobar\nfoobar")
@@ -700,24 +701,24 @@ class TestFileLike:
s = mock.MagicMock()
s.read = mock.MagicMock(side_effect=SSL.Error())
s = tcp.Reader(s)
- tutils.raises(TlsException, s.read, 1)
+ tutils.raises(exceptions.TlsException, s.read, 1)
def test_read_syscall_ssl_error(self):
s = mock.MagicMock()
s.read = mock.MagicMock(side_effect=SSL.SysCallError())
s = tcp.Reader(s)
- tutils.raises(TlsException, s.read, 1)
+ tutils.raises(exceptions.TlsException, s.read, 1)
def test_reader_readline_disconnect(self):
o = mock.MagicMock()
o.read = mock.MagicMock(side_effect=socket.error)
s = tcp.Reader(o)
- tutils.raises(TcpDisconnect, s.readline, 10)
+ tutils.raises(exceptions.TcpDisconnect, s.readline, 10)
def test_reader_incomplete_error(self):
s = BytesIO(b"foobar")
s = tcp.Reader(s)
- tutils.raises(TcpReadIncomplete, s.safe_read, 10)
+ tutils.raises(exceptions.TcpReadIncomplete, s.safe_read, 10)
class TestPeek(tservers.ServerTestBase):
@@ -738,11 +739,11 @@ class TestPeek(tservers.ServerTestBase):
assert c.rfile.readline() == testval
c.close()
- with tutils.raises(NetlibException):
+ with tutils.raises(exceptions.NetlibException):
if c.rfile.peek(1) == b"":
# Workaround for Python 2 on Unix:
# Peeking a closed connection does not raise an exception here.
- raise NetlibException()
+ raise exceptions.NetlibException()
class TestPeekSSL(TestPeek):