aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mitmproxy/protocol/tls.py4
-rw-r--r--mitmproxy/proxy/root_context.py2
-rw-r--r--test/mitmproxy/test_server.py8
3 files changed, 10 insertions, 4 deletions
diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py
index e41a9af0..943fe837 100644
--- a/mitmproxy/protocol/tls.py
+++ b/mitmproxy/protocol/tls.py
@@ -326,12 +326,12 @@ class TlsLayer(base.Layer):
the server connection.
"""
- def __init__(self, ctx, client_tls, server_tls):
+ def __init__(self, ctx, client_tls, server_tls, custom_server_sni = None):
super(TlsLayer, self).__init__(ctx)
self._client_tls = client_tls
self._server_tls = server_tls
- self._custom_server_sni = None
+ self._custom_server_sni = custom_server_sni
self._client_hello = None # type: TlsClientHello
def __call__(self):
diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py
index 81dd625c..000ebb13 100644
--- a/mitmproxy/proxy/root_context.py
+++ b/mitmproxy/proxy/root_context.py
@@ -69,7 +69,7 @@ class RootContext(object):
# An inline script may upgrade from http to https,
# in which case we need some form of TLS layer.
if isinstance(top_layer, modes.ReverseProxy):
- return protocol.TlsLayer(top_layer, client_tls, top_layer.server_tls)
+ return protocol.TlsLayer(top_layer, client_tls, top_layer.server_tls, top_layer.server_conn.address.host)
if isinstance(top_layer, protocol.ServerConnectionMixin) or isinstance(top_layer, protocol.UpstreamConnectLayer):
return protocol.TlsLayer(top_layer, client_tls, client_tls)
diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py
index d1886b97..059ea856 100644
--- a/test/mitmproxy/test_server.py
+++ b/test/mitmproxy/test_server.py
@@ -101,10 +101,16 @@ class CommonMixin:
if not self.ssl:
return
+ if getattr(self, 'reverse', False):
+ # In reverse proxy mode, we expect to use the upstream host as our SNI value
+ expected_sni = "127.0.0.1"
+ else:
+ expected_sni = "testserver.com"
+
f = self.pathod("304", sni="testserver.com")
assert f.status_code == 304
log = self.server.last_log()
- assert log["request"]["sni"] == "testserver.com"
+ assert log["request"]["sni"] == expected_sni
class TcpMixin: