diff options
| -rw-r--r-- | mitmproxy/connections.py | 18 | ||||
| -rw-r--r-- | mitmproxy/net/tls.py | 21 | ||||
| -rw-r--r-- | mitmproxy/proxy/protocol/http_replay.py | 10 | ||||
| -rw-r--r-- | mitmproxy/proxy/protocol/tls.py | 15 | ||||
| -rw-r--r-- | test/mitmproxy/test_connections.py | 9 | 
5 files changed, 46 insertions, 27 deletions
| diff --git a/mitmproxy/connections.py b/mitmproxy/connections.py index 86565b7b..9c47985c 100644 --- a/mitmproxy/connections.py +++ b/mitmproxy/connections.py @@ -253,7 +253,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):              address=address,              ip_address=address,              cert=None, -            sni=None, +            sni=address[0],              alpn_proto_negotiated=None,              tls_version=None,              source_address=('', 0), @@ -276,21 +276,21 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):          self.wfile.write(message)          self.wfile.flush() -    def establish_tls(self, clientcerts, sni, **kwargs): +    def establish_tls(self, *, sni=None, client_certs=None, **kwargs):          if sni and not isinstance(sni, str):              raise ValueError("sni must be str, not " + type(sni).__name__) -        clientcert = None -        if clientcerts: -            if os.path.isfile(clientcerts): -                clientcert = clientcerts +        client_cert = None +        if client_certs: +            if os.path.isfile(client_certs): +                client_cert = client_certs              else:                  path = os.path.join( -                    clientcerts, +                    client_certs,                      self.address[0].encode("idna").decode()) + ".pem"                  if os.path.exists(path): -                    clientcert = path +                    client_cert = path -        self.convert_to_tls(cert=clientcert, sni=sni, **kwargs) +        self.convert_to_tls(cert=client_cert, sni=sni, **kwargs)          self.sni = sni          self.alpn_proto_negotiated = self.get_alpn_proto_negotiated()          self.tls_version = self.connection.get_protocol_version_name() diff --git a/mitmproxy/net/tls.py b/mitmproxy/net/tls.py index 0e43a2ac..f8eeb44b 100644 --- a/mitmproxy/net/tls.py +++ b/mitmproxy/net/tls.py @@ -13,6 +13,7 @@ import certifi  from OpenSSL import SSL  from kaitaistruct import KaitaiStream +import mitmproxy.options  # noqa  from mitmproxy import exceptions, certs  from mitmproxy.contrib.kaitaistruct import tls_client_hello  from mitmproxy.net import check @@ -57,6 +58,26 @@ METHOD_NAMES = {  } +def client_arguments_from_options(options: "mitmproxy.options.Options") -> dict: + +    if options.ssl_insecure: +        verify = SSL.VERIFY_NONE +    else: +        verify = SSL.VERIFY_PEER + +    method, tls_options = VERSION_CHOICES[options.ssl_version_server] + +    return { +        "verify": verify, +        "method": method, +        "options": tls_options, +        "ca_path": options.ssl_verify_upstream_trusted_cadir, +        "ca_pemfile": options.ssl_verify_upstream_trusted_ca, +        "client_certs": options.client_certs, +        "cipher_list": options.ciphers_server, +    } + +  class MasterSecretLogger:      def __init__(self, filename):          self.filename = filename diff --git a/mitmproxy/proxy/protocol/http_replay.py b/mitmproxy/proxy/protocol/http_replay.py index 022e8133..0f3be1ea 100644 --- a/mitmproxy/proxy/protocol/http_replay.py +++ b/mitmproxy/proxy/protocol/http_replay.py @@ -9,7 +9,7 @@ from mitmproxy import http  from mitmproxy import flow  from mitmproxy import options  from mitmproxy import connections -from mitmproxy.net import server_spec +from mitmproxy.net import server_spec, tls  from mitmproxy.net.http import http1  from mitmproxy.coretypes import basethread  from mitmproxy.utils import human @@ -76,8 +76,8 @@ class RequestReplayThread(basethread.BaseThread):                          if resp.status_code != 200:                              raise exceptions.ReplayException("Upstream server refuses CONNECT request")                          server.establish_tls( -                            self.options.client_certs, -                            sni=self.f.server_conn.sni +                            sni=self.f.server_conn.sni, +                            **tls.client_arguments_from_options(self.options)                          )                          r.first_line_format = "relative"                      else: @@ -91,8 +91,8 @@ class RequestReplayThread(basethread.BaseThread):                      server.connect()                      if r.scheme == "https":                          server.establish_tls( -                            self.options.client_certs, -                            sni=self.f.server_conn.sni +                            sni=self.f.server_conn.sni, +                            **tls.client_arguments_from_options(self.options)                          )                      r.first_line_format = "relative" diff --git a/mitmproxy/proxy/protocol/tls.py b/mitmproxy/proxy/protocol/tls.py index d04c9801..876c1162 100644 --- a/mitmproxy/proxy/protocol/tls.py +++ b/mitmproxy/proxy/protocol/tls.py @@ -424,6 +424,9 @@ class TlsLayer(base.Layer):                  #   * which results in garbage because the layers don' match.                  alpn = [self.client_conn.get_alpn_proto_negotiated()] +            # We pass through the list of ciphers send by the client, because some HTTP/2 servers +            # will select a non-HTTP/2 compatible cipher from our default list and then hang up +            # because it's incompatible with h2. :-)              ciphers_server = self.config.options.ciphers_server              if not ciphers_server and self._client_tls:                  ciphers_server = [] @@ -432,16 +435,12 @@ class TlsLayer(base.Layer):                          ciphers_server.append(CIPHER_ID_NAME_MAP[id])                  ciphers_server = ':'.join(ciphers_server) +            args = net_tls.client_arguments_from_options(self.config.options) +            args["cipher_list"] = ciphers_server              self.server_conn.establish_tls( -                self.config.client_certs, -                self.server_sni, -                method=self.config.openssl_method_server, -                options=self.config.openssl_options_server, -                verify=self.config.openssl_verification_mode_server, -                ca_path=self.config.options.ssl_verify_upstream_trusted_cadir, -                ca_pemfile=self.config.options.ssl_verify_upstream_trusted_ca, -                cipher_list=ciphers_server, +                sni=self.server_sni,                  alpn_protos=alpn, +                **args              )              tls_cert_err = self.server_conn.ssl_verification_error              if tls_cert_err is not None: diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py index 9e5d89f1..00cdbc87 100644 --- a/test/mitmproxy/test_connections.py +++ b/test/mitmproxy/test_connections.py @@ -155,7 +155,7 @@ class TestServerConnection:      def test_sni(self):          c = connections.ServerConnection(('', 1234))          with pytest.raises(ValueError, matches='sni must be str, not '): -            c.establish_tls(None, b'foobar') +            c.establish_tls(sni=b'foobar')      def test_state(self):          c = tflow.tserver_conn() @@ -222,17 +222,16 @@ class TestServerConnectionTLS(tservers.ServerTestBase):          def handle(self):              self.finish() -    @pytest.mark.parametrize("clientcert", [ +    @pytest.mark.parametrize("client_certs", [          None,          tutils.test_data.path("mitmproxy/data/clientcert"),          tutils.test_data.path("mitmproxy/data/clientcert/client.pem"),      ]) -    def test_tls(self, clientcert): +    def test_tls(self, client_certs):          c = connections.ServerConnection(("127.0.0.1", self.port))          c.connect() -        c.establish_tls(clientcert, "foo.com") +        c.establish_tls(client_certs=client_certs)          assert c.connected() -        assert c.sni == "foo.com"          assert c.tls_established          c.close()          c.finish() | 
