aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2016-07-06 21:03:17 -0700
committerMaximilian Hils <git@maximilianhils.com>2016-07-06 21:03:17 -0700
commit64a867973d5bac136c2e1c3c11c457d6b04d6649 (patch)
tree89b124368f8ed0c7973b4fa067239d505815b130
parent8287ce7e6dcf31e65519629bb064044a44de46d1 (diff)
downloadmitmproxy-64a867973d5bac136c2e1c3c11c457d6b04d6649.tar.gz
mitmproxy-64a867973d5bac136c2e1c3c11c457d6b04d6649.tar.bz2
mitmproxy-64a867973d5bac136c2e1c3c11c457d6b04d6649.zip
sni is now str, not bytes
-rw-r--r--mitmproxy/models/connections.py7
-rw-r--r--mitmproxy/models/flow.py16
-rw-r--r--mitmproxy/protocol/tls.py13
-rw-r--r--netlib/tcp.py4
-rw-r--r--netlib/utils.py4
-rw-r--r--pathod/pathod.py5
-rw-r--r--test/mitmproxy/test_server.py6
-rw-r--r--test/mitmproxy/tutils.py2
-rw-r--r--test/netlib/test_tcp.py26
-rw-r--r--test/pathod/test_pathoc.py4
10 files changed, 44 insertions, 43 deletions
diff --git a/mitmproxy/models/connections.py b/mitmproxy/models/connections.py
index 3e1a0928..570e89a9 100644
--- a/mitmproxy/models/connections.py
+++ b/mitmproxy/models/connections.py
@@ -8,7 +8,6 @@ import six
from mitmproxy import stateobject
from netlib import certutils
-from netlib import strutils
from netlib import tcp
@@ -162,7 +161,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
source_address=tcp.Address,
ssl_established=bool,
cert=certutils.SSLCert,
- sni=bytes,
+ sni=str,
timestamp_start=float,
timestamp_tcp_setup=float,
timestamp_ssl_setup=float,
@@ -206,6 +205,8 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
self.wfile.flush()
def establish_ssl(self, clientcerts, sni, **kwargs):
+ if sni and not isinstance(sni, six.string_types):
+ raise ValueError("sni must be str, not " + type(sni).__name__)
clientcert = None
if clientcerts:
if os.path.isfile(clientcerts):
@@ -217,7 +218,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
if os.path.exists(path):
clientcert = path
- self.convert_to_ssl(cert=clientcert, sni=strutils.always_bytes(sni), **kwargs)
+ self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs)
self.sni = sni
self.timestamp_ssl_setup = time.time()
diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py
index 0e4f80cb..f4993b7a 100644
--- a/mitmproxy/models/flow.py
+++ b/mitmproxy/models/flow.py
@@ -9,6 +9,7 @@ from mitmproxy.models.connections import ClientConnection
from mitmproxy.models.connections import ServerConnection
from netlib import version
+from typing import Optional # noqa
class Error(stateobject.StateObject):
@@ -70,18 +71,13 @@ class Flow(stateobject.StateObject):
def __init__(self, type, client_conn, server_conn, live=None):
self.type = type
self.id = str(uuid.uuid4())
- self.client_conn = client_conn
- """@type: ClientConnection"""
- self.server_conn = server_conn
- """@type: ServerConnection"""
+ self.client_conn = client_conn # type: ClientConnection
+ self.server_conn = server_conn # type: ServerConnection
self.live = live
- """@type: LiveConnection"""
- self.error = None
- """@type: Error"""
- self.intercepted = False
- """@type: bool"""
- self._backup = None
+ self.error = None # type: Error
+ self.intercepted = False # type: bool
+ self._backup = None # type: Optional[Flow]
self.reply = None
_stateobject_attributes = dict(
diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py
index 9f883b2b..8ef34493 100644
--- a/mitmproxy/protocol/tls.py
+++ b/mitmproxy/protocol/tls.py
@@ -10,6 +10,7 @@ import netlib.exceptions
from mitmproxy import exceptions
from mitmproxy.contrib.tls import _constructs
from mitmproxy.protocol import base
+from netlib import utils
# taken from https://testssl.sh/openssl-rfc.mappping.html
@@ -274,10 +275,11 @@ class TlsClientHello(object):
is_valid_sni_extension = (
extension.type == 0x00 and
len(extension.server_names) == 1 and
- extension.server_names[0].type == 0
+ extension.server_names[0].type == 0 and
+ utils.is_valid_host(extension.server_names[0].name)
)
if is_valid_sni_extension:
- return extension.server_names[0].name
+ return extension.server_names[0].name.decode("idna")
@property
def alpn_protocols(self):
@@ -403,13 +405,14 @@ class TlsLayer(base.Layer):
self._establish_tls_with_server()
def set_server_tls(self, server_tls, sni=None):
+ # type: (bool, Union[six.text_type, None, False]) -> None
"""
Set the TLS settings for the next server connection that will be established.
This function will not alter an existing connection.
Args:
server_tls: Shall we establish TLS with the server?
- sni: ``bytes`` for a custom SNI value,
+ sni: ``str`` for a custom SNI value,
``None`` for the client SNI value,
``False`` if no SNI value should be sent.
"""
@@ -602,9 +605,9 @@ class TlsLayer(base.Layer):
host = upstream_cert.cn.decode("utf8").encode("idna")
# Also add SNI values.
if self._client_hello.sni:
- sans.add(self._client_hello.sni)
+ sans.add(self._client_hello.sni.encode("idna"))
if self._custom_server_sni:
- sans.add(self._custom_server_sni)
+ sans.add(self._custom_server_sni.encode("idna"))
# RFC 2818: If a subjectAltName extension of type dNSName is present, that MUST be used as the identity.
# In other words, the Common Name is irrelevant then.
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 69dafc1f..cf099edd 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -676,7 +676,7 @@ class TCPClient(_Connection):
self.connection = SSL.Connection(context, self.connection)
if sni:
self.sni = sni
- self.connection.set_tlsext_host_name(sni)
+ self.connection.set_tlsext_host_name(sni.encode("idna"))
self.connection.set_connect_state()
try:
self.connection.do_handshake()
@@ -705,7 +705,7 @@ class TCPClient(_Connection):
if self.cert.cn:
crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]]
if sni:
- hostname = sni.decode("ascii", "strict")
+ hostname = sni
else:
hostname = "no-hostname"
ssl_match_hostname.match_hostname(crt, hostname)
diff --git a/netlib/utils.py b/netlib/utils.py
index 79340cbd..23c16dc3 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -73,11 +73,9 @@ _label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
def is_valid_host(host):
+ # type: (bytes) -> bool
"""
Checks if a hostname is valid.
-
- Args:
- host (bytes): The hostname
"""
try:
host.decode("idna")
diff --git a/pathod/pathod.py b/pathod/pathod.py
index 3df86aae..7087cba6 100644
--- a/pathod/pathod.py
+++ b/pathod/pathod.py
@@ -89,7 +89,10 @@ class PathodHandler(tcp.BaseHandler):
self.http2_framedump = http2_framedump
def handle_sni(self, connection):
- self.sni = connection.get_servername()
+ sni = connection.get_servername()
+ if sni:
+ sni = sni.decode("idna")
+ self.sni = sni
def http_serve_crafted(self, crafted, logctx):
error, crafted = self.server.check_policy(
diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py
index 1bbef975..0ab7624e 100644
--- a/test/mitmproxy/test_server.py
+++ b/test/mitmproxy/test_server.py
@@ -100,10 +100,10 @@ class CommonMixin:
if not self.ssl:
return
- f = self.pathod("304", sni=b"testserver.com")
+ f = self.pathod("304", sni="testserver.com")
assert f.status_code == 304
log = self.server.last_log()
- assert log["request"]["sni"] == b"testserver.com"
+ assert log["request"]["sni"] == "testserver.com"
class TcpMixin:
@@ -498,7 +498,7 @@ class TestHttps2Http(tservers.ReverseProxyTest):
assert p.request("get:'/p/200'").status_code == 200
def test_sni(self):
- p = self.pathoc(ssl=True, sni=b"example.com")
+ p = self.pathoc(ssl=True, sni="example.com")
assert p.request("get:'/p/200'").status_code == 200
assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog)
diff --git a/test/mitmproxy/tutils.py b/test/mitmproxy/tutils.py
index 5aade60c..d0a09035 100644
--- a/test/mitmproxy/tutils.py
+++ b/test/mitmproxy/tutils.py
@@ -130,7 +130,7 @@ def tserver_conn():
timestamp_ssl_setup=3,
timestamp_end=4,
ssl_established=False,
- sni=b"address",
+ sni="address",
via=None
))
c.reply = controller.DummyReply()
diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py
index 590bcc01..273427d5 100644
--- a/test/netlib/test_tcp.py
+++ b/test/netlib/test_tcp.py
@@ -169,7 +169,7 @@ class TestServerSSL(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL)
+ c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL)
testval = b"echo!\n"
c.wfile.write(testval)
c.wfile.flush()
@@ -179,7 +179,7 @@ class TestServerSSL(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
assert not c.get_current_cipher()
- c.convert_to_ssl(sni=b"foo.com")
+ c.convert_to_ssl(sni="foo.com")
ret = c.get_current_cipher()
assert ret
assert "AES" in ret[0]
@@ -195,7 +195,7 @@ class TestSSLv3Only(tservers.ServerTestBase):
def test_failure(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com")
+ tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com")
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
@@ -238,7 +238,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
with c.connect():
with tutils.raises(InvalidCertificateException):
c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
+ sni="example.mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
)
@@ -272,7 +272,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
with c.connect():
with tutils.raises(InvalidCertificateException):
c.convert_to_ssl(
- sni=b"mitmproxy.org",
+ sni="mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
)
@@ -291,7 +291,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
+ sni="example.mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
)
@@ -307,7 +307,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
+ sni="example.mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_path=tutils.test_data.path("data/verificationcerts/")
)
@@ -371,8 +371,8 @@ class TestSNI(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com")
- assert c.sni == b"foo.com"
+ c.convert_to_ssl(sni="foo.com")
+ assert c.sni == "foo.com"
assert c.rfile.readline() == b"foo.com"
@@ -385,7 +385,7 @@ class TestServerCipherList(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com")
+ c.convert_to_ssl(sni="foo.com")
assert c.rfile.readline() == b"['RC4-SHA']"
@@ -405,7 +405,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com")
+ c.convert_to_ssl(sni="foo.com")
assert b"RC4-SHA" in c.rfile.readline()
@@ -418,7 +418,7 @@ class TestServerCipherListError(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com")
+ tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com")
class TestClientCipherListError(tservers.ServerTestBase):
@@ -433,7 +433,7 @@ class TestClientCipherListError(tservers.ServerTestBase):
tutils.raises(
"cipher specification",
c.convert_to_ssl,
- sni=b"foo.com",
+ sni="foo.com",
cipher_list="bogus"
)
diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py
index 28f9f0f8..361a863b 100644
--- a/test/pathod/test_pathoc.py
+++ b/test/pathod/test_pathoc.py
@@ -54,10 +54,10 @@ class TestDaemonSSL(PathocTestDaemon):
def test_sni(self):
self.tval(
["get:/p/200"],
- sni=b"foobar.com"
+ sni="foobar.com"
)
log = self.d.log()
- assert log[0]["request"]["sni"] == b"foobar.com"
+ assert log[0]["request"]["sni"] == "foobar.com"
def test_showssl(self):
assert "certificate chain" in self.tval(["get:/p/200"], showssl=True)