diff options
-rw-r--r-- | mitmproxy/protocol/tls.py | 4 | ||||
-rw-r--r-- | mitmproxy/proxy/root_context.py | 2 | ||||
-rw-r--r-- | test/mitmproxy/test_server.py | 8 |
3 files changed, 10 insertions, 4 deletions
diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py index e41a9af0..943fe837 100644 --- a/mitmproxy/protocol/tls.py +++ b/mitmproxy/protocol/tls.py @@ -326,12 +326,12 @@ class TlsLayer(base.Layer): the server connection. """ - def __init__(self, ctx, client_tls, server_tls): + def __init__(self, ctx, client_tls, server_tls, custom_server_sni = None): super(TlsLayer, self).__init__(ctx) self._client_tls = client_tls self._server_tls = server_tls - self._custom_server_sni = None + self._custom_server_sni = custom_server_sni self._client_hello = None # type: TlsClientHello def __call__(self): diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index 81dd625c..000ebb13 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -69,7 +69,7 @@ class RootContext(object): # An inline script may upgrade from http to https, # in which case we need some form of TLS layer. if isinstance(top_layer, modes.ReverseProxy): - return protocol.TlsLayer(top_layer, client_tls, top_layer.server_tls) + return protocol.TlsLayer(top_layer, client_tls, top_layer.server_tls, top_layer.server_conn.address.host) if isinstance(top_layer, protocol.ServerConnectionMixin) or isinstance(top_layer, protocol.UpstreamConnectLayer): return protocol.TlsLayer(top_layer, client_tls, client_tls) diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index d1886b97..059ea856 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -101,10 +101,16 @@ class CommonMixin: if not self.ssl: return + if getattr(self, 'reverse', False): + # In reverse proxy mode, we expect to use the upstream host as our SNI value + expected_sni = "127.0.0.1" + else: + expected_sni = "testserver.com" + f = self.pathod("304", sni="testserver.com") assert f.status_code == 304 log = self.server.last_log() - assert log["request"]["sni"] == "testserver.com" + assert log["request"]["sni"] == expected_sni class TcpMixin: |