diff options
Diffstat (limited to 'test')
24 files changed, 343 insertions, 184 deletions
diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py index 2dc7eb92..3f990668 100644 --- a/test/mitmproxy/addons/test_clientplayback.py +++ b/test/mitmproxy/addons/test_clientplayback.py @@ -52,6 +52,10 @@ class TestClientPlayback:              cp.stop_replay()              assert not cp.flows +            df = tflow.DummyFlow(tflow.tclient_conn(), tflow.tserver_conn(), True) +            with pytest.raises(exceptions.CommandError, match="Can't replay live flow."): +                cp.start_replay([df]) +      def test_load_file(self, tmpdir):          cp = clientplayback.ClientPlayback()          with taddons.context(): diff --git a/test/mitmproxy/addons/test_cut.py b/test/mitmproxy/addons/test_cut.py index 71e699db..c444b8ee 100644 --- a/test/mitmproxy/addons/test_cut.py +++ b/test/mitmproxy/addons/test_cut.py @@ -23,8 +23,8 @@ def test_extract():          ["request.text", "content"],          ["request.content", b"content"],          ["request.raw_content", b"content"], -        ["request.timestamp_start", "1"], -        ["request.timestamp_end", "2"], +        ["request.timestamp_start", "946681200"], +        ["request.timestamp_end", "946681201"],          ["request.header[header]", "qvalue"],          ["response.status_code", "200"], @@ -33,30 +33,29 @@ def test_extract():          ["response.content", b"message"],          ["response.raw_content", b"message"],          ["response.header[header-response]", "svalue"], -        ["response.timestamp_start", "1"], -        ["response.timestamp_end", "2"], +        ["response.timestamp_start", "946681202"], +        ["response.timestamp_end", "946681203"],          ["client_conn.address.port", "22"],          ["client_conn.address.host", "127.0.0.1"],          ["client_conn.tls_version", "TLSv1.2"],          ["client_conn.sni", "address"], -        ["client_conn.ssl_established", "false"], +        ["client_conn.tls_established", "false"],          ["server_conn.address.port", "22"],          ["server_conn.address.host", "address"],          ["server_conn.ip_address.host", "192.168.0.1"],          ["server_conn.tls_version", "TLSv1.2"],          ["server_conn.sni", "address"], -        ["server_conn.ssl_established", "false"], +        ["server_conn.tls_established", "false"],      ] -    for t in tests: -        ret = cut.extract(t[0], tf) -        if ret != t[1]: -            raise AssertionError("%s: Expected %s, got %s" % (t[0], t[1], ret)) +    for spec, expected in tests: +        ret = cut.extract(spec, tf) +        assert spec and ret == expected      with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f:          d = f.read() -    c1 = certs.SSLCert.from_pem(d) +    c1 = certs.Cert.from_pem(d)      tf.server_conn.cert = c1      assert "CERTIFICATE" in cut.extract("server_conn.cert", tf) diff --git a/test/mitmproxy/addons/test_proxyauth.py b/test/mitmproxy/addons/test_proxyauth.py index 1d05e137..97259d1c 100644 --- a/test/mitmproxy/addons/test_proxyauth.py +++ b/test/mitmproxy/addons/test_proxyauth.py @@ -190,7 +190,7 @@ class TestProxyAuth:              with pytest.raises(exceptions.OptionsError):                  ctx.configure(up, proxyauth="ldap:test:test:test") -            with pytest.raises(IndexError): +            with pytest.raises(exceptions.OptionsError):                  ctx.configure(up, proxyauth="ldap:fake_serveruid=?dc=example,dc=com:person")              with pytest.raises(exceptions.OptionsError): diff --git a/test/mitmproxy/addons/test_view.py b/test/mitmproxy/addons/test_view.py index 1c76eb21..6f2a9ca5 100644 --- a/test/mitmproxy/addons/test_view.py +++ b/test/mitmproxy/addons/test_view.py @@ -41,7 +41,7 @@ def test_order_generators():      tf = tflow.tflow(resp=True)      rs = view.OrderRequestStart(v) -    assert rs.generate(tf) == 1 +    assert rs.generate(tf) == 946681200      rm = view.OrderRequestMethod(v)      assert rm.generate(tf) == tf.request.method @@ -147,6 +147,10 @@ def test_create():          assert v[0].request.url == "http://foo.com/"          v.create("get", "http://foo.com")          assert len(v) == 2 +        with pytest.raises(exceptions.CommandError, match="Invalid URL"): +            v.create("get", "http://foo.com\\") +        with pytest.raises(exceptions.CommandError, match="Invalid URL"): +            v.create("get", "http://")  def test_orders(): @@ -175,6 +179,10 @@ def test_load(tmpdir):              v.load_file("nonexistent_file_path")          except IOError:              assert False +        with open(path, "wb") as f: +            f.write(b"invalidflows") +        v.load_file(path) +        assert tctx.master.has_log("Invalid data format.")  def test_resolve(): diff --git a/test/mitmproxy/net/http/test_response.py b/test/mitmproxy/net/http/test_response.py index a77435c9..af35bab3 100644 --- a/test/mitmproxy/net/http/test_response.py +++ b/test/mitmproxy/net/http/test_response.py @@ -150,10 +150,10 @@ class TestResponseUtils:          n = time.time()          r.headers["date"] = email.utils.formatdate(n)          pre = r.headers["date"] -        r.refresh(1) +        r.refresh(946681202)          assert pre == r.headers["date"] -        r.refresh(61) +        r.refresh(946681262)          d = email.utils.parsedate_tz(r.headers["date"])          d = email.utils.mktime_tz(d)          # Weird that this is not exact... diff --git a/test/mitmproxy/net/http/test_url.py b/test/mitmproxy/net/http/test_url.py index 2064aab8..c9f61faf 100644 --- a/test/mitmproxy/net/http/test_url.py +++ b/test/mitmproxy/net/http/test_url.py @@ -108,6 +108,7 @@ def test_empty_key_trailing_equal_sign():  def test_encode():      assert url.encode([('foo', 'bar')])      assert url.encode([('foo', surrogates)]) +    assert not url.encode([], similar_to="justatext")  def test_decode(): diff --git a/test/mitmproxy/net/test_tcp.py b/test/mitmproxy/net/test_tcp.py index e9084be4..8c012e42 100644 --- a/test/mitmproxy/net/test_tcp.py +++ b/test/mitmproxy/net/test_tcp.py @@ -178,7 +178,7 @@ class TestServerSSL(tservers.ServerTestBase):      def test_echo(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL) +            c.convert_to_tls(sni="foo.com", options=SSL.OP_ALL)              testval = b"echo!\n"              c.wfile.write(testval)              c.wfile.flush() @@ -188,7 +188,7 @@ class TestServerSSL(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect():              assert not c.get_current_cipher() -            c.convert_to_ssl(sni="foo.com") +            c.convert_to_tls(sni="foo.com")              ret = c.get_current_cipher()              assert ret              assert "AES" in ret[0] @@ -205,7 +205,7 @@ class TestSSLv3Only(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect():              with pytest.raises(exceptions.TlsException): -                c.convert_to_ssl(sni="foo.com") +                c.convert_to_tls(sni="foo.com")  class TestInvalidTrustFile(tservers.ServerTestBase): @@ -213,7 +213,7 @@ class TestInvalidTrustFile(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect():              with pytest.raises(exceptions.TlsException): -                c.convert_to_ssl( +                c.convert_to_tls(                      sni="example.mitmproxy.org",                      verify=SSL.VERIFY_PEER,                      ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/generate.py") @@ -231,7 +231,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):      def test_mode_default_should_pass(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              # Verification errors should be saved even if connection isn't aborted              # aborted @@ -245,7 +245,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):      def test_mode_none_should_pass(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl(verify=SSL.VERIFY_NONE) +            c.convert_to_tls(verify=SSL.VERIFY_NONE)              # Verification errors should be saved even if connection isn't aborted              assert c.ssl_verification_error @@ -259,7 +259,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect():              with pytest.raises(exceptions.InvalidCertificateException): -                c.convert_to_ssl( +                c.convert_to_tls(                      sni="example.mitmproxy.org",                      verify=SSL.VERIFY_PEER,                      ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt") @@ -284,7 +284,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect():              with pytest.raises(exceptions.TlsException): -                c.convert_to_ssl( +                c.convert_to_tls(                      verify=SSL.VERIFY_PEER,                      ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt")                  ) @@ -292,7 +292,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):      def test_mode_none_should_pass_without_sni(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl( +            c.convert_to_tls(                  verify=SSL.VERIFY_NONE,                  ca_path=tutils.test_data.path("mitmproxy/net/data/verificationcerts/")              ) @@ -303,7 +303,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect():              with pytest.raises(exceptions.InvalidCertificateException): -                c.convert_to_ssl( +                c.convert_to_tls(                      sni="mitmproxy.org",                      verify=SSL.VERIFY_PEER,                      ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt") @@ -322,7 +322,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):      def test_mode_strict_w_pemfile_should_pass(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl( +            c.convert_to_tls(                  sni="example.mitmproxy.org",                  verify=SSL.VERIFY_PEER,                  ca_pemfile=tutils.test_data.path("mitmproxy/net/data/verificationcerts/trusted-root.crt") @@ -338,7 +338,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):      def test_mode_strict_w_cadir_should_pass(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl( +            c.convert_to_tls(                  sni="example.mitmproxy.org",                  verify=SSL.VERIFY_PEER,                  ca_path=tutils.test_data.path("mitmproxy/net/data/verificationcerts/") @@ -372,7 +372,7 @@ class TestSSLClientCert(tservers.ServerTestBase):      def test_clientcert(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl( +            c.convert_to_tls(                  cert=tutils.test_data.path("mitmproxy/net/data/clientcert/client.pem"))              assert c.rfile.readline().strip() == b"1" @@ -380,7 +380,7 @@ class TestSSLClientCert(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect():              with pytest.raises(exceptions.TlsException): -                c.convert_to_ssl(cert=tutils.test_data.path("mitmproxy/net/data/clientcert/make")) +                c.convert_to_tls(cert=tutils.test_data.path("mitmproxy/net/data/clientcert/make"))  class TestSNI(tservers.ServerTestBase): @@ -400,15 +400,15 @@ class TestSNI(tservers.ServerTestBase):      def test_echo(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl(sni="foo.com") +            c.convert_to_tls(sni="foo.com")              assert c.sni == "foo.com"              assert c.rfile.readline() == b"foo.com"      def test_idn(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl(sni="mitmproxyäöüß.example.com") -            assert c.ssl_established +            c.convert_to_tls(sni="mitmproxyäöüß.example.com") +            assert c.tls_established              assert "doesn't match" not in str(c.ssl_verification_error) @@ -421,7 +421,7 @@ class TestServerCipherList(tservers.ServerTestBase):      def test_echo(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl(sni="foo.com") +            c.convert_to_tls(sni="foo.com")              expected = b"['AES256-GCM-SHA384']"              assert c.rfile.read(len(expected) + 2) == expected @@ -442,7 +442,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase):      def test_echo(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl(sni="foo.com") +            c.convert_to_tls(sni="foo.com")              assert b'AES256-GCM-SHA384' in c.rfile.readline() @@ -456,7 +456,7 @@ class TestServerCipherListError(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect():              with pytest.raises(Exception, match="handshake error"): -                c.convert_to_ssl(sni="foo.com") +                c.convert_to_tls(sni="foo.com")  class TestClientCipherListError(tservers.ServerTestBase): @@ -469,7 +469,7 @@ class TestClientCipherListError(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect():              with pytest.raises(Exception, match="cipher specification"): -                c.convert_to_ssl(sni="foo.com", cipher_list="bogus") +                c.convert_to_tls(sni="foo.com", cipher_list="bogus")  class TestSSLDisconnect(tservers.ServerTestBase): @@ -484,7 +484,7 @@ class TestSSLDisconnect(tservers.ServerTestBase):      def test_echo(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              # Excercise SSL.ZeroReturnError              c.rfile.read(10)              c.close() @@ -501,7 +501,7 @@ class TestSSLHardDisconnect(tservers.ServerTestBase):      def test_echo(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              # Exercise SSL.SysCallError              c.rfile.read(10)              c.close() @@ -565,7 +565,7 @@ class TestALPNClient(tservers.ServerTestBase):      def test_alpn(self, monkeypatch, alpn_protos, expected_negotiated, expected_response):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl(alpn_protos=alpn_protos) +            c.convert_to_tls(alpn_protos=alpn_protos)              assert c.get_alpn_proto_negotiated() == expected_negotiated              assert c.rfile.readline().strip() == expected_response @@ -587,7 +587,7 @@ class TestSSLTimeOut(tservers.ServerTestBase):      def test_timeout_client(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              c.settimeout(0.1)              with pytest.raises(exceptions.TcpTimeout):                  c.rfile.read(10) @@ -605,7 +605,7 @@ class TestDHParams(tservers.ServerTestBase):      def test_dhparams(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              ret = c.get_current_cipher()              assert ret[0] == "DHE-RSA-AES256-SHA" @@ -801,5 +801,5 @@ class TestPeekSSL(TestPeek):      def _connect(self, c):          with c.connect() as conn: -            c.convert_to_ssl() +            c.convert_to_tls()              return conn.pop() diff --git a/test/mitmproxy/net/test_tls.py b/test/mitmproxy/net/test_tls.py index d0583d34..489bf89f 100644 --- a/test/mitmproxy/net/test_tls.py +++ b/test/mitmproxy/net/test_tls.py @@ -1,3 +1,5 @@ +import io +  import pytest  from mitmproxy import exceptions @@ -6,6 +8,17 @@ from mitmproxy.net.tcp import TCPClient  from test.mitmproxy.net.test_tcp import EchoHandler  from . import tservers +CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex( +    "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" +    "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000" +    "61006200640100" +) +FULL_CLIENT_HELLO_NO_EXTENSIONS = ( +    b"\x16\x03\x03\x00\x65"  # record layer +    b"\x01\x00\x00\x61" +  # handshake header +    CLIENT_HELLO_NO_EXTENSIONS +) +  class TestMasterSecretLogger(tservers.ServerTestBase):      handler = EchoHandler @@ -22,7 +35,7 @@ class TestMasterSecretLogger(tservers.ServerTestBase):          c = TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              c.wfile.write(testval)              c.wfile.flush()              assert c.rfile.readline() == testval @@ -53,3 +66,92 @@ class TestTLSInvalid:          with pytest.raises(exceptions.TlsException, match="ALPN error"):              tls.create_client_context(alpn_select="foo", alpn_select_callback="bar") + + +def test_is_record_magic(): +    assert not tls.is_tls_record_magic(b"POST /") +    assert not tls.is_tls_record_magic(b"\x16\x03") +    assert not tls.is_tls_record_magic(b"\x16\x03\x04") +    assert tls.is_tls_record_magic(b"\x16\x03\x00") +    assert tls.is_tls_record_magic(b"\x16\x03\x01") +    assert tls.is_tls_record_magic(b"\x16\x03\x02") +    assert tls.is_tls_record_magic(b"\x16\x03\x03") + + +def test_get_client_hello(): +    rfile = io.BufferedReader(io.BytesIO( +        FULL_CLIENT_HELLO_NO_EXTENSIONS +    )) +    assert tls.get_client_hello(rfile) + +    rfile = io.BufferedReader(io.BytesIO( +        FULL_CLIENT_HELLO_NO_EXTENSIONS[:30] +    )) +    with pytest.raises(exceptions.TlsProtocolException, message="Unexpected EOF"): +        tls.get_client_hello(rfile) + +    rfile = io.BufferedReader(io.BytesIO( +        b"GET /" +    )) +    with pytest.raises(exceptions.TlsProtocolException, message="Expected TLS record"): +        tls.get_client_hello(rfile) + + +class TestClientHello: +    def test_no_extensions(self): +        c = tls.ClientHello(CLIENT_HELLO_NO_EXTENSIONS) +        assert repr(c) +        assert c.sni is None +        assert c.cipher_suites == [53, 47, 10, 5, 4, 9, 3, 6, 8, 96, 97, 98, 100] +        assert c.alpn_protocols == [] +        assert c.extensions == [] + +    def test_extensions(self): +        data = bytes.fromhex( +            "03033b70638d2523e1cba15f8364868295305e9c52aceabda4b5147210abc783e6e1000022c02bc02fc02cc030" +            "cca9cca8cc14cc13c009c013c00ac014009c009d002f0035000a0100006cff0100010000000010000e00000b65" +            "78616d706c652e636f6d0017000000230000000d00120010060106030501050304010403020102030005000501" +            "00000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a00080006001d00" +            "170018" +        ) +        c = tls.ClientHello(data) +        assert repr(c) +        assert c.sni == 'example.com' +        assert c.cipher_suites == [ +            49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161, +            49171, 49162, 49172, 156, 157, 47, 53, 10 +        ] +        assert c.alpn_protocols == [b'h2', b'http/1.1'] +        assert c.extensions == [ +            (65281, b'\x00'), +            (0, b'\x00\x0e\x00\x00\x0bexample.com'), +            (23, b''), +            (35, b''), +            (13, b'\x00\x10\x06\x01\x06\x03\x05\x01\x05\x03\x04\x01\x04\x03\x02\x01\x02\x03'), +            (5, b'\x01\x00\x00\x00\x00'), +            (18, b''), +            (16, b'\x00\x0c\x02h2\x08http/1.1'), +            (30032, b''), +            (11, b'\x01\x00'), +            (10, b'\x00\x06\x00\x1d\x00\x17\x00\x18') +        ] + +    def test_from_file(self): +        rfile = io.BufferedReader(io.BytesIO( +            FULL_CLIENT_HELLO_NO_EXTENSIONS +        )) +        assert tls.ClientHello.from_file(rfile) + +        rfile = io.BufferedReader(io.BytesIO( +            b"" +        )) +        with pytest.raises(exceptions.TlsProtocolException): +            tls.ClientHello.from_file(rfile) + +        rfile = io.BufferedReader(io.BytesIO( +            b"\x16\x03\x03\x00\x07"  # record layer +            b"\x01\x00\x00\x03" +  # handshake header +            b"foo" +        )) +        with pytest.raises(exceptions.TlsProtocolException, message='Cannot parse Client Hello'): +            tls.ClientHello.from_file(rfile) diff --git a/test/mitmproxy/net/tools/getcertnames b/test/mitmproxy/net/tools/getcertnames index d64e5ff5..9349415f 100644 --- a/test/mitmproxy/net/tools/getcertnames +++ b/test/mitmproxy/net/tools/getcertnames @@ -7,7 +7,7 @@ from mitmproxy.net import tcp  def get_remote_cert(host, port, sni):      c = tcp.TCPClient((host, port))      c.connect() -    c.convert_to_ssl(sni=sni) +    c.convert_to_tls(sni=sni)      return c.cert  if len(sys.argv) > 2: diff --git a/test/mitmproxy/net/tservers.py b/test/mitmproxy/net/tservers.py index 44701aa5..22e195e3 100644 --- a/test/mitmproxy/net/tservers.py +++ b/test/mitmproxy/net/tservers.py @@ -60,7 +60,7 @@ class _TServer(tcp.TCPServer):              else:                  method = OpenSSL.SSL.SSLv23_METHOD                  options = None -            h.convert_to_ssl( +            h.convert_to_tls(                  cert,                  key,                  method=method, diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index 4f161ef5..194a57c9 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -141,7 +141,7 @@ class _Http2TestBase:          while self.client.rfile.readline() != b"\r\n":              pass -        self.client.convert_to_ssl(alpn_protos=[b'h2']) +        self.client.convert_to_tls(alpn_protos=[b'h2'])          config = h2.config.H2Configuration(              client_side=True, diff --git a/test/mitmproxy/proxy/protocol/test_tls.py b/test/mitmproxy/proxy/protocol/test_tls.py index e17ee46f..e69de29b 100644 --- a/test/mitmproxy/proxy/protocol/test_tls.py +++ b/test/mitmproxy/proxy/protocol/test_tls.py @@ -1,26 +0,0 @@ -from mitmproxy.proxy.protocol.tls import TlsClientHello - - -class TestClientHello: - -    def test_no_extensions(self): -        data = bytes.fromhex( -            "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" -            "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000" -            "61006200640100" -        ) -        c = TlsClientHello(data) -        assert c.sni is None -        assert c.alpn_protocols == [] - -    def test_extensions(self): -        data = bytes.fromhex( -            "03033b70638d2523e1cba15f8364868295305e9c52aceabda4b5147210abc783e6e1000022c02bc02fc02cc030" -            "cca9cca8cc14cc13c009c013c00ac014009c009d002f0035000a0100006cff0100010000000010000e00000b65" -            "78616d706c652e636f6d0017000000230000000d00120010060106030501050304010403020102030005000501" -            "00000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a00080006001d00" -            "170018" -        ) -        c = TlsClientHello(data) -        assert c.sni == 'example.com' -        assert c.alpn_protocols == [b'h2', b'http/1.1'] diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index d9389faf..5cd9601c 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -101,8 +101,8 @@ class _WebSocketTestBase:          response = http.http1.read_response(self.client.rfile, request)          if self.ssl: -            self.client.convert_to_ssl() -            assert self.client.ssl_established +            self.client.convert_to_tls() +            assert self.client.tls_established          request = http.Request(              "relative", diff --git a/test/mitmproxy/proxy/test_server.py b/test/mitmproxy/proxy/test_server.py index 8dce9bcd..87ec443a 100644 --- a/test/mitmproxy/proxy/test_server.py +++ b/test/mitmproxy/proxy/test_server.py @@ -143,9 +143,9 @@ class TcpMixin:          # Test that we get the original SSL cert          if self.ssl: -            i_cert = certs.SSLCert(i.sslinfo.certchain[0]) -            i2_cert = certs.SSLCert(i2.sslinfo.certchain[0]) -            n_cert = certs.SSLCert(n.sslinfo.certchain[0]) +            i_cert = certs.Cert(i.sslinfo.certchain[0]) +            i2_cert = certs.Cert(i2.sslinfo.certchain[0]) +            n_cert = certs.Cert(n.sslinfo.certchain[0])              assert i_cert == i2_cert              assert i_cert != n_cert @@ -188,9 +188,9 @@ class TcpMixin:          # Test that we get the original SSL cert          if self.ssl: -            i_cert = certs.SSLCert(i.sslinfo.certchain[0]) -            i2_cert = certs.SSLCert(i2.sslinfo.certchain[0]) -            n_cert = certs.SSLCert(n.sslinfo.certchain[0]) +            i_cert = certs.Cert(i.sslinfo.certchain[0]) +            i2_cert = certs.Cert(i2.sslinfo.certchain[0]) +            n_cert = certs.Cert(n.sslinfo.certchain[0])              assert i_cert == i2_cert              assert i_cert != n_cert @@ -511,6 +511,14 @@ class TestReverse(tservers.ReverseProxyTest, CommonMixin, TcpMixin):          req = self.master.state.flows[0].request          assert req.host_header == "127.0.0.1" +    def test_selfconnection(self): +        self.options.mode = "reverse:http://127.0.0.1:0" + +        p = self.pathoc() +        with p.connect(): +            p.request("get:/") +        assert self.master.has_log("The proxy shall not connect to itself.") +  class TestReverseSSL(tservers.ReverseProxyTest, CommonMixin, TcpMixin):      reverse = True @@ -579,7 +587,7 @@ class TestSocks5SSL(tservers.SocksModeTest):          p = self.pathoc_raw()          with p.connect():              p.socks_connect(("localhost", self.server.port)) -            p.convert_to_ssl() +            p.convert_to_tls()              f = p.request("get:/p/200")          assert f.status_code == 200 @@ -709,7 +717,7 @@ class TestProxy(tservers.HTTPProxyTest):          first_flow = self.master.state.flows[0]          second_flow = self.master.state.flows[1]          assert first_flow.server_conn.timestamp_tcp_setup -        assert first_flow.server_conn.timestamp_ssl_setup is None +        assert first_flow.server_conn.timestamp_tls_setup is None          assert second_flow.server_conn.timestamp_tcp_setup          assert first_flow.server_conn.timestamp_tcp_setup == second_flow.server_conn.timestamp_tcp_setup @@ -723,12 +731,13 @@ class TestProxy(tservers.HTTPProxyTest):  class TestProxySSL(tservers.HTTPProxyTest):      ssl = True -    def test_request_ssl_setup_timestamp_presence(self): +    def test_request_tls_attribute_presence(self):          # tests that the ssl timestamp is present when ssl is used          f = self.pathod("304:b@10k")          assert f.status_code == 304          first_flow = self.master.state.flows[0] -        assert first_flow.server_conn.timestamp_ssl_setup +        assert first_flow.server_conn.timestamp_tls_setup +        assert first_flow.client_conn.tls_extensions      def test_via(self):          # tests that the ssl timestamp is present when ssl is used @@ -1149,7 +1158,7 @@ class AddUpstreamCertsToClientChainMixin:      def test_add_upstream_certs_to_client_chain(self):          with open(self.servercert, "rb") as f:              d = f.read() -        upstreamCert = certs.SSLCert.from_pem(d) +        upstreamCert = certs.Cert.from_pem(d)          p = self.pathoc()          with p.connect():              upstream_cert_found_in_client_chain = False diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py index 693bebc6..dcc185c0 100644 --- a/test/mitmproxy/test_certs.py +++ b/test/mitmproxy/test_certs.py @@ -136,18 +136,18 @@ class TestDummyCert:          assert r.altnames == [] -class TestSSLCert: +class TestCert:      def test_simple(self):          with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f:              d = f.read() -        c1 = certs.SSLCert.from_pem(d) +        c1 = certs.Cert.from_pem(d)          assert c1.cn == b"google.com"          assert len(c1.altnames) == 436          with open(tutils.test_data.path("mitmproxy/net/data/text_cert_2"), "rb") as f:              d = f.read() -        c2 = certs.SSLCert.from_pem(d) +        c2 = certs.Cert.from_pem(d)          assert c2.cn == b"www.inode.co.nz"          assert len(c2.altnames) == 2          assert c2.digest("sha1") @@ -165,20 +165,20 @@ class TestSSLCert:      def test_err_broken_sans(self):          with open(tutils.test_data.path("mitmproxy/net/data/text_cert_weird1"), "rb") as f:              d = f.read() -        c = certs.SSLCert.from_pem(d) +        c = certs.Cert.from_pem(d)          # This breaks unless we ignore a decoding error.          assert c.altnames is not None      def test_der(self):          with open(tutils.test_data.path("mitmproxy/net/data/dercert"), "rb") as f:              d = f.read() -        s = certs.SSLCert.from_der(d) +        s = certs.Cert.from_der(d)          assert s.cn      def test_state(self):          with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f:              d = f.read() -        c = certs.SSLCert.from_pem(d) +        c = certs.Cert.from_pem(d)          c.get_state()          c2 = c.copy() @@ -188,6 +188,6 @@ class TestSSLCert:          assert c == c2          assert c is not c2 -        x = certs.SSLCert('') +        x = certs.Cert('')          x.set_state(a)          assert x == c diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py index 83f0bd34..9e5d89f1 100644 --- a/test/mitmproxy/test_connections.py +++ b/test/mitmproxy/test_connections.py @@ -41,10 +41,10 @@ class TestClientConnection:      def test_tls_established_property(self):          c = tflow.tclient_conn()          c.tls_established = True -        assert c.ssl_established +        assert c.tls_established          assert c.tls_established          c.tls_established = False -        assert not c.ssl_established +        assert not c.tls_established          assert not c.tls_established      def test_make_dummy(self): @@ -113,10 +113,10 @@ class TestServerConnection:      def test_tls_established_property(self):          c = tflow.tserver_conn()          c.tls_established = True -        assert c.ssl_established +        assert c.tls_established          assert c.tls_established          c.tls_established = False -        assert not c.ssl_established +        assert not c.tls_established          assert not c.tls_established      def test_make_dummy(self): @@ -155,7 +155,7 @@ class TestServerConnection:      def test_sni(self):          c = connections.ServerConnection(('', 1234))          with pytest.raises(ValueError, matches='sni must be str, not '): -            c.establish_ssl(None, b'foobar') +            c.establish_tls(None, b'foobar')      def test_state(self):          c = tflow.tserver_conn() @@ -206,7 +206,7 @@ class TestClientConnectionTLS:          key = OpenSSL.crypto.load_privatekey(              OpenSSL.crypto.FILETYPE_PEM,              raw_key) -        c.convert_to_ssl(cert, key) +        c.convert_to_tls(cert, key)          assert c.connected()          assert c.sni == sni          assert c.tls_established @@ -230,7 +230,7 @@ class TestServerConnectionTLS(tservers.ServerTestBase):      def test_tls(self, clientcert):          c = connections.ServerConnection(("127.0.0.1", self.port))          c.connect() -        c.establish_ssl(clientcert, "foo.com") +        c.establish_tls(clientcert, "foo.com")          assert c.connected()          assert c.sni == "foo.com"          assert c.tls_established diff --git a/test/mitmproxy/test_stateobject.py b/test/mitmproxy/test_stateobject.py index d8c7a8e9..bd5d1792 100644 --- a/test/mitmproxy/test_stateobject.py +++ b/test/mitmproxy/test_stateobject.py @@ -1,101 +1,146 @@ -from typing import List +import typing +  import pytest  from mitmproxy.stateobject import StateObject -class Child(StateObject): +class TObject(StateObject):      def __init__(self, x):          self.x = x -    _stateobject_attributes = dict( -        x=int -    ) -      @classmethod      def from_state(cls, state):          obj = cls(None)          obj.set_state(state)          return obj + +class Child(TObject): +    _stateobject_attributes = dict( +        x=int +    ) +      def __eq__(self, other):          return isinstance(other, Child) and self.x == other.x -class Container(StateObject): -    def __init__(self): -        self.child = None -        self.children = None -        self.dictionary = None +class TTuple(TObject): +    _stateobject_attributes = dict( +        x=typing.Tuple[int, Child] +    ) + + +class TList(TObject): +    _stateobject_attributes = dict( +        x=typing.List[Child] +    ) + +class TDict(TObject):      _stateobject_attributes = dict( -        child=Child, -        children=List[Child], -        dictionary=dict, +        x=typing.Dict[str, Child]      ) -    @classmethod -    def from_state(cls, state): -        obj = cls() -        obj.set_state(state) -        return obj + +class TAny(TObject): +    _stateobject_attributes = dict( +        x=typing.Any +    ) + + +class TSerializableChild(TObject): +    _stateobject_attributes = dict( +        x=Child +    )  def test_simple():      a = Child(42) +    assert a.get_state() == {"x": 42}      b = a.copy() -    assert b.get_state() == {"x": 42}      a.set_state({"x": 44})      assert a.x == 44      assert b.x == 42 -def test_container(): -    a = Container() -    a.child = Child(42) +def test_serializable_child(): +    child = Child(42) +    a = TSerializableChild(child) +    assert a.get_state() == { +        "x": {"x": 42} +    } +    a.set_state({ +        "x": {"x": 43} +    }) +    assert a.x.x == 43 +    assert a.x is child      b = a.copy() -    assert a.child.x == b.child.x -    b.child.x = 44 -    assert a.child.x != b.child.x +    assert a.x == b.x +    assert a.x is not b.x -def test_container_list(): -    a = Container() -    a.children = [Child(42), Child(44)] +def test_tuple(): +    a = TTuple((42, Child(43)))      assert a.get_state() == { -        "child": None, -        "children": [{"x": 42}, {"x": 44}], -        "dictionary": None, +        "x": (42, {"x": 43})      } -    copy = a.copy() -    assert len(copy.children) == 2 -    assert copy.children is not a.children -    assert copy.children[0] is not a.children[0] -    assert Container.from_state(a.get_state()) +    b = a.copy() +    a.set_state({"x": (44, {"x": 45})}) +    assert a.x == (44, Child(45)) +    assert b.x == (42, Child(43)) + +def test_tuple_err(): +    a = TTuple(None) +    with pytest.raises(ValueError, msg="Invalid data"): +        a.set_state({"x": (42,)}) -def test_container_dict(): -    a = Container() -    a.dictionary = dict() -    a.dictionary['foo'] = 'bar' -    a.dictionary['bar'] = Child(44) + +def test_list(): +    a = TList([Child(1), Child(2)])      assert a.get_state() == { -        "child": None, -        "children": None, -        "dictionary": {'bar': {'x': 44}, 'foo': 'bar'}, +        "x": [{"x": 1}, {"x": 2}],      }      copy = a.copy() -    assert len(copy.dictionary) == 2 -    assert copy.dictionary is not a.dictionary -    assert copy.dictionary['bar'] is not a.dictionary['bar'] +    assert len(copy.x) == 2 +    assert copy.x is not a.x +    assert copy.x[0] is not a.x[0] + + +def test_dict(): +    a = TDict({"foo": Child(42)}) +    assert a.get_state() == { +        "x": {"foo": {"x": 42}} +    } +    b = a.copy() +    assert list(a.x.items()) == list(b.x.items()) +    assert a.x is not b.x +    assert a.x["foo"] is not b.x["foo"] + + +def test_any(): +    a = TAny(42) +    b = a.copy() +    assert a.x == b.x + +    a = TAny(object()) +    with pytest.raises(AssertionError): +        a.get_state()  def test_too_much_state(): -    a = Container() -    a.child = Child(42) +    a = Child(42)      s = a.get_state()      s['foo'] = 'bar' -    b = Container()      with pytest.raises(RuntimeWarning): -        b.set_state(s) +        a.set_state(s) + + +def test_none(): +    a = Child(None) +    assert a.get_state() == {"x": None} +    a = Child(42) +    a.set_state({"x": None}) +    assert a.x is None diff --git a/test/mitmproxy/test_version.py b/test/mitmproxy/test_version.py index f8d646dc..8c176542 100644 --- a/test/mitmproxy/test_version.py +++ b/test/mitmproxy/test_version.py @@ -1,3 +1,4 @@ +import pathlib  import runpy  import subprocess  from unittest import mock @@ -6,7 +7,9 @@ from mitmproxy import version  def test_version(capsys): -    runpy.run_module('mitmproxy.version', run_name='__main__') +    here = pathlib.Path(__file__).absolute().parent +    version_file = here / ".." / ".." / "mitmproxy" / "version.py" +    runpy.run_path(str(version_file), run_name='__main__')      stdout, stderr = capsys.readouterr()      assert len(stdout) > 0      assert stdout.strip() == version.VERSION @@ -27,7 +30,7 @@ def test_get_version():          assert version.get_version(True, True) == "3.0.0"          m.return_value = b"tag-2-cafecafe" -        assert version.get_version(True, True) == "3.0.0.dev0002-0xcafecaf" +        assert version.get_version(True, True) == "3.0.0.dev2-0xcafecaf" -        m.side_effect = subprocess.CalledProcessError(-1, 'git describe --tags --long') +        m.side_effect = subprocess.CalledProcessError(-1, 'git describe --long')          assert version.get_version(True, True) == "3.0.0" diff --git a/test/mitmproxy/tools/console/test_common.py b/test/mitmproxy/tools/console/test_common.py index 3ab4fd67..72438c49 100644 --- a/test/mitmproxy/tools/console/test_common.py +++ b/test/mitmproxy/tools/console/test_common.py @@ -1,12 +1,34 @@ +import urwid +  from mitmproxy.test import tflow  from mitmproxy.tools.console import common -from ....conftest import skip_appveyor - -@skip_appveyor  def test_format_flow():      f = tflow.tflow(resp=True)      assert common.format_flow(f, True)      assert common.format_flow(f, True, hostheader=True)      assert common.format_flow(f, True, extended=True) + + +def test_format_keyvals(): +    assert common.format_keyvals( +        [ +            ("aa", "bb"), +            ("cc", "dd"), +            ("ee", None), +        ] +    ) +    wrapped = urwid.BoxAdapter( +        urwid.ListBox( +            urwid.SimpleFocusListWalker( +                common.format_keyvals([("foo", "bar")]) +            ) +        ), 1 +    ) +    assert wrapped.render((30, )) +    assert common.format_keyvals( +        [ +            ("aa", wrapped) +        ] +    ) diff --git a/test/mitmproxy/tools/console/test_master.py b/test/mitmproxy/tools/console/test_master.py index 3aa0dc54..9779a482 100644 --- a/test/mitmproxy/tools/console/test_master.py +++ b/test/mitmproxy/tools/console/test_master.py @@ -4,22 +4,9 @@ from mitmproxy import options  from mitmproxy.test import tflow  from mitmproxy.test import tutils  from mitmproxy.tools import console -from mitmproxy.tools.console import common  from ... import tservers -def test_format_keyvals(): -    assert common.format_keyvals( -        [ -            ("aa", "bb"), -            None, -            ("cc", "dd"), -            (None, "dd"), -            (None, "dd"), -        ] -    ) - -  def test_options():      assert options.Options(replay_kill_extra=True) diff --git a/test/mitmproxy/utils/test_typecheck.py b/test/mitmproxy/utils/test_typecheck.py index 5295fff5..9cb4334e 100644 --- a/test/mitmproxy/utils/test_typecheck.py +++ b/test/mitmproxy/utils/test_typecheck.py @@ -93,3 +93,8 @@ def test_typesec_to_str():      assert(typecheck.typespec_to_str(typing.Optional[str])) == "optional str"      with pytest.raises(NotImplementedError):          typecheck.typespec_to_str(dict) + + +def test_mapping_types(): +    # this is not covered by check_option_type, but still belongs in this module +    assert (str, int) == typecheck.mapping_types(typing.Mapping[str, int]) diff --git a/test/pathod/protocols/test_http2.py b/test/pathod/protocols/test_http2.py index b1eebc73..95965cee 100644 --- a/test/pathod/protocols/test_http2.py +++ b/test/pathod/protocols/test_http2.py @@ -75,7 +75,7 @@ class TestCheckALPNMatch(net_tservers.ServerTestBase):      def test_check_alpn(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl(alpn_protos=[b'h2']) +            c.convert_to_tls(alpn_protos=[b'h2'])              protocol = HTTP2StateProtocol(c)              assert protocol.check_alpn() @@ -89,7 +89,7 @@ class TestCheckALPNMismatch(net_tservers.ServerTestBase):      def test_check_alpn(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl(alpn_protos=[b'h2']) +            c.convert_to_tls(alpn_protos=[b'h2'])              protocol = HTTP2StateProtocol(c)              with pytest.raises(NotImplementedError):                  protocol.check_alpn() @@ -207,7 +207,7 @@ class TestApplySettings(net_tservers.ServerTestBase):      def test_apply_settings(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              protocol = HTTP2StateProtocol(c)              protocol._apply_settings({ @@ -302,7 +302,7 @@ class TestReadRequest(net_tservers.ServerTestBase):      def test_read_request(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              protocol = HTTP2StateProtocol(c, is_server=True)              protocol.connection_preface_performed = True @@ -328,7 +328,7 @@ class TestReadRequestRelative(net_tservers.ServerTestBase):      def test_asterisk_form(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              protocol = HTTP2StateProtocol(c, is_server=True)              protocol.connection_preface_performed = True @@ -351,7 +351,7 @@ class TestReadRequestAbsolute(net_tservers.ServerTestBase):      def test_absolute_form(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              protocol = HTTP2StateProtocol(c, is_server=True)              protocol.connection_preface_performed = True @@ -378,7 +378,7 @@ class TestReadResponse(net_tservers.ServerTestBase):      def test_read_response(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              protocol = HTTP2StateProtocol(c)              protocol.connection_preface_performed = True @@ -404,7 +404,7 @@ class TestReadEmptyResponse(net_tservers.ServerTestBase):      def test_read_empty_response(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          with c.connect(): -            c.convert_to_ssl() +            c.convert_to_tls()              protocol = HTTP2StateProtocol(c)              protocol.connection_preface_performed = True diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py index 4b50e2a7..297b54d4 100644 --- a/test/pathod/test_pathoc.py +++ b/test/pathod/test_pathoc.py @@ -238,11 +238,11 @@ class TestDaemonHTTP2(PathocTestDaemon):              http2_skip_connection_preface=True,          ) -        tmp_convert_to_ssl = c.convert_to_ssl -        c.convert_to_ssl = Mock() -        c.convert_to_ssl.side_effect = tmp_convert_to_ssl +        tmp_convert_to_tls = c.convert_to_tls +        c.convert_to_tls = Mock() +        c.convert_to_tls.side_effect = tmp_convert_to_tls          with c.connect(): -            _, kwargs = c.convert_to_ssl.call_args +            _, kwargs = c.convert_to_tls.call_args              assert set(kwargs['alpn_protos']) == set([b'http/1.1', b'h2'])      def test_request(self): diff --git a/test/pathod/test_pathod.py b/test/pathod/test_pathod.py index c0011952..d6522cb6 100644 --- a/test/pathod/test_pathod.py +++ b/test/pathod/test_pathod.py @@ -153,7 +153,7 @@ class CommonTests(tservers.DaemonTests):          c = tcp.TCPClient(("localhost", self.d.port))          with c.connect():              if self.ssl: -                c.convert_to_ssl() +                c.convert_to_tls()              c.wfile.write(b"foo\n\n\n")              c.wfile.flush()              l = self.d.last_log() @@ -241,7 +241,7 @@ class TestDaemonSSL(CommonTests):          with c.connect():              c.wfile.write(b"\0\0\0\0")              with pytest.raises(exceptions.TlsException): -                c.convert_to_ssl() +                c.convert_to_tls()              l = self.d.last_log()              assert l["type"] == "error"              assert "SSL" in l["msg"]  | 
