diff options
author | Nikhil Soni <krsoninikhil@gmail.com> | 2017-03-03 12:58:44 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-03 12:58:44 +0530 |
commit | 0081d9b82807b178bd6d00ca250d38aeeeed2d33 (patch) | |
tree | 8d244a7b9ade13f0e1836f8eb9fbc534a1cc662d /test | |
parent | 3da8532bed3305b01e3f3ab556f9dbc652177c6b (diff) | |
parent | bae4cdf8d5cc434938c74a041f762075513dd8e4 (diff) | |
download | mitmproxy-0081d9b82807b178bd6d00ca250d38aeeeed2d33.tar.gz mitmproxy-0081d9b82807b178bd6d00ca250d38aeeeed2d33.tar.bz2 mitmproxy-0081d9b82807b178bd6d00ca250d38aeeeed2d33.zip |
Merge branch 'master' into on-issues
Diffstat (limited to 'test')
-rw-r--r-- | test/mitmproxy/addons/test_dumper.py | 4 | ||||
-rwxr-xr-x | test/mitmproxy/examples/test_xss_scanner.py | 368 | ||||
-rw-r--r-- | test/mitmproxy/net/test_check.py | 1 | ||||
-rw-r--r-- | test/mitmproxy/net/test_socks.py | 5 | ||||
-rw-r--r-- | test/mitmproxy/net/test_tcp.py | 24 | ||||
-rw-r--r-- | test/mitmproxy/net/tservers.py | 4 | ||||
-rw-r--r-- | test/mitmproxy/proxy/protocol/test_http2.py | 93 | ||||
-rw-r--r-- | test/mitmproxy/proxy/protocol/test_websocket.py | 8 | ||||
-rw-r--r-- | test/mitmproxy/proxy/test_server.py | 9 | ||||
-rw-r--r-- | test/mitmproxy/test_certs.py | 27 | ||||
-rw-r--r-- | test/mitmproxy/test_connections.py | 211 | ||||
-rw-r--r-- | test/mitmproxy/test_eventsequence.py | 2 | ||||
-rw-r--r-- | test/mitmproxy/test_flow.py | 247 | ||||
-rw-r--r-- | test/mitmproxy/test_flowfilter.py | 109 | ||||
-rw-r--r-- | test/mitmproxy/test_http.py | 257 | ||||
-rw-r--r-- | test/mitmproxy/test_optmanager.py | 22 | ||||
-rw-r--r-- | test/mitmproxy/test_proxy.py | 47 | ||||
-rw-r--r-- | test/mitmproxy/tservers.py | 5 |
18 files changed, 1022 insertions, 421 deletions
diff --git a/test/mitmproxy/addons/test_dumper.py b/test/mitmproxy/addons/test_dumper.py index 6a66d0c9..22d2c2c6 100644 --- a/test/mitmproxy/addons/test_dumper.py +++ b/test/mitmproxy/addons/test_dumper.py @@ -70,7 +70,7 @@ def test_simple(): flow.request = tutils.treq() flow.request.stickycookie = True flow.client_conn = mock.MagicMock() - flow.client_conn.address.host = "foo" + flow.client_conn.address[0] = "foo" flow.response = tutils.tresp(content=None) flow.response.is_replay = True flow.response.status_code = 300 @@ -176,7 +176,7 @@ def test_websocket(): ctx.configure(d, flow_detail=3, showhost=True) f = tflow.twebsocketflow() d.websocket_message(f) - assert "hello text" in sio.getvalue() + assert "it's me" in sio.getvalue() sio.truncate(0) d.websocket_end(f) diff --git a/test/mitmproxy/examples/test_xss_scanner.py b/test/mitmproxy/examples/test_xss_scanner.py new file mode 100755 index 00000000..14ee6902 --- /dev/null +++ b/test/mitmproxy/examples/test_xss_scanner.py @@ -0,0 +1,368 @@ +import pytest +import requests +from examples.complex import xss_scanner as xss +from mitmproxy.test import tflow, tutils + + +class TestXSSScanner(): + def test_get_XSS_info(self): + # First type of exploit: <script>PAYLOAD</script> + # Exploitable: + xss_info = xss.get_XSS_data(b"<html><script>%s</script><html>" % + xss.FULL_PAYLOAD, + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData('https://example.com', + "End of URL", + '</script><script>alert(0)</script><script>', + xss.FULL_PAYLOAD.decode('utf-8')) + assert xss_info == expected_xss_info + xss_info = xss.get_XSS_data(b"<html><script>%s</script><html>" % + xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22"), + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData("https://example.com", + "End of URL", + '</script><script>alert(0)</script><script>', + xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22").decode('utf-8')) + assert xss_info == expected_xss_info + # Non-Exploitable: + xss_info = xss.get_XSS_data(b"<html><script>%s</script><html>" % + xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22").replace(b"/", b"%2F"), + "https://example.com", + "End of URL") + assert xss_info is None + # Second type of exploit: <script>t='PAYLOAD'</script> + # Exploitable: + xss_info = xss.get_XSS_data(b"<html><script>t='%s';</script></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"\"", b"%22"), + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData("https://example.com", + "End of URL", + "';alert(0);g='", + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") + .replace(b"\"", b"%22").decode('utf-8')) + assert xss_info == expected_xss_info + # Non-Exploitable: + xss_info = xss.get_XSS_data(b"<html><script>t='%s';</script></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b"\"", b"%22").replace(b"'", b"%22"), + "https://example.com", + "End of URL") + assert xss_info is None + # Third type of exploit: <script>t="PAYLOAD"</script> + # Exploitable: + xss_info = xss.get_XSS_data(b"<html><script>t=\"%s\";</script></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"'", b"%27"), + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData("https://example.com", + "End of URL", + '";alert(0);g="', + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") + .replace(b"'", b"%27").decode('utf-8')) + assert xss_info == expected_xss_info + # Non-Exploitable: + xss_info = xss.get_XSS_data(b"<html><script>t=\"%s\";</script></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b"'", b"%27").replace(b"\"", b"%22"), + "https://example.com", + "End of URL") + assert xss_info is None + # Fourth type of exploit: <a href='PAYLOAD'>Test</a> + # Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href='%s'>Test</a></html>" % + xss.FULL_PAYLOAD, + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData("https://example.com", + "End of URL", + "'><script>alert(0)</script>", + xss.FULL_PAYLOAD.decode('utf-8')) + assert xss_info == expected_xss_info + # Non-Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href='OtherStuff%s'>Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"'", b"%27"), + "https://example.com", + "End of URL") + assert xss_info is None + # Fifth type of exploit: <a href="PAYLOAD">Test</a> + # Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href=\"%s\">Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"'", b"%27"), + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData("https://example.com", + "End of URL", + "\"><script>alert(0)</script>", + xss.FULL_PAYLOAD.replace(b"'", b"%27").decode('utf-8')) + assert xss_info == expected_xss_info + # Non-Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href=\"OtherStuff%s\">Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b"\"", b"%22"), + "https://example.com", + "End of URL") + assert xss_info is None + # Sixth type of exploit: <a href=PAYLOAD>Test</a> + # Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href=%s>Test</a></html>" % + xss.FULL_PAYLOAD, + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData("https://example.com", + "End of URL", + "><script>alert(0)</script>", + xss.FULL_PAYLOAD.decode('utf-8')) + assert xss_info == expected_xss_info + # Non-Exploitable + xss_info = xss.get_XSS_data(b"<html><a href=OtherStuff%s>Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") + .replace(b"=", b"%3D"), + "https://example.com", + "End of URL") + assert xss_info is None + # Seventh type of exploit: <html>PAYLOAD</html> + # Exploitable: + xss_info = xss.get_XSS_data(b"<html><b>%s</b></html>" % + xss.FULL_PAYLOAD, + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData("https://example.com", + "End of URL", + "<script>alert(0)</script>", + xss.FULL_PAYLOAD.decode('utf-8')) + assert xss_info == expected_xss_info + # Non-Exploitable + xss_info = xss.get_XSS_data(b"<html><b>%s</b></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"/", b"%2F"), + "https://example.com", + "End of URL") + assert xss_info is None + # Eighth type of exploit: <a href=PAYLOAD>Test</a> + # Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href=%s>Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData("https://example.com", + "End of URL", + "Javascript:alert(0)", + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) + assert xss_info == expected_xss_info + # Non-Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href=OtherStuff%s>Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") + .replace(b"=", b"%3D"), + "https://example.com", + "End of URL") + assert xss_info is None + # Ninth type of exploit: <a href="STUFF PAYLOAD">Test</a> + # Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href=\"STUFF %s\">Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData("https://example.com", + "End of URL", + '" onmouseover="alert(0)" t="', + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) + assert xss_info == expected_xss_info + # Non-Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href=\"STUFF %s\">Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") + .replace(b'"', b"%22"), + "https://example.com", + "End of URL") + assert xss_info is None + # Tenth type of exploit: <a href='STUFF PAYLOAD'>Test</a> + # Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href='STUFF %s'>Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData("https://example.com", + "End of URL", + "' onmouseover='alert(0)' t='", + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) + assert xss_info == expected_xss_info + # Non-Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href='STUFF %s'>Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") + .replace(b"'", b"%22"), + "https://example.com", + "End of URL") + assert xss_info is None + # Eleventh type of exploit: <a href=STUFF_PAYLOAD>Test</a> + # Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href=STUFF%s>Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), + "https://example.com", + "End of URL") + expected_xss_info = xss.XSSData("https://example.com", + "End of URL", + " onmouseover=alert(0) t=", + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) + assert xss_info == expected_xss_info + # Non-Exploitable: + xss_info = xss.get_XSS_data(b"<html><a href=STUFF_%s>Test</a></html>" % + xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") + .replace(b"=", b"%3D"), + "https://example.com", + "End of URL") + assert xss_info is None + + def test_get_SQLi_data(self): + sqli_data = xss.get_SQLi_data("<html>SQL syntax MySQL</html>", + "<html></html>", + "https://example.com", + "End of URL") + expected_sqli_data = xss.SQLiData("https://example.com", + "End of URL", + "SQL syntax.*MySQL", + "MySQL") + assert sqli_data == expected_sqli_data + sqli_data = xss.get_SQLi_data("<html>SQL syntax MySQL</html>", + "<html>SQL syntax MySQL</html>", + "https://example.com", + "End of URL") + assert sqli_data is None + + def test_inside_quote(self): + assert not xss.inside_quote("'", b"no", 0, b"no") + assert xss.inside_quote("'", b"yes", 0, b"'yes'") + assert xss.inside_quote("'", b"yes", 1, b"'yes'otherJunk'yes'more") + assert not xss.inside_quote("'", b"longStringNotInIt", 1, b"short") + + def test_paths_to_text(self): + text = xss.paths_to_text("""<html><head><h1>STRING</h1></head> + <script>STRING</script> + <a href=STRING></a></html>""", "STRING") + expected_text = ["/html/head/h1", "/html/script"] + assert text == expected_text + assert xss.paths_to_text("""<html></html>""", "STRING") == [] + + def mocked_requests_vuln(*args, headers=None, cookies=None): + class MockResponse: + def __init__(self, html, headers=None, cookies=None): + self.text = html + return MockResponse("<html>%s</html>" % xss.FULL_PAYLOAD) + + def mocked_requests_invuln(*args, headers=None, cookies=None): + class MockResponse: + def __init__(self, html, headers=None, cookies=None): + self.text = html + return MockResponse("<html></html>") + + def test_test_end_of_url_injection(self, monkeypatch): + monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln) + xss_info = xss.test_end_of_URL_injection("<html></html>", "https://example.com/index.html", {})[0] + expected_xss_info = xss.XSSData('https://example.com/index.html/1029zxcs\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\eq=3847asd', + 'End of URL', + '<script>alert(0)</script>', + '1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd') + sqli_info = xss.test_end_of_URL_injection("<html></html>", "https://example.com/", {})[1] + assert xss_info == expected_xss_info + assert sqli_info is None + + def test_test_referer_injection(self, monkeypatch): + monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln) + xss_info = xss.test_referer_injection("<html></html>", "https://example.com/", {})[0] + expected_xss_info = xss.XSSData('https://example.com/', + 'Referer', + '<script>alert(0)</script>', + '1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd') + sqli_info = xss.test_referer_injection("<html></html>", "https://example.com/", {})[1] + assert xss_info == expected_xss_info + assert sqli_info is None + + def test_test_user_agent_injection(self, monkeypatch): + monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln) + xss_info = xss.test_user_agent_injection("<html></html>", "https://example.com/", {})[0] + expected_xss_info = xss.XSSData('https://example.com/', + 'User Agent', + '<script>alert(0)</script>', + '1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd') + sqli_info = xss.test_user_agent_injection("<html></html>", "https://example.com/", {})[1] + assert xss_info == expected_xss_info + assert sqli_info is None + + def test_test_query_injection(self, monkeypatch): + monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln) + xss_info = xss.test_query_injection("<html></html>", "https://example.com/vuln.php?cmd=ls", {})[0] + expected_xss_info = xss.XSSData('https://example.com/vuln.php?cmd=1029zxcs\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\eq=3847asd', + 'Query', + '<script>alert(0)</script>', + '1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd') + sqli_info = xss.test_query_injection("<html></html>", "https://example.com/vuln.php?cmd=ls", {})[1] + assert xss_info == expected_xss_info + assert sqli_info is None + + @pytest.fixture + def logger(self): + class Logger(): + def __init__(self): + self.args = [] + + def error(self, str): + self.args.append(str) + return Logger() + + def test_find_unclaimed_URLs(self, monkeypatch, logger): + logger.args = [] + monkeypatch.setattr("mitmproxy.ctx.log", logger) + xss.find_unclaimed_URLs("<html><script src=\"http://google.com\"></script></html>", + "https://example.com") + assert logger.args == [] + xss.find_unclaimed_URLs("<html><script src=\"http://unclaimedDomainName.com\"></script></html>", + "https://example.com") + assert logger.args[0] == 'XSS found in https://example.com due to unclaimed URL "http://unclaimedDomainName.com" in script tag.' + + def test_log_XSS_data(self, monkeypatch, logger): + logger.args = [] + monkeypatch.setattr("mitmproxy.ctx.log", logger) + xss.log_XSS_data(None) + assert logger.args == [] + # self, url: str, injection_point: str, exploit: str, line: str + xss.log_XSS_data(xss.XSSData('https://example.com', + 'Location', + 'String', + 'Line of HTML')) + assert logger.args[0] == '===== XSS Found ====' + assert logger.args[1] == 'XSS URL: https://example.com' + assert logger.args[2] == 'Injection Point: Location' + assert logger.args[3] == 'Suggested Exploit: String' + assert logger.args[4] == 'Line: Line of HTML' + + def test_log_SQLi_data(self, monkeypatch, logger): + logger.args = [] + monkeypatch.setattr("mitmproxy.ctx.log", logger) + xss.log_SQLi_data(None) + assert logger.args == [] + xss.log_SQLi_data(xss.SQLiData(b'https://example.com', + b'Location', + b'Oracle.*Driver', + b'Oracle')) + assert logger.args[0] == '===== SQLi Found =====' + assert logger.args[1] == 'SQLi URL: https://example.com' + assert logger.args[2] == 'Injection Point: Location' + assert logger.args[3] == 'Regex used: Oracle.*Driver' + + def test_get_cookies(self): + mocked_req = tutils.treq() + mocked_req.cookies = [("cookieName2", "cookieValue2")] + mocked_flow = tflow.tflow(req=mocked_req) + # It only uses the request cookies + assert xss.get_cookies(mocked_flow) == {"cookieName2": "cookieValue2"} + + def test_response(self, monkeypatch, logger): + logger.args = [] + monkeypatch.setattr("mitmproxy.ctx.log", logger) + monkeypatch.setattr(requests, 'get', self.mocked_requests_invuln) + mocked_flow = tflow.tflow(req=tutils.treq(path=b"index.html?q=1"), resp=tutils.tresp(content=b'<html></html>')) + xss.response(mocked_flow) + assert logger.args == [] + + def test_data_equals(self): + xssData = xss.XSSData("a", "b", "c", "d") + sqliData = xss.SQLiData("a", "b", "c", "d") + assert xssData == xssData + assert sqliData == sqliData diff --git a/test/mitmproxy/net/test_check.py b/test/mitmproxy/net/test_check.py index 9dbc02e0..0ffd6b2e 100644 --- a/test/mitmproxy/net/test_check.py +++ b/test/mitmproxy/net/test_check.py @@ -11,3 +11,4 @@ def test_is_valid_host(): assert check.is_valid_host(b"one.two.") # Allow underscore assert check.is_valid_host(b"one_two") + assert check.is_valid_host(b"::1") diff --git a/test/mitmproxy/net/test_socks.py b/test/mitmproxy/net/test_socks.py index e00dd410..fbd31ef4 100644 --- a/test/mitmproxy/net/test_socks.py +++ b/test/mitmproxy/net/test_socks.py @@ -3,7 +3,6 @@ from io import BytesIO import pytest from mitmproxy.net import socks -from mitmproxy.net import tcp from mitmproxy.test import tutils @@ -176,7 +175,7 @@ def test_message_ipv6(): msg.to_file(out) assert out.getvalue() == raw.getvalue()[:-2] - assert msg.addr.host == ipv6_addr + assert msg.addr[0] == ipv6_addr def test_message_invalid_host(): @@ -196,6 +195,6 @@ def test_message_unknown_atyp(): with pytest.raises(socks.SocksError): socks.Message.from_file(raw) - m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) + m = socks.Message(5, 1, 0x02, ("example.com", 5050)) with pytest.raises(socks.SocksError): m.to_file(BytesIO()) diff --git a/test/mitmproxy/net/test_tcp.py b/test/mitmproxy/net/test_tcp.py index ff6362c8..cf010f6e 100644 --- a/test/mitmproxy/net/test_tcp.py +++ b/test/mitmproxy/net/test_tcp.py @@ -116,11 +116,11 @@ class TestServerBind(tservers.ServerTestBase): class TestServerIPv6(tservers.ServerTestBase): handler = EchoHandler - addr = tcp.Address(("localhost", 0), use_ipv6=True) + addr = ("::1", 0) def test_echo(self): testval = b"echo!\n" - c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True)) + c = tcp.TCPClient(("::1", self.port)) with c.connect(): c.wfile.write(testval) c.wfile.flush() @@ -132,7 +132,7 @@ class TestEcho(tservers.ServerTestBase): def test_echo(self): testval = b"echo!\n" - c = tcp.TCPClient(("127.0.0.1", self.port)) + c = tcp.TCPClient(("localhost", self.port)) with c.connect(): c.wfile.write(testval) c.wfile.flush() @@ -602,12 +602,6 @@ class TestDHParams(tservers.ServerTestBase): ret = c.get_current_cipher() assert ret[0] == "DHE-RSA-AES256-SHA" - def test_create_dhparams(self): - with tutils.tmpdir() as d: - filename = os.path.join(d, "dhparam.pem") - certs.CertStore.load_dhparam(filename) - assert os.path.exists(filename) - class TestTCPClient: @@ -783,18 +777,6 @@ class TestPeekSSL(TestPeek): return conn.pop() -class TestAddress: - def test_simple(self): - a = tcp.Address(("localhost", 80), True) - assert a.use_ipv6 - b = tcp.Address(("foo.com", 80), True) - assert not a == b - c = tcp.Address(("localhost", 80), True) - assert a == c - assert not a != c - assert repr(a) == "localhost:80" - - class TestSSLKeyLogger(tservers.ServerTestBase): handler = EchoHandler ssl = dict( diff --git a/test/mitmproxy/net/tservers.py b/test/mitmproxy/net/tservers.py index 68a2caa0..ebe6d3eb 100644 --- a/test/mitmproxy/net/tservers.py +++ b/test/mitmproxy/net/tservers.py @@ -86,13 +86,13 @@ class _TServer(tcp.TCPServer): class ServerTestBase: ssl = None handler = None - addr = ("localhost", 0) + addr = ("127.0.0.1", 0) @classmethod def setup_class(cls, **kwargs): cls.q = queue.Queue() s = cls.makeserver(**kwargs) - cls.port = s.address.port + cls.port = s.address[1] cls.server = _ServerThread(s) cls.server.start() diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index cb9c0474..871d02fe 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -124,10 +124,10 @@ class _Http2TestBase: b'CONNECT', b'', b'localhost', - self.server.server.address.port, + self.server.server.address[1], b'/', b'HTTP/1.1', - [(b'host', b'localhost:%d' % self.server.server.address.port)], + [(b'host', b'localhost:%d' % self.server.server.address[1])], b'', ))) client.wfile.flush() @@ -231,7 +231,7 @@ class TestSimple(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -272,75 +272,6 @@ class TestSimple(_Http2Test): @requires_alpn -class TestForbiddenHeaders(_Http2Test): - - @classmethod - def handle_server_event(cls, event, h2_conn, rfile, wfile): - if isinstance(event, h2.events.ConnectionTerminated): - return False - elif isinstance(event, h2.events.StreamEnded): - import warnings - with warnings.catch_warnings(): - # Ignore UnicodeWarning: - # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison - # failed to convert both arguments to Unicode - interpreting - # them as being unequal. - # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: - - warnings.simplefilter("ignore") - - h2_conn.config.validate_outbound_headers = False - h2_conn.send_headers(event.stream_id, [ - (':status', '200'), - ('keep-alive', 'foobar'), - ]) - h2_conn.send_data(event.stream_id, b'response body') - h2_conn.end_stream(event.stream_id) - wfile.write(h2_conn.data_to_send()) - wfile.flush() - return True - - def test_forbidden_headers(self): - client, h2_conn = self._setup_connection() - - self._send_request( - client.wfile, - h2_conn, - headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ]) - - done = False - while not done: - try: - raw = b''.join(http2.read_raw_frame(client.rfile)) - events = h2_conn.receive_data(raw) - except exceptions.HttpException: - print(traceback.format_exc()) - assert False - - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - for event in events: - if isinstance(event, h2.events.ResponseReceived): - assert 'keep-alive' not in event.headers - elif isinstance(event, h2.events.StreamEnded): - done = True - - h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() - - assert len(self.master.state.flows) == 1 - assert self.master.state.flows[0].response.status_code == 200 - assert self.master.state.flows[0].response.headers['keep-alive'] == 'foobar' - - -@requires_alpn class TestRequestWithPriority(_Http2Test): @classmethod @@ -384,7 +315,7 @@ class TestRequestWithPriority(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -469,7 +400,7 @@ class TestPriority(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -527,7 +458,7 @@ class TestStreamResetFromServer(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -576,7 +507,7 @@ class TestBodySizeLimit(_Http2Test): client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -672,7 +603,7 @@ class TestPushPromise(_Http2Test): client, h2_conn = self._setup_connection() self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -728,7 +659,7 @@ class TestPushPromise(_Http2Test): client, h2_conn = self._setup_connection() self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -791,7 +722,7 @@ class TestConnectionLost(_Http2Test): client, h2_conn = self._setup_connection() self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -848,7 +779,7 @@ class TestMaxConcurrentStreams(_Http2Test): # this will exceed MAX_CONCURRENT_STREAMS on the server connection # and cause mitmproxy to throttle stream creation to the server self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), @@ -894,7 +825,7 @@ class TestConnectionTerminated(_Http2Test): client, h2_conn = self._setup_connection() self._send_request(client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 4ea01d34..bac0e527 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -87,8 +87,8 @@ class _WebSocketTestBase: "authority", "CONNECT", "", - "localhost", - self.server.server.address.port, + "127.0.0.1", + self.server.server.address[1], "", "HTTP/1.1", content=b'') @@ -105,8 +105,8 @@ class _WebSocketTestBase: "relative", "GET", "http", - "localhost", - self.server.server.address.port, + "127.0.0.1", + self.server.server.address[1], "/ws", "HTTP/1.1", headers=http.Headers( diff --git a/test/mitmproxy/proxy/test_server.py b/test/mitmproxy/proxy/test_server.py index 46beea41..56b09b9a 100644 --- a/test/mitmproxy/proxy/test_server.py +++ b/test/mitmproxy/proxy/test_server.py @@ -17,7 +17,6 @@ from mitmproxy.net import socks from mitmproxy import certs from mitmproxy import exceptions from mitmproxy.net.http import http1 -from mitmproxy.net.tcp import Address from pathod import pathoc from pathod import pathod @@ -581,7 +580,7 @@ class TestHttps2Http(tservers.ReverseProxyTest): def get_options(cls): opts = super().get_options() s = parse_server_spec(opts.upstream_server) - opts.upstream_server = "http://%s" % s.address + opts.upstream_server = "http://{}:{}".format(s.address[0], s.address[1]) return opts def pathoc(self, ssl, sni=None): @@ -740,7 +739,7 @@ class MasterRedirectRequest(tservers.TestMaster): # This part should have no impact, but it should also not cause any exceptions. addr = f.live.server_conn.address - addr2 = Address(("127.0.0.1", self.redirect_port)) + addr2 = ("127.0.0.1", self.redirect_port) f.live.set_server(addr2) f.live.set_server(addr) @@ -750,8 +749,8 @@ class MasterRedirectRequest(tservers.TestMaster): @controller.handler def response(self, f): - f.response.content = bytes(f.client_conn.address.port) - f.response.headers["server-conn-id"] = str(f.server_conn.source_address.port) + f.response.content = bytes(f.client_conn.address[1]) + f.response.headers["server-conn-id"] = str(f.server_conn.source_address[1]) super().response(f) diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py index f1eff9ba..9bd3ad25 100644 --- a/test/mitmproxy/test_certs.py +++ b/test/mitmproxy/test_certs.py @@ -117,6 +117,12 @@ class TestCertStore: ret = ca1.get_cert(b"foo.com", []) assert ret[0].serial == dc[0].serial + def test_create_dhparams(self): + with tutils.tmpdir() as d: + filename = os.path.join(d, "dhparam.pem") + certs.CertStore.load_dhparam(filename) + assert os.path.exists(filename) + class TestDummyCert: @@ -127,9 +133,10 @@ class TestDummyCert: ca.default_privatekey, ca.default_ca, b"foo.com", - [b"one.com", b"two.com", b"*.three.com"] + [b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"] ) assert r.cn == b"foo.com" + assert r.altnames == [b'one.com', b'two.com', b'*.three.com'] r = certs.dummy_cert( ca.default_privatekey, @@ -138,6 +145,7 @@ class TestDummyCert: [] ) assert r.cn is None + assert r.altnames == [] class TestSSLCert: @@ -179,3 +187,20 @@ class TestSSLCert: d = f.read() s = certs.SSLCert.from_der(d) assert s.cn + + def test_state(self): + with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f: + d = f.read() + c = certs.SSLCert.from_pem(d) + + c.get_state() + c2 = c.copy() + a = c.get_state() + b = c2.get_state() + assert a == b + assert c == c2 + assert c is not c2 + + x = certs.SSLCert('') + x.set_state(a) + assert x == c diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py index 777ab4dd..0083f57c 100644 --- a/test/mitmproxy/test_connections.py +++ b/test/mitmproxy/test_connections.py @@ -1 +1,210 @@ -# TODO: write tests +import socket +import os +import threading +import ssl +import OpenSSL +import pytest +from unittest import mock + +from mitmproxy import connections +from mitmproxy import exceptions +from mitmproxy.net import tcp +from mitmproxy.net.http import http1 +from mitmproxy.test import tflow +from mitmproxy.test import tutils +from .net import tservers +from pathod import test + + +class TestClientConnection: + + def test_send(self): + c = tflow.tclient_conn() + c.send(b'foobar') + c.send([b'foo', b'bar']) + with pytest.raises(TypeError): + c.send('string') + with pytest.raises(TypeError): + c.send(['string', 'not']) + assert c.wfile.getvalue() == b'foobarfoobar' + + def test_repr(self): + c = tflow.tclient_conn() + assert 'address:22' in repr(c) + assert 'ALPN' in repr(c) + assert 'TLS' not in repr(c) + + c.alpn_proto_negotiated = None + c.tls_established = True + assert 'ALPN' not in repr(c) + assert 'TLS' in repr(c) + + def test_tls_established_property(self): + c = tflow.tclient_conn() + c.tls_established = True + assert c.ssl_established + assert c.tls_established + c.tls_established = False + assert not c.ssl_established + assert not c.tls_established + + def test_make_dummy(self): + c = connections.ClientConnection.make_dummy(('foobar', 1234)) + assert c.address == ('foobar', 1234) + + def test_state(self): + c = tflow.tclient_conn() + assert connections.ClientConnection.from_state(c.get_state()).get_state() == \ + c.get_state() + + c2 = tflow.tclient_conn() + c2.address = (c2.address[0], 4242) + assert not c == c2 + + c2.timestamp_start = 42 + c.set_state(c2.get_state()) + assert c.timestamp_start == 42 + + c3 = c.copy() + assert c3.get_state() == c.get_state() + + +class TestServerConnection: + + def test_send(self): + c = tflow.tserver_conn() + c.send(b'foobar') + c.send([b'foo', b'bar']) + with pytest.raises(TypeError): + c.send('string') + with pytest.raises(TypeError): + c.send(['string', 'not']) + assert c.wfile.getvalue() == b'foobarfoobar' + + def test_repr(self): + c = tflow.tserver_conn() + + c.sni = 'foobar' + c.tls_established = True + c.alpn_proto_negotiated = b'h2' + assert 'address:22' in repr(c) + assert 'ALPN' in repr(c) + assert 'TLS: foobar' in repr(c) + + c.sni = None + c.tls_established = True + c.alpn_proto_negotiated = None + assert 'ALPN' not in repr(c) + assert 'TLS' in repr(c) + + c.sni = None + c.tls_established = False + assert 'TLS' not in repr(c) + + def test_tls_established_property(self): + c = tflow.tserver_conn() + c.tls_established = True + assert c.ssl_established + assert c.tls_established + c.tls_established = False + assert not c.ssl_established + assert not c.tls_established + + def test_make_dummy(self): + c = connections.ServerConnection.make_dummy(('foobar', 1234)) + assert c.address == ('foobar', 1234) + + def test_simple(self): + d = test.Daemon() + c = connections.ServerConnection((d.IFACE, d.port)) + c.connect() + f = tflow.tflow() + f.server_conn = c + f.request.path = "/p/200:da" + + # use this protocol just to assemble - not for actual sending + c.wfile.write(http1.assemble_request(f.request)) + c.wfile.flush() + + assert http1.read_response(c.rfile, f.request, 1000) + assert d.last_log() + + c.finish() + d.shutdown() + + def test_terminate_error(self): + d = test.Daemon() + c = connections.ServerConnection((d.IFACE, d.port)) + c.connect() + c.connection = mock.Mock() + c.connection.recv = mock.Mock(return_value=False) + c.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) + c.finish() + d.shutdown() + + def test_sni(self): + c = connections.ServerConnection(('', 1234)) + with pytest.raises(ValueError, matches='sni must be str, not '): + c.establish_ssl(None, b'foobar') + + +class TestClientConnectionTLS: + + @pytest.mark.parametrize("sni", [ + None, + "example.com" + ]) + def test_tls_with_sni(self, sni): + address = ('127.0.0.1', 0) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen() + address = sock.getsockname() + + def client_run(): + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + s = socket.create_connection(address) + s = ctx.wrap_socket(s, server_hostname=sni) + s.send(b'foobar') + s.shutdown(socket.SHUT_RDWR) + threading.Thread(target=client_run).start() + + connection, client_address = sock.accept() + c = connections.ClientConnection(connection, client_address, None) + + cert = tutils.test_data.path("mitmproxy/net/data/server.crt") + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + open(tutils.test_data.path("mitmproxy/net/data/server.key"), "rb").read()) + c.convert_to_ssl(cert, key) + assert c.connected() + assert c.sni == sni + assert c.tls_established + assert c.rfile.read(6) == b'foobar' + c.finish() + + +class TestServerConnectionTLS(tservers.ServerTestBase): + ssl = True + + class handler(tcp.BaseHandler): + def handle(self): + self.finish() + + @pytest.mark.parametrize("clientcert", [ + None, + tutils.test_data.path("mitmproxy/data/clientcert"), + os.path.join(tutils.test_data.path("mitmproxy/data/clientcert"), "client.pem"), + ]) + def test_tls(self, clientcert): + c = connections.ServerConnection(("127.0.0.1", self.port)) + c.connect() + c.establish_ssl(clientcert, "foo.com") + assert c.connected() + assert c.sni == "foo.com" + assert c.tls_established + c.close() + c.finish() diff --git a/test/mitmproxy/test_eventsequence.py b/test/mitmproxy/test_eventsequence.py index fe0f92b3..871d4b9d 100644 --- a/test/mitmproxy/test_eventsequence.py +++ b/test/mitmproxy/test_eventsequence.py @@ -32,6 +32,8 @@ def test_websocket_flow(err): assert len(f.messages) == 1 assert next(i) == ("websocket_message", f) assert len(f.messages) == 2 + assert next(i) == ("websocket_message", f) + assert len(f.messages) == 3 if err: assert next(i) == ("websocket_error", f) assert next(i) == ("websocket_end", f) diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index a78e5f80..0ac3bfd6 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -2,160 +2,18 @@ import io import pytest from mitmproxy.test import tflow -from mitmproxy.net.http import Headers import mitmproxy.io from mitmproxy import flowfilter, options from mitmproxy.contrib import tnetstring -from mitmproxy.exceptions import FlowReadException, Kill +from mitmproxy.exceptions import FlowReadException from mitmproxy import flow from mitmproxy import http -from mitmproxy import connections from mitmproxy.proxy import ProxyConfig from mitmproxy.proxy.server import DummyServer from mitmproxy import master from . import tservers -class TestHTTPFlow: - - def test_copy(self): - f = tflow.tflow(resp=True) - f.get_state() - f2 = f.copy() - a = f.get_state() - b = f2.get_state() - del a["id"] - del b["id"] - assert a == b - assert not f == f2 - assert f is not f2 - assert f.request.get_state() == f2.request.get_state() - assert f.request is not f2.request - assert f.request.headers == f2.request.headers - assert f.request.headers is not f2.request.headers - assert f.response.get_state() == f2.response.get_state() - assert f.response is not f2.response - - f = tflow.tflow(err=True) - f2 = f.copy() - assert f is not f2 - assert f.request is not f2.request - assert f.request.headers == f2.request.headers - assert f.request.headers is not f2.request.headers - assert f.error.get_state() == f2.error.get_state() - assert f.error is not f2.error - - def test_match(self): - f = tflow.tflow(resp=True) - assert not flowfilter.match("~b test", f) - assert flowfilter.match(None, f) - assert not flowfilter.match("~b test", f) - - f = tflow.tflow(err=True) - assert flowfilter.match("~e", f) - - with pytest.raises(ValueError): - flowfilter.match("~", f) - - def test_backup(self): - f = tflow.tflow() - f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) - f.request.content = b"foo" - assert not f.modified() - f.backup() - f.request.content = b"bar" - assert f.modified() - f.revert() - assert f.request.content == b"foo" - - def test_backup_idempotence(self): - f = tflow.tflow(resp=True) - f.backup() - f.revert() - f.backup() - f.revert() - - def test_getset_state(self): - f = tflow.tflow(resp=True) - state = f.get_state() - assert f.get_state() == http.HTTPFlow.from_state( - state).get_state() - - f.response = None - f.error = flow.Error("error") - state = f.get_state() - assert f.get_state() == http.HTTPFlow.from_state( - state).get_state() - - f2 = f.copy() - f2.id = f.id # copy creates a different uuid - assert f.get_state() == f2.get_state() - assert not f == f2 - f2.error = flow.Error("e2") - assert not f == f2 - f.set_state(f2.get_state()) - assert f.get_state() == f2.get_state() - - def test_kill(self): - f = tflow.tflow() - f.reply.handle() - f.intercept() - assert f.killable - f.kill() - assert not f.killable - assert f.reply.value == Kill - - def test_resume(self): - f = tflow.tflow() - f.reply.handle() - f.intercept() - assert f.reply.state == "taken" - f.resume() - assert f.reply.state == "committed" - - def test_replace_unicode(self): - f = tflow.tflow(resp=True) - f.response.content = b"\xc2foo" - f.replace(b"foo", u"bar") - - def test_replace_no_content(self): - f = tflow.tflow() - f.request.content = None - assert f.replace("foo", "bar") == 0 - - def test_replace(self): - f = tflow.tflow(resp=True) - f.request.headers["foo"] = "foo" - f.request.content = b"afoob" - - f.response.headers["foo"] = "foo" - f.response.content = b"afoob" - - assert f.replace("foo", "bar") == 6 - - assert f.request.headers["bar"] == "bar" - assert f.request.content == b"abarb" - assert f.response.headers["bar"] == "bar" - assert f.response.content == b"abarb" - - def test_replace_encoded(self): - f = tflow.tflow(resp=True) - f.request.content = b"afoob" - f.request.encode("gzip") - f.response.content = b"afoob" - f.response.encode("gzip") - - f.replace("foo", "bar") - - assert f.request.raw_content != b"abarb" - f.request.decode() - assert f.request.raw_content == b"abarb" - - assert f.response.raw_content != b"abarb" - f.response.decode() - assert f.response.raw_content == b"abarb" - - class TestSerialize: def _treader(self): @@ -307,88 +165,6 @@ class TestFlowMaster: fm.shutdown() -class TestRequest: - - def test_simple(self): - f = tflow.tflow() - r = f.request - u = r.url - r.url = u - with pytest.raises(ValueError): - setattr(r, "url", "") - assert r.url == u - r2 = r.copy() - assert r.get_state() == r2.get_state() - - def test_get_url(self): - r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) - - assert r.url == "http://address:22/path" - - r.scheme = "https" - assert r.url == "https://address:22/path" - - r.host = "host" - r.port = 42 - assert r.url == "https://host:42/path" - - r.host = "address" - r.port = 22 - assert r.url == "https://address:22/path" - - assert r.pretty_url == "https://address:22/path" - r.headers["Host"] = "foo.com:22" - assert r.url == "https://address:22/path" - assert r.pretty_url == "https://foo.com:22/path" - - def test_replace(self): - r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) - r.path = "path/foo" - r.headers["Foo"] = "fOo" - r.content = b"afoob" - assert r.replace("foo(?i)", "boo") == 4 - assert r.path == "path/boo" - assert b"foo" not in r.content - assert r.headers["boo"] == "boo" - - def test_constrain_encoding(self): - r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) - r.headers["accept-encoding"] = "gzip, oink" - r.constrain_encoding() - assert "oink" not in r.headers["accept-encoding"] - - r.headers.set_all("accept-encoding", ["gzip", "oink"]) - r.constrain_encoding() - assert "oink" not in r.headers["accept-encoding"] - - def test_get_content_type(self): - resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) - resp.headers = Headers(content_type="text/plain") - assert resp.headers["content-type"] == "text/plain" - - -class TestResponse: - - def test_simple(self): - f = tflow.tflow(resp=True) - resp = f.response - resp2 = resp.copy() - assert resp2.get_state() == resp.get_state() - - def test_replace(self): - r = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) - r.headers["Foo"] = "fOo" - r.content = b"afoob" - assert r.replace("foo(?i)", "boo") == 3 - assert b"foo" not in r.content - assert r.headers["boo"] == "boo" - - def test_get_content_type(self): - resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) - resp.headers = Headers(content_type="text/plain") - assert resp.headers["content-type"] == "text/plain" - - class TestError: def test_getset_state(self): @@ -409,23 +185,4 @@ class TestError: def test_repr(self): e = flow.Error("yay") assert repr(e) - - -class TestClientConnection: - def test_state(self): - c = tflow.tclient_conn() - assert connections.ClientConnection.from_state(c.get_state()).get_state() == \ - c.get_state() - - c2 = tflow.tclient_conn() - c2.address.address = (c2.address.host, 4242) - assert not c == c2 - - c2.timestamp_start = 42 - c.set_state(c2.get_state()) - assert c.timestamp_start == 42 - - c3 = c.copy() - assert c3.get_state() == c.get_state() - - assert str(c) + assert str(e) diff --git a/test/mitmproxy/test_flowfilter.py b/test/mitmproxy/test_flowfilter.py index bfce265e..46fff477 100644 --- a/test/mitmproxy/test_flowfilter.py +++ b/test/mitmproxy/test_flowfilter.py @@ -1,4 +1,5 @@ import io +import pytest from unittest.mock import patch from mitmproxy.test import tflow @@ -134,6 +135,12 @@ class TestMatchingHTTPFlow: e = self.err() assert self.q("~e", e) + def test_fmarked(self): + q = self.req() + assert not self.q("~marked", q) + q.marked = True + assert self.q("~marked", q) + def test_head(self): q = self.req() s = self.resp() @@ -221,6 +228,11 @@ class TestMatchingHTTPFlow: assert not self.q("~src :99", q) assert self.q("~src address:22", q) + q.client_conn.address = None + assert not self.q('~src address:22', q) + q.client_conn = None + assert not self.q('~src address:22', q) + def test_dst(self): q = self.req() q.server_conn = tflow.tserver_conn() @@ -230,6 +242,11 @@ class TestMatchingHTTPFlow: assert not self.q("~dst :99", q) assert self.q("~dst address:22", q) + q.server_conn.address = None + assert not self.q('~dst address:22', q) + q.server_conn = None + assert not self.q('~dst address:22', q) + def test_and(self): s = self.resp() assert self.q("~c 200 & ~h head", s) @@ -269,6 +286,7 @@ class TestMatchingTCPFlow: f = self.flow() assert self.q("~tcp", f) assert not self.q("~http", f) + assert not self.q("~websocket", f) def test_ferr(self): e = self.err() @@ -378,6 +396,87 @@ class TestMatchingTCPFlow: assert not self.q("~u whatever", f) +class TestMatchingWebSocketFlow: + + def flow(self): + return tflow.twebsocketflow() + + def err(self): + return tflow.twebsocketflow(err=True) + + def q(self, q, o): + return flowfilter.parse(q)(o) + + def test_websocket(self): + f = self.flow() + assert self.q("~websocket", f) + assert not self.q("~tcp", f) + assert not self.q("~http", f) + + def test_ferr(self): + e = self.err() + assert self.q("~e", e) + + def test_body(self): + f = self.flow() + + # Messages sent by client or server + assert self.q("~b hello", f) + assert self.q("~b me", f) + assert not self.q("~b nonexistent", f) + + # Messages sent by client + assert self.q("~bq hello", f) + assert not self.q("~bq me", f) + assert not self.q("~bq nonexistent", f) + + # Messages sent by server + assert self.q("~bs me", f) + assert not self.q("~bs hello", f) + assert not self.q("~bs nonexistent", f) + + def test_src(self): + f = self.flow() + assert self.q("~src address", f) + assert not self.q("~src foobar", f) + assert self.q("~src :22", f) + assert not self.q("~src :99", f) + assert self.q("~src address:22", f) + + def test_dst(self): + f = self.flow() + f.server_conn = tflow.tserver_conn() + assert self.q("~dst address", f) + assert not self.q("~dst foobar", f) + assert self.q("~dst :22", f) + assert not self.q("~dst :99", f) + assert self.q("~dst address:22", f) + + def test_and(self): + f = self.flow() + f.server_conn = tflow.tserver_conn() + assert self.q("~b hello & ~b me", f) + assert not self.q("~src wrongaddress & ~b hello", f) + assert self.q("(~src :22 & ~dst :22) & ~b hello", f) + assert not self.q("(~src address:22 & ~dst :22) & ~b nonexistent", f) + assert not self.q("(~src address:22 & ~dst :99) & ~b hello", f) + + def test_or(self): + f = self.flow() + f.server_conn = tflow.tserver_conn() + assert self.q("~b hello | ~b me", f) + assert self.q("~src :22 | ~b me", f) + assert not self.q("~src :99 | ~dst :99", f) + assert self.q("(~src :22 | ~dst :22) | ~b me", f) + + def test_not(self): + f = self.flow() + assert not self.q("! ~src :22", f) + assert self.q("! ~src :99", f) + assert self.q("!~src :99 !~src :99", f) + assert not self.q("!~src :99 !~src :22", f) + + class TestMatchingDummyFlow: def flow(self): @@ -411,6 +510,8 @@ class TestMatchingDummyFlow: assert not self.q("~e", f) assert not self.q("~http", f) + assert not self.q("~tcp", f) + assert not self.q("~websocket", f) assert not self.q("~h whatever", f) assert not self.q("~hq whatever", f) @@ -440,3 +541,11 @@ def test_pyparsing_bug(extract_tb): # The text is a string with leading and trailing whitespace stripped; if the source is not available it is None. extract_tb.return_value = [("", 1, "test", None)] assert flowfilter.parse("test") + + +def test_match(): + with pytest.raises(ValueError): + flowfilter.match('[foobar', None) + + assert flowfilter.match(None, None) + assert not flowfilter.match('foobar', None) diff --git a/test/mitmproxy/test_http.py b/test/mitmproxy/test_http.py index 777ab4dd..889eb0a7 100644 --- a/test/mitmproxy/test_http.py +++ b/test/mitmproxy/test_http.py @@ -1 +1,256 @@ -# TODO: write tests +import pytest + +from mitmproxy.test import tflow +from mitmproxy.net.http import Headers +import mitmproxy.io +from mitmproxy import flowfilter +from mitmproxy.exceptions import Kill +from mitmproxy import flow +from mitmproxy import http + + +class TestHTTPRequest: + + def test_simple(self): + f = tflow.tflow() + r = f.request + u = r.url + r.url = u + with pytest.raises(ValueError): + setattr(r, "url", "") + assert r.url == u + r2 = r.copy() + assert r.get_state() == r2.get_state() + assert hash(r) + + def test_get_url(self): + r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) + + assert r.url == "http://address:22/path" + + r.scheme = "https" + assert r.url == "https://address:22/path" + + r.host = "host" + r.port = 42 + assert r.url == "https://host:42/path" + + r.host = "address" + r.port = 22 + assert r.url == "https://address:22/path" + + assert r.pretty_url == "https://address:22/path" + r.headers["Host"] = "foo.com:22" + assert r.url == "https://address:22/path" + assert r.pretty_url == "https://foo.com:22/path" + + def test_replace(self): + r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) + r.path = "path/foo" + r.headers["Foo"] = "fOo" + r.content = b"afoob" + assert r.replace("foo(?i)", "boo") == 4 + assert r.path == "path/boo" + assert b"foo" not in r.content + assert r.headers["boo"] == "boo" + + def test_constrain_encoding(self): + r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) + r.headers["accept-encoding"] = "gzip, oink" + r.constrain_encoding() + assert "oink" not in r.headers["accept-encoding"] + + r.headers.set_all("accept-encoding", ["gzip", "oink"]) + r.constrain_encoding() + assert "oink" not in r.headers["accept-encoding"] + + def test_get_content_type(self): + resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) + resp.headers = Headers(content_type="text/plain") + assert resp.headers["content-type"] == "text/plain" + + +class TestHTTPResponse: + + def test_simple(self): + f = tflow.tflow(resp=True) + resp = f.response + resp2 = resp.copy() + assert resp2.get_state() == resp.get_state() + + def test_replace(self): + r = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) + r.headers["Foo"] = "fOo" + r.content = b"afoob" + assert r.replace("foo(?i)", "boo") == 3 + assert b"foo" not in r.content + assert r.headers["boo"] == "boo" + + def test_get_content_type(self): + resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) + resp.headers = Headers(content_type="text/plain") + assert resp.headers["content-type"] == "text/plain" + + +class TestHTTPFlow: + + def test_copy(self): + f = tflow.tflow(resp=True) + assert repr(f) + f.get_state() + f2 = f.copy() + a = f.get_state() + b = f2.get_state() + del a["id"] + del b["id"] + assert a == b + assert not f == f2 + assert f is not f2 + assert f.request.get_state() == f2.request.get_state() + assert f.request is not f2.request + assert f.request.headers == f2.request.headers + assert f.request.headers is not f2.request.headers + assert f.response.get_state() == f2.response.get_state() + assert f.response is not f2.response + + f = tflow.tflow(err=True) + f2 = f.copy() + assert f is not f2 + assert f.request is not f2.request + assert f.request.headers == f2.request.headers + assert f.request.headers is not f2.request.headers + assert f.error.get_state() == f2.error.get_state() + assert f.error is not f2.error + + def test_match(self): + f = tflow.tflow(resp=True) + assert not flowfilter.match("~b test", f) + assert flowfilter.match(None, f) + assert not flowfilter.match("~b test", f) + + f = tflow.tflow(err=True) + assert flowfilter.match("~e", f) + + with pytest.raises(ValueError): + flowfilter.match("~", f) + + def test_backup(self): + f = tflow.tflow() + f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) + f.request.content = b"foo" + assert not f.modified() + f.backup() + f.request.content = b"bar" + assert f.modified() + f.revert() + assert f.request.content == b"foo" + + def test_backup_idempotence(self): + f = tflow.tflow(resp=True) + f.backup() + f.revert() + f.backup() + f.revert() + + def test_getset_state(self): + f = tflow.tflow(resp=True) + state = f.get_state() + assert f.get_state() == http.HTTPFlow.from_state( + state).get_state() + + f.response = None + f.error = flow.Error("error") + state = f.get_state() + assert f.get_state() == http.HTTPFlow.from_state( + state).get_state() + + f2 = f.copy() + f2.id = f.id # copy creates a different uuid + assert f.get_state() == f2.get_state() + assert not f == f2 + f2.error = flow.Error("e2") + assert not f == f2 + f.set_state(f2.get_state()) + assert f.get_state() == f2.get_state() + + def test_kill(self): + f = tflow.tflow() + f.reply.handle() + f.intercept() + assert f.killable + f.kill() + assert not f.killable + assert f.reply.value == Kill + + def test_resume(self): + f = tflow.tflow() + f.reply.handle() + f.intercept() + assert f.reply.state == "taken" + f.resume() + assert f.reply.state == "committed" + + def test_replace_unicode(self): + f = tflow.tflow(resp=True) + f.response.content = b"\xc2foo" + f.replace(b"foo", u"bar") + + def test_replace_no_content(self): + f = tflow.tflow() + f.request.content = None + assert f.replace("foo", "bar") == 0 + + def test_replace(self): + f = tflow.tflow(resp=True) + f.request.headers["foo"] = "foo" + f.request.content = b"afoob" + + f.response.headers["foo"] = "foo" + f.response.content = b"afoob" + + assert f.replace("foo", "bar") == 6 + + assert f.request.headers["bar"] == "bar" + assert f.request.content == b"abarb" + assert f.response.headers["bar"] == "bar" + assert f.response.content == b"abarb" + + def test_replace_encoded(self): + f = tflow.tflow(resp=True) + f.request.content = b"afoob" + f.request.encode("gzip") + f.response.content = b"afoob" + f.response.encode("gzip") + + f.replace("foo", "bar") + + assert f.request.raw_content != b"abarb" + f.request.decode() + assert f.request.raw_content == b"abarb" + + assert f.response.raw_content != b"abarb" + f.response.decode() + assert f.response.raw_content == b"abarb" + + +def test_make_error_response(): + resp = http.make_error_response(543, 'foobar', Headers()) + assert resp + + +def test_make_connect_request(): + req = http.make_connect_request(('invalidhost', 1234)) + assert req.first_line_format == 'authority' + assert req.method == 'CONNECT' + assert req.http_version == 'HTTP/1.1' + + +def test_make_connect_response(): + resp = http.make_connect_response('foobar') + assert resp.http_version == 'foobar' + assert resp.status_code == 200 + + +def test_expect_continue_response(): + assert http.expect_continue_response.http_version == 'HTTP/1.1' + assert http.expect_continue_response.status_code == 100 diff --git a/test/mitmproxy/test_optmanager.py b/test/mitmproxy/test_optmanager.py index 65691fdf..161b0dcf 100644 --- a/test/mitmproxy/test_optmanager.py +++ b/test/mitmproxy/test_optmanager.py @@ -30,6 +30,14 @@ class TD2(TD): super().__init__(three=three, **kwargs) +class TM(optmanager.OptManager): + def __init__(self, one="one", two=["foo"], three=None): + self.one = one + self.two = two + self.three = three + super().__init__() + + def test_defaults(): assert TD2.default("one") == "done" assert TD2.default("two") == "dtwo" @@ -203,6 +211,9 @@ def test_serialize(): t = "" o2.load(t) + with pytest.raises(exceptions.OptionsError, matches='No such option: foobar'): + o2.load("foobar: '123'") + def test_serialize_defaults(): o = options.Options() @@ -224,13 +235,10 @@ def test_saving(): o.load_paths(dst) assert o.three == "foo" - -class TM(optmanager.OptManager): - def __init__(self, one="one", two=["foo"], three=None): - self.one = one - self.two = two - self.three = three - super().__init__() + with open(dst, 'a') as f: + f.write("foobar: '123'") + with pytest.raises(exceptions.OptionsError, matches=''): + o.load_paths(dst) def test_merge(): diff --git a/test/mitmproxy/test_proxy.py b/test/mitmproxy/test_proxy.py index a14c851e..37cec57a 100644 --- a/test/mitmproxy/test_proxy.py +++ b/test/mitmproxy/test_proxy.py @@ -4,62 +4,17 @@ from unittest import mock from OpenSSL import SSL import pytest -from mitmproxy.test import tflow from mitmproxy.tools import cmdline from mitmproxy import options from mitmproxy.proxy import ProxyConfig -from mitmproxy import connections from mitmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler from mitmproxy.proxy import config -from mitmproxy import exceptions -from pathod import test -from mitmproxy.net.http import http1 from mitmproxy.test import tutils from ..conftest import skip_windows -class TestServerConnection: - - def test_simple(self): - self.d = test.Daemon() - sc = connections.ServerConnection((self.d.IFACE, self.d.port)) - sc.connect() - f = tflow.tflow() - f.server_conn = sc - f.request.path = "/p/200:da" - - # use this protocol just to assemble - not for actual sending - sc.wfile.write(http1.assemble_request(f.request)) - sc.wfile.flush() - - assert http1.read_response(sc.rfile, f.request, 1000) - assert self.d.last_log() - - sc.finish() - self.d.shutdown() - - def test_terminate_error(self): - self.d = test.Daemon() - sc = connections.ServerConnection((self.d.IFACE, self.d.port)) - sc.connect() - sc.connection = mock.Mock() - sc.connection.recv = mock.Mock(return_value=False) - sc.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) - sc.finish() - self.d.shutdown() - - def test_repr(self): - sc = tflow.tserver_conn() - assert "address:22" in repr(sc) - assert "ssl" not in repr(sc) - sc.ssl_established = True - assert "ssl" in repr(sc) - sc.sni = "foo" - assert "foo" in repr(sc) - - class MockParser(argparse.ArgumentParser): """ @@ -160,7 +115,7 @@ class TestProxyServer: ProxyServer(conf) def test_err_2(self): - conf = ProxyConfig(options.Options(listen_host="invalidhost")) + conf = ProxyConfig(options.Options(listen_host="256.256.256.256")) with pytest.raises(Exception, match="Error starting proxy server"): ProxyServer(conf) diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index 298fddcb..9a289ae5 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -98,13 +98,14 @@ class ProxyThread(threading.Thread): threading.Thread.__init__(self) self.tmaster = tmaster self.name = "ProxyThread (%s:%s)" % ( - tmaster.server.address.host, tmaster.server.address.port + tmaster.server.address[0], + tmaster.server.address[1], ) controller.should_exit = False @property def port(self): - return self.tmaster.server.address.port + return self.tmaster.server.address[1] @property def tlog(self): |