From a12c3d3f8ea255dd03bb7e993fa85eb00f47ab29 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 14 Feb 2017 22:48:43 +0100 Subject: restructure and move test files add empty test files to satisfy linter --- test/pathod/language/test_actions.py | 134 ++++++++ test/pathod/language/test_base.py | 354 +++++++++++++++++++++ test/pathod/language/test_exceptions.py | 1 + test/pathod/language/test_generators.py | 45 +++ test/pathod/language/test_http.py | 355 +++++++++++++++++++++ test/pathod/language/test_http2.py | 236 ++++++++++++++ test/pathod/language/test_message.py | 1 + test/pathod/language/test_websockets.py | 142 +++++++++ test/pathod/language/test_writer.py | 90 ++++++ test/pathod/protocols/test_http.py | 1 + test/pathod/protocols/test_http2.py | 514 +++++++++++++++++++++++++++++++ test/pathod/protocols/test_websockets.py | 1 + test/pathod/test_language_actions.py | 134 -------- test/pathod/test_language_base.py | 354 --------------------- test/pathod/test_language_generators.py | 45 --- test/pathod/test_language_http.py | 355 --------------------- test/pathod/test_language_http2.py | 236 -------------- test/pathod/test_language_websocket.py | 142 --------- test/pathod/test_language_writer.py | 90 ------ test/pathod/test_protocols_http2.py | 514 ------------------------------- 20 files changed, 1874 insertions(+), 1870 deletions(-) create mode 100644 test/pathod/language/test_actions.py create mode 100644 test/pathod/language/test_base.py create mode 100644 test/pathod/language/test_exceptions.py create mode 100644 test/pathod/language/test_generators.py create mode 100644 test/pathod/language/test_http.py create mode 100644 test/pathod/language/test_http2.py create mode 100644 test/pathod/language/test_message.py create mode 100644 test/pathod/language/test_websockets.py create mode 100644 test/pathod/language/test_writer.py create mode 100644 test/pathod/protocols/test_http.py create mode 100644 test/pathod/protocols/test_http2.py create mode 100644 test/pathod/protocols/test_websockets.py delete mode 100644 test/pathod/test_language_actions.py delete mode 100644 test/pathod/test_language_base.py delete mode 100644 test/pathod/test_language_generators.py delete mode 100644 test/pathod/test_language_http.py delete mode 100644 test/pathod/test_language_http2.py delete mode 100644 test/pathod/test_language_websocket.py delete mode 100644 test/pathod/test_language_writer.py delete mode 100644 test/pathod/test_protocols_http2.py (limited to 'test/pathod') diff --git a/test/pathod/language/test_actions.py b/test/pathod/language/test_actions.py new file mode 100644 index 00000000..9740e5c7 --- /dev/null +++ b/test/pathod/language/test_actions.py @@ -0,0 +1,134 @@ +import io + +from pathod.language import actions, parse_pathoc, parse_pathod, serve + + +def parse_request(s): + return next(parse_pathoc(s)) + + +def test_unique_name(): + assert not actions.PauseAt(0, "f").unique_name + assert actions.DisconnectAt(0).unique_name + + +class TestDisconnects: + + def test_parse_pathod(self): + a = next(parse_pathod("400:d0")).actions[0] + assert a.spec() == "d0" + a = next(parse_pathod("400:dr")).actions[0] + assert a.spec() == "dr" + + def test_at(self): + e = actions.DisconnectAt.expr() + v = e.parseString("d0")[0] + assert isinstance(v, actions.DisconnectAt) + assert v.offset == 0 + + v = e.parseString("d100")[0] + assert v.offset == 100 + + e = actions.DisconnectAt.expr() + v = e.parseString("dr")[0] + assert v.offset == "r" + + def test_spec(self): + assert actions.DisconnectAt("r").spec() == "dr" + assert actions.DisconnectAt(10).spec() == "d10" + + +class TestInject: + + def test_parse_pathod(self): + a = next(parse_pathod("400:ir,@100")).actions[0] + assert a.offset == "r" + assert a.value.datatype == "bytes" + assert a.value.usize == 100 + + a = next(parse_pathod("400:ia,@100")).actions[0] + assert a.offset == "a" + + def test_at(self): + e = actions.InjectAt.expr() + v = e.parseString("i0,'foo'")[0] + assert v.value.val == b"foo" + assert v.offset == 0 + assert isinstance(v, actions.InjectAt) + + v = e.parseString("ir,'foo'")[0] + assert v.offset == "r" + + def test_serve(self): + s = io.BytesIO() + r = next(parse_pathod("400:i0,'foo'")) + assert serve(r, s, {}) + + def test_spec(self): + e = actions.InjectAt.expr() + v = e.parseString("i0,'foo'")[0] + assert v.spec() == "i0,'foo'" + + def test_spec2(self): + e = actions.InjectAt.expr() + v = e.parseString("i0,@100")[0] + v2 = v.freeze({}) + v3 = v2.freeze({}) + assert v2.value.val == v3.value.val + + +class TestPauses: + + def test_parse_pathod(self): + e = actions.PauseAt.expr() + v = e.parseString("p10,10")[0] + assert v.seconds == 10 + assert v.offset == 10 + + v = e.parseString("p10,f")[0] + assert v.seconds == "f" + + v = e.parseString("pr,f")[0] + assert v.offset == "r" + + v = e.parseString("pa,f")[0] + assert v.offset == "a" + + def test_request(self): + r = next(parse_pathod('400:p10,10')) + assert r.actions[0].spec() == "p10,10" + + def test_spec(self): + assert actions.PauseAt("r", 5).spec() == "pr,5" + assert actions.PauseAt(0, 5).spec() == "p0,5" + assert actions.PauseAt(0, "f").spec() == "p0,f" + + def test_freeze(self): + l = actions.PauseAt("r", 5) + assert l.freeze({}).spec() == l.spec() + + +class Test_Action: + + def test_cmp(self): + a = actions.DisconnectAt(0) + b = actions.DisconnectAt(1) + c = actions.DisconnectAt(0) + assert a < b + assert a == c + l = sorted([b, a]) + assert l[0].offset == 0 + + def test_resolve(self): + r = parse_request('GET:"/foo"') + e = actions.DisconnectAt("r") + ret = e.resolve({}, r) + assert isinstance(ret.offset, int) + + def test_repr(self): + e = actions.DisconnectAt("r") + assert repr(e) + + def test_freeze(self): + l = actions.DisconnectAt(5) + assert l.freeze({}).spec() == l.spec() diff --git a/test/pathod/language/test_base.py b/test/pathod/language/test_base.py new file mode 100644 index 00000000..85e9e53b --- /dev/null +++ b/test/pathod/language/test_base.py @@ -0,0 +1,354 @@ +import os +import pytest + +from pathod import language +from pathod.language import base, exceptions + +from mitmproxy.test import tutils + + +def parse_request(s): + return language.parse_pathoc(s).next() + + +def test_times(): + reqs = list(language.parse_pathoc("get:/:x5")) + assert len(reqs) == 5 + assert not reqs[0].times + + +def test_caseless_literal(): + class CL(base.CaselessLiteral): + TOK = "foo" + v = CL("foo") + assert v.expr() + assert v.values(language.Settings()) + + +class TestTokValueNakedLiteral: + + def test_expr(self): + v = base.TokValueNakedLiteral("foo") + assert v.expr() + + def test_spec(self): + v = base.TokValueNakedLiteral("foo") + assert v.spec() == repr(v) == "foo" + + v = base.TokValueNakedLiteral("f\x00oo") + assert v.spec() == repr(v) == r"f\x00oo" + + +class TestTokValueLiteral: + + def test_expr(self): + v = base.TokValueLiteral("foo") + assert v.expr() + assert v.val == b"foo" + + v = base.TokValueLiteral("foo\n") + assert v.expr() + assert v.val == b"foo\n" + assert repr(v) + + def test_spec(self): + v = base.TokValueLiteral("foo") + assert v.spec() == r"'foo'" + + v = base.TokValueLiteral("f\x00oo") + assert v.spec() == repr(v) == r"'f\x00oo'" + + v = base.TokValueLiteral('"') + assert v.spec() == repr(v) == """ '"' """.strip() + + # While pyparsing has a escChar argument for QuotedString, + # escChar only performs scapes single-character escapes and does not work for e.g. r"\x02". + # Thus, we cannot use that option, which means we cannot have single quotes in strings. + # To fix this, we represent single quotes as r"\x07". + v = base.TokValueLiteral("'") + assert v.spec() == r"'\x27'" + + def roundtrip(self, spec): + e = base.TokValueLiteral.expr() + v = base.TokValueLiteral(spec) + v2 = e.parseString(v.spec()) + assert v.val == v2[0].val + assert v.spec() == v2[0].spec() + + def test_roundtrip(self): + self.roundtrip("'") + self.roundtrip(r"\'") + self.roundtrip("a") + self.roundtrip("\"") + # self.roundtrip("\\") + self.roundtrip("200:b'foo':i23,'\\''") + self.roundtrip("\a") + + +class TestTokValueGenerate: + + def test_basic(self): + v = base.TokValue.parseString("@10b")[0] + assert v.usize == 10 + assert v.unit == "b" + assert v.bytes() == 10 + v = base.TokValue.parseString("@10")[0] + assert v.unit == "b" + v = base.TokValue.parseString("@10k")[0] + assert v.bytes() == 10240 + v = base.TokValue.parseString("@10g")[0] + assert v.bytes() == 1024 ** 3 * 10 + + v = base.TokValue.parseString("@10g,digits")[0] + assert v.datatype == "digits" + g = v.get_generator({}) + assert g[:100] + + v = base.TokValue.parseString("@10,digits")[0] + assert v.unit == "b" + assert v.datatype == "digits" + + def test_spec(self): + v = base.TokValueGenerate(1, "b", "bytes") + assert v.spec() == repr(v) == "@1" + + v = base.TokValueGenerate(1, "k", "bytes") + assert v.spec() == repr(v) == "@1k" + + v = base.TokValueGenerate(1, "k", "ascii") + assert v.spec() == repr(v) == "@1k,ascii" + + v = base.TokValueGenerate(1, "b", "ascii") + assert v.spec() == repr(v) == "@1,ascii" + + def test_freeze(self): + v = base.TokValueGenerate(100, "b", "ascii") + f = v.freeze(language.Settings()) + assert len(f.val) == 100 + + +class TestTokValueFile: + + def test_file_value(self): + v = base.TokValue.parseString("<'one two'")[0] + assert str(v) + assert v.path == "one two" + + v = base.TokValue.parseString(" 100 + + def test_path_generator(self): + r = parse_request("GET:@100").freeze(language.Settings()) + assert len(r.spec()) > 100 + + def test_websocket(self): + r = parse_request('ws:/path/') + res = r.resolve(language.Settings()) + assert res.method.string().lower() == b"get" + assert res.tok(http.Path).value.val == b"/path/" + assert res.tok(http.Method).value.val.lower() == b"get" + assert http.get_header(b"Upgrade", res.headers).value.val == b"websocket" + + r = parse_request('ws:put:/path/') + res = r.resolve(language.Settings()) + assert r.method.string().lower() == b"put" + assert res.tok(http.Path).value.val == b"/path/" + assert res.tok(http.Method).value.val.lower() == b"put" + assert http.get_header(b"Upgrade", res.headers).value.val == b"websocket" + + +class TestResponse: + + def dummy_response(self): + return next(language.parse_pathod("400'msg'")) + + def test_response(self): + r = next(language.parse_pathod("400:m'msg'")) + assert r.status_code.string() == b"400" + assert r.reason.string() == b"msg" + + r = next(language.parse_pathod("400:m'msg':b@100b")) + assert r.reason.string() == b"msg" + assert r.body.values({}) + assert str(r) + + r = next(language.parse_pathod("200")) + assert r.status_code.string() == b"200" + assert not r.reason + assert b"OK" in [i[:] for i in r.preamble({})] + + def test_render(self): + s = io.BytesIO() + r = next(language.parse_pathod("400:m'msg'")) + assert language.serve(r, s, {}) + + r = next(language.parse_pathod("400:p0,100:dr")) + assert "p0" in r.spec() + s = r.preview_safe() + assert "p0" not in s.spec() + + def test_raw(self): + s = io.BytesIO() + r = next(language.parse_pathod("400:b'foo'")) + language.serve(r, s, {}) + v = s.getvalue() + assert b"Content-Length" in v + + s = io.BytesIO() + r = next(language.parse_pathod("400:b'foo':r")) + language.serve(r, s, {}) + v = s.getvalue() + assert b"Content-Length" not in v + + def test_length(self): + def testlen(x): + s = io.BytesIO() + x = next(x) + language.serve(x, s, language.Settings()) + assert x.length(language.Settings()) == len(s.getvalue()) + testlen(language.parse_pathod("400:m'msg':r")) + testlen(language.parse_pathod("400:m'msg':h'foo'='bar':r")) + testlen(language.parse_pathod("400:m'msg':h'foo'='bar':b@100b:r")) + + def test_maximum_length(self): + def testlen(x): + x = next(x) + s = io.BytesIO() + m = x.maximum_length({}) + language.serve(x, s, {}) + assert m >= len(s.getvalue()) + + r = language.parse_pathod("400:m'msg':b@100:d0") + testlen(r) + + r = language.parse_pathod("400:m'msg':b@100:d0:i0,'foo'") + testlen(r) + + r = language.parse_pathod("400:m'msg':b@100:d0:i0,'foo'") + testlen(r) + + def test_parse_err(self): + with pytest.raises(language.ParseException): + language.parse_pathod("400:msg,b:") + try: + language.parse_pathod("400'msg':b:") + except language.ParseException as v: + assert v.marked() + assert str(v) + + def test_nonascii(self): + with pytest.raises(Exception, match="ASCII"): + language.parse_pathod("foo:b\xf0") + + def test_parse_header(self): + r = next(language.parse_pathod('400:h"foo"="bar"')) + assert http.get_header(b"foo", r.headers) + + def test_parse_pause_before(self): + r = next(language.parse_pathod("400:p0,10")) + assert r.actions[0].spec() == "p0,10" + + def test_parse_pause_after(self): + r = next(language.parse_pathod("400:pa,10")) + assert r.actions[0].spec() == "pa,10" + + def test_parse_pause_random(self): + r = next(language.parse_pathod("400:pr,10")) + assert r.actions[0].spec() == "pr,10" + + def test_parse_stress(self): + # While larger values are known to work on linux, len() technically + # returns an int and a python 2.7 int on windows has 32bit precision. + # Therefore, we should keep the body length < 2147483647 bytes in our + # tests. + r = next(language.parse_pathod("400:b@1g")) + assert r.length({}) + + def test_spec(self): + def rt(s): + s = next(language.parse_pathod(s)).spec() + assert next(language.parse_pathod(s)).spec() == s + rt("400:b@100g") + rt("400") + rt("400:da") + + def test_websockets(self): + r = next(language.parse_pathod("ws")) + with pytest.raises(Exception, match="No websocket key"): + r.resolve(language.Settings()) + res = r.resolve(language.Settings(websocket_key=b"foo")) + assert res.status_code.string() == b"101" + + +def test_ctype_shortcut(): + e = http.ShortcutContentType.expr() + v = e.parseString("c'foo'")[0] + assert v.key.val == b"Content-Type" + assert v.value.val == b"foo" + + s = v.spec() + assert s == e.parseString(s)[0].spec() + + e = http.ShortcutContentType.expr() + v = e.parseString("c@100")[0] + v2 = v.freeze({}) + v3 = v2.freeze({}) + assert v2.value.val == v3.value.val + + +def test_location_shortcut(): + e = http.ShortcutLocation.expr() + v = e.parseString("l'foo'")[0] + assert v.key.val == b"Location" + assert v.value.val == b"foo" + + s = v.spec() + assert s == e.parseString(s)[0].spec() + + e = http.ShortcutLocation.expr() + v = e.parseString("l@100")[0] + v2 = v.freeze({}) + v3 = v2.freeze({}) + assert v2.value.val == v3.value.val + + +def test_shortcuts(): + assert next(language.parse_pathod( + "400:c'foo'")).headers[0].key.val == b"Content-Type" + assert next(language.parse_pathod( + "400:l'foo'")).headers[0].key.val == b"Location" + + assert b"Android" in tservers.render(parse_request("get:/:ua")) + assert b"User-Agent" in tservers.render(parse_request("get:/:ua")) + + +def test_user_agent(): + e = http.ShortcutUserAgent.expr() + v = e.parseString("ua")[0] + assert b"Android" in v.string() + + e = http.ShortcutUserAgent.expr() + v = e.parseString("u'a'")[0] + assert b"Android" not in v.string() + + v = e.parseString("u@100'")[0] + assert len(str(v.freeze({}).value)) > 100 + v2 = v.freeze({}) + v3 = v2.freeze({}) + assert v2.value.val == v3.value.val + + +def test_nested_response(): + e = http.NestedResponse.expr() + v = e.parseString("s'200'")[0] + assert v.value.val == b"200" + with pytest.raises(language.ParseException): + e.parseString("s'foo'") + + v = e.parseString('s"200:b@1"')[0] + assert "@1" in v.spec() + f = v.freeze({}) + assert "@1" not in f.spec() + + +def test_nested_response_freeze(): + e = http.NestedResponse( + base.TokValueLiteral( + r"200:b\'foo\':i10,\'\\x27\'" + ) + ) + assert e.freeze({}) + assert e.values({}) + + +def test_unique_components(): + with pytest.raises(Exception, match="multiple body clauses"): + language.parse_pathod("400:b@1:b@1") diff --git a/test/pathod/language/test_http2.py b/test/pathod/language/test_http2.py new file mode 100644 index 00000000..4f89adb8 --- /dev/null +++ b/test/pathod/language/test_http2.py @@ -0,0 +1,236 @@ +import io +import pytest + +from mitmproxy.net import tcp +from mitmproxy.net.http import user_agents + +from pathod import language +from pathod.language import http2 +from pathod.protocols.http2 import HTTP2StateProtocol + + +def parse_request(s): + return next(language.parse_pathoc(s, True)) + + +def parse_response(s): + return next(language.parse_pathod(s, True)) + + +def default_settings(): + return language.Settings( + request_host="foo.com", + protocol=HTTP2StateProtocol(tcp.TCPClient(('localhost', 1234))) + ) + + +def test_make_error_response(): + d = io.BytesIO() + s = http2.make_error_response("foo", "bar") + language.serve(s, d, default_settings()) + + +class TestRequest: + + def test_cached_values(self): + req = parse_request("get:/") + req_id = id(req) + assert req_id == id(req.resolve(default_settings())) + assert req.values(default_settings()) == req.values(default_settings()) + + def test_nonascii(self): + with pytest.raises(Exception, match="ASCII"): + parse_request("get:\xf0") + + def test_err(self): + with pytest.raises(language.ParseException): + parse_request('GET') + + def test_simple(self): + r = parse_request('GET:"/foo"') + assert r.method.string() == b"GET" + assert r.path.string() == b"/foo" + r = parse_request('GET:/foo') + assert r.path.string() == b"/foo" + + def test_multiple(self): + r = list(language.parse_pathoc("GET:/ PUT:/")) + assert r[0].method.string() == b"GET" + assert r[1].method.string() == b"PUT" + assert len(r) == 2 + + l = """ + GET + "/foo" + + PUT + + "/foo + + + + bar" + """ + r = list(language.parse_pathoc(l, True)) + assert len(r) == 2 + assert r[0].method.string() == b"GET" + assert r[1].method.string() == b"PUT" + + l = """ + get:"http://localhost:9999/p/200" + get:"http://localhost:9999/p/200" + """ + r = list(language.parse_pathoc(l, True)) + assert len(r) == 2 + assert r[0].method.string() == b"GET" + assert r[1].method.string() == b"GET" + + def test_render_simple(self): + s = io.BytesIO() + r = parse_request("GET:'/foo'") + assert language.serve( + r, + s, + default_settings(), + ) + + def test_raw_content_length(self): + r = parse_request('GET:/:r') + assert len(r.headers) == 0 + + r = parse_request('GET:/:r:b"foobar"') + assert len(r.headers) == 0 + + r = parse_request('GET:/') + assert len(r.headers) == 1 + assert r.headers[0].values(default_settings()) == (b"content-length", b"0") + + r = parse_request('GET:/:b"foobar"') + assert len(r.headers) == 1 + assert r.headers[0].values(default_settings()) == (b"content-length", b"6") + + r = parse_request('GET:/:b"foobar":h"content-length"="42"') + assert len(r.headers) == 1 + assert r.headers[0].values(default_settings()) == (b"content-length", b"42") + + r = parse_request('GET:/:r:b"foobar":h"content-length"="42"') + assert len(r.headers) == 1 + assert r.headers[0].values(default_settings()) == (b"content-length", b"42") + + def test_content_type(self): + r = parse_request('GET:/:r:c"foobar"') + assert len(r.headers) == 1 + assert r.headers[0].values(default_settings()) == (b"content-type", b"foobar") + + def test_user_agent(self): + r = parse_request('GET:/:r:ua') + assert len(r.headers) == 1 + assert r.headers[0].values(default_settings()) == (b"user-agent", user_agents.get_by_shortcut('a')[2].encode()) + + def test_render_with_headers(self): + s = io.BytesIO() + r = parse_request('GET:/foo:h"foo"="bar"') + assert language.serve( + r, + s, + default_settings(), + ) + + def test_nested_response(self): + l = "get:/p/:s'200'" + r = parse_request(l) + assert len(r.tokens) == 3 + assert isinstance(r.tokens[2], http2.NestedResponse) + assert r.values(default_settings()) + + def test_render_with_body(self): + s = io.BytesIO() + r = parse_request("GET:'/foo':bfoobar") + assert language.serve( + r, + s, + default_settings(), + ) + + def test_spec(self): + def rt(s): + s = parse_request(s).spec() + assert parse_request(s).spec() == s + rt("get:/foo") + + +class TestResponse: + + def test_cached_values(self): + res = parse_response("200") + res_id = id(res) + assert res_id == id(res.resolve(default_settings())) + assert res.values(default_settings()) == res.values(default_settings()) + + def test_nonascii(self): + with pytest.raises(Exception, match="ASCII"): + parse_response("200:\xf0") + + def test_err(self): + with pytest.raises(language.ParseException): + parse_response('GET:/') + + def test_raw_content_length(self): + r = parse_response('200:r') + assert len(r.headers) == 0 + + r = parse_response('200') + assert len(r.headers) == 1 + assert r.headers[0].values(default_settings()) == (b"content-length", b"0") + + def test_content_type(self): + r = parse_response('200:r:c"foobar"') + assert len(r.headers) == 1 + assert r.headers[0].values(default_settings()) == (b"content-type", b"foobar") + + def test_simple(self): + r = parse_response('200:r:h"foo"="bar"') + assert r.status_code.string() == b"200" + assert len(r.headers) == 1 + assert r.headers[0].values(default_settings()) == (b"foo", b"bar") + assert r.body is None + + r = parse_response('200:r:h"foo"="bar":bfoobar:h"bla"="fasel"') + assert r.status_code.string() == b"200" + assert len(r.headers) == 2 + assert r.headers[0].values(default_settings()) == (b"foo", b"bar") + assert r.headers[1].values(default_settings()) == (b"bla", b"fasel") + assert r.body.string() == b"foobar" + + def test_render_simple(self): + s = io.BytesIO() + r = parse_response('200') + assert language.serve( + r, + s, + default_settings(), + ) + + def test_render_with_headers(self): + s = io.BytesIO() + r = parse_response('200:h"foo"="bar"') + assert language.serve( + r, + s, + default_settings(), + ) + + def test_render_with_body(self): + s = io.BytesIO() + r = parse_response('200:bfoobar') + assert language.serve( + r, + s, + default_settings(), + ) + + def test_spec(self): + def rt(s): + s = parse_response(s).spec() + assert parse_response(s).spec() == s + rt("200:bfoobar") diff --git a/test/pathod/language/test_message.py b/test/pathod/language/test_message.py new file mode 100644 index 00000000..777ab4dd --- /dev/null +++ b/test/pathod/language/test_message.py @@ -0,0 +1 @@ +# TODO: write tests diff --git a/test/pathod/language/test_websockets.py b/test/pathod/language/test_websockets.py new file mode 100644 index 00000000..e5046591 --- /dev/null +++ b/test/pathod/language/test_websockets.py @@ -0,0 +1,142 @@ +import pytest + +from pathod import language +from pathod.language import websockets +import mitmproxy.net.websockets + +from . import tservers + + +def parse_request(s): + return next(language.parse_pathoc(s)) + + +class TestWebsocketFrame: + + def _test_messages(self, specs, message_klass): + for i in specs: + wf = parse_request(i) + assert isinstance(wf, message_klass) + assert wf + assert wf.values(language.Settings()) + assert wf.resolve(language.Settings()) + + spec = wf.spec() + wf2 = parse_request(spec) + assert wf2.spec() == spec + + def test_server_values(self): + specs = [ + "wf", + "wf:dr", + "wf:b'foo'", + "wf:mask:r'foo'", + "wf:l1024:b'foo'", + "wf:cbinary", + "wf:c1", + "wf:mask:knone", + "wf:fin", + "wf:fin:rsv1:rsv2:rsv3:mask", + "wf:-fin:-rsv1:-rsv2:-rsv3:-mask", + "wf:k@4", + "wf:x10", + ] + self._test_messages(specs, websockets.WebsocketFrame) + + def test_parse_websocket_frames(self): + wf = language.parse_websocket_frame("wf:x10") + assert len(list(wf)) == 10 + with pytest.raises(language.ParseException): + language.parse_websocket_frame("wf:x") + + def test_client_values(self): + specs = [ + "wf:f'wf'", + ] + self._test_messages(specs, websockets.WebsocketClientFrame) + + def test_nested_frame(self): + wf = parse_request("wf:f'wf'") + assert wf.nested_frame + + def test_flags(self): + wf = parse_request("wf:fin:mask:rsv1:rsv2:rsv3") + frm = mitmproxy.net.websockets.Frame.from_bytes(tservers.render(wf)) + assert frm.header.fin + assert frm.header.mask + assert frm.header.rsv1 + assert frm.header.rsv2 + assert frm.header.rsv3 + + wf = parse_request("wf:-fin:-mask:-rsv1:-rsv2:-rsv3") + frm = mitmproxy.net.websockets.Frame.from_bytes(tservers.render(wf)) + assert not frm.header.fin + assert not frm.header.mask + assert not frm.header.rsv1 + assert not frm.header.rsv2 + assert not frm.header.rsv3 + + def fr(self, spec, **kwargs): + settings = language.base.Settings(**kwargs) + wf = parse_request(spec) + return mitmproxy.net.websockets.Frame.from_bytes(tservers.render(wf, settings)) + + def test_construction(self): + assert self.fr("wf:c1").header.opcode == 1 + assert self.fr("wf:c0").header.opcode == 0 + assert self.fr("wf:cbinary").header.opcode ==\ + mitmproxy.net.websockets.OPCODE.BINARY + assert self.fr("wf:ctext").header.opcode ==\ + mitmproxy.net.websockets.OPCODE.TEXT + + def test_rawbody(self): + frm = self.fr("wf:mask:r'foo'") + assert len(frm.payload) == 3 + assert frm.payload != b"foo" + + assert self.fr("wf:r'foo'").payload == b"foo" + + def test_construction_2(self): + # Simple server frame + frm = self.fr("wf:b'foo'") + assert not frm.header.mask + assert not frm.header.masking_key + + # Simple client frame + frm = self.fr("wf:b'foo'", is_client=True) + assert frm.header.mask + assert frm.header.masking_key + frm = self.fr("wf:b'foo':k'abcd'", is_client=True) + assert frm.header.mask + assert frm.header.masking_key == b'abcd' + + # Server frame, mask explicitly set + frm = self.fr("wf:b'foo':mask") + assert frm.header.mask + assert frm.header.masking_key + frm = self.fr("wf:b'foo':k'abcd'") + assert frm.header.mask + assert frm.header.masking_key == b'abcd' + + # Client frame, mask explicitly unset + frm = self.fr("wf:b'foo':-mask", is_client=True) + assert not frm.header.mask + assert not frm.header.masking_key + + frm = self.fr("wf:b'foo':-mask:k'abcd'", is_client=True) + assert not frm.header.mask + # We're reading back a corrupted frame - the first 3 characters of the + # mask is mis-interpreted as the payload + assert frm.payload == b"abc" + + def test_knone(self): + with pytest.raises(Exception, match="Expected 4 bytes"): + self.fr("wf:b'foo':mask:knone") + + def test_length(self): + assert self.fr("wf:l3:b'foo'").header.payload_length == 3 + frm = self.fr("wf:l2:b'foo'") + assert frm.header.payload_length == 2 + assert frm.payload == b"fo" + with pytest.raises(Exception, match="Expected 1024 bytes"): + self.fr("wf:l1024:b'foo'") diff --git a/test/pathod/language/test_writer.py b/test/pathod/language/test_writer.py new file mode 100644 index 00000000..7feb985d --- /dev/null +++ b/test/pathod/language/test_writer.py @@ -0,0 +1,90 @@ +import io +from pathod import language +from pathod.language import writer + + +def test_send_chunk(): + v = b"foobarfoobar" + for bs in range(1, len(v) + 2): + s = io.BytesIO() + writer.send_chunk(s, v, bs, 0, len(v)) + assert s.getvalue() == v + for start in range(len(v)): + for end in range(len(v)): + s = io.BytesIO() + writer.send_chunk(s, v, bs, start, end) + assert s.getvalue() == v[start:end] + + +def test_write_values_inject(): + tst = b"foo" + + s = io.BytesIO() + writer.write_values(s, [tst], [(0, "inject", b"aaa")], blocksize=5) + assert s.getvalue() == b"aaafoo" + + s = io.BytesIO() + writer.write_values(s, [tst], [(1, "inject", b"aaa")], blocksize=5) + assert s.getvalue() == b"faaaoo" + + s = io.BytesIO() + writer.write_values(s, [tst], [(1, "inject", b"aaa")], blocksize=5) + assert s.getvalue() == b"faaaoo" + + +def test_write_values_disconnects(): + s = io.BytesIO() + tst = b"foo" * 100 + writer.write_values(s, [tst], [(0, "disconnect")], blocksize=5) + assert not s.getvalue() + + +def test_write_values(): + tst = b"foobarvoing" + s = io.BytesIO() + writer.write_values(s, [tst], []) + assert s.getvalue() == tst + + for bs in range(1, len(tst) + 2): + for off in range(len(tst)): + s = io.BytesIO() + writer.write_values( + s, [tst], [(off, "disconnect")], blocksize=bs + ) + assert s.getvalue() == tst[:off] + + +def test_write_values_pauses(): + tst = "".join(str(i) for i in range(10)).encode() + for i in range(2, 10): + s = io.BytesIO() + writer.write_values( + s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i + ) + assert s.getvalue() == tst + + for i in range(2, 10): + s = io.BytesIO() + writer.write_values(s, [tst], [(1, "pause", 0)], blocksize=i) + assert s.getvalue() == tst + + tst = [tst] * 5 + for i in range(2, 10): + s = io.BytesIO() + writer.write_values(s, tst[:], [(1, "pause", 0)], blocksize=i) + assert s.getvalue() == b"".join(tst) + + +def test_write_values_after(): + s = io.BytesIO() + r = next(language.parse_pathod("400:da")) + language.serve(r, s, {}) + + s = io.BytesIO() + r = next(language.parse_pathod("400:pa,0")) + language.serve(r, s, {}) + + s = io.BytesIO() + r = next(language.parse_pathod("400:ia,'xx'")) + language.serve(r, s, {}) + assert s.getvalue().endswith(b'xx') diff --git a/test/pathod/protocols/test_http.py b/test/pathod/protocols/test_http.py new file mode 100644 index 00000000..777ab4dd --- /dev/null +++ b/test/pathod/protocols/test_http.py @@ -0,0 +1 @@ +# TODO: write tests diff --git a/test/pathod/protocols/test_http2.py b/test/pathod/protocols/test_http2.py new file mode 100644 index 00000000..5bb31031 --- /dev/null +++ b/test/pathod/protocols/test_http2.py @@ -0,0 +1,514 @@ +from unittest import mock +import codecs +import pytest +import hyperframe + +from mitmproxy.net import tcp, http +from mitmproxy.net.http import http2 +from mitmproxy import exceptions + +from ..mitmproxy.net import tservers as net_tservers + +from pathod.protocols.http2 import HTTP2StateProtocol, TCPHandler + +from ..conftest import requires_alpn + + +class TestTCPHandlerWrapper: + def test_wrapped(self): + h = TCPHandler(rfile='foo', wfile='bar') + p = HTTP2StateProtocol(h) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + + def test_direct(self): + p = HTTP2StateProtocol(rfile='foo', wfile='bar') + assert isinstance(p.tcp_handler, TCPHandler) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + + +class EchoHandler(tcp.BaseHandler): + sni = None + + def handle(self): + while True: + v = self.rfile.safe_read(1) + self.wfile.write(v) + self.wfile.flush() + + +class TestProtocol: + @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_server_connection_preface") + @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_client_connection_preface") + def test_perform_connection_preface(self, mock_client_method, mock_server_method): + protocol = HTTP2StateProtocol(is_server=False) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert mock_client_method.called + assert not mock_server_method.called + + @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_server_connection_preface") + @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_client_connection_preface") + def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): + protocol = HTTP2StateProtocol(is_server=True) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert not mock_client_method.called + assert mock_server_method.called + + +@requires_alpn +class TestCheckALPNMatch(net_tservers.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=b'h2', + ) + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + with c.connect(): + c.convert_to_ssl(alpn_protos=[b'h2']) + protocol = HTTP2StateProtocol(c) + assert protocol.check_alpn() + + +@requires_alpn +class TestCheckALPNMismatch(net_tservers.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=None, + ) + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + with c.connect(): + c.convert_to_ssl(alpn_protos=[b'h2']) + protocol = HTTP2StateProtocol(c) + with pytest.raises(NotImplementedError): + protocol.check_alpn() + + +class TestPerformServerConnectionPreface(net_tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # send magic + self.wfile.write(codecs.decode('505249202a20485454502f322e300d0a0d0a534d0d0a0d0a', 'hex_codec')) + self.wfile.flush() + + # send empty settings frame + self.wfile.write(codecs.decode('000000040000000000', 'hex_codec')) + self.wfile.flush() + + # check empty settings frame + raw = http2.read_raw_frame(self.rfile) + assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec') + + # check settings acknowledgement + raw = http2.read_raw_frame(self.rfile) + assert raw == codecs.decode('000000040100000000', 'hex_codec') + + # send settings acknowledgement + self.wfile.write(codecs.decode('000000040100000000', 'hex_codec')) + self.wfile.flush() + + def test_perform_server_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + with c.connect(): + protocol = HTTP2StateProtocol(c) + + assert not protocol.connection_preface_performed + protocol.perform_server_connection_preface() + assert protocol.connection_preface_performed + + with pytest.raises(exceptions.TcpDisconnect): + protocol.perform_server_connection_preface(force=True) + + +class TestPerformClientConnectionPreface(net_tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check magic + assert self.rfile.read(24) == HTTP2StateProtocol.CLIENT_CONNECTION_PREFACE + + # check empty settings frame + assert self.rfile.read(9) ==\ + codecs.decode('000000040000000000', 'hex_codec') + + # send empty settings frame + self.wfile.write(codecs.decode('000000040000000000', 'hex_codec')) + self.wfile.flush() + + # check settings acknowledgement + assert self.rfile.read(9) == \ + codecs.decode('000000040100000000', 'hex_codec') + + # send settings acknowledgement + self.wfile.write(codecs.decode('000000040100000000', 'hex_codec')) + self.wfile.flush() + + def test_perform_client_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + with c.connect(): + protocol = HTTP2StateProtocol(c) + + assert not protocol.connection_preface_performed + protocol.perform_client_connection_preface() + assert protocol.connection_preface_performed + + +class TestClientStreamIds: + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = HTTP2StateProtocol(c) + + def test_client_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol._next_stream_id() == 1 + assert self.protocol.current_stream_id == 1 + assert self.protocol._next_stream_id() == 3 + assert self.protocol.current_stream_id == 3 + assert self.protocol._next_stream_id() == 5 + assert self.protocol.current_stream_id == 5 + + +class TestserverstreamIds: + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = HTTP2StateProtocol(c, is_server=True) + + def test_server_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol._next_stream_id() == 2 + assert self.protocol.current_stream_id == 2 + assert self.protocol._next_stream_id() == 4 + assert self.protocol.current_stream_id == 4 + assert self.protocol._next_stream_id() == 6 + assert self.protocol.current_stream_id == 6 + + +class TestApplySettings(net_tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + # check settings acknowledgement + assert self.rfile.read(9) == codecs.decode('000000040100000000', 'hex_codec') + self.wfile.write("OK") + self.wfile.flush() + self.rfile.safe_read(9) # just to keep the connection alive a bit longer + + ssl = True + + def test_apply_settings(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + with c.connect(): + c.convert_to_ssl() + protocol = HTTP2StateProtocol(c) + + protocol._apply_settings({ + hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo', + hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar', + hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef', + }) + + assert c.rfile.safe_read(2) == b"OK" + + assert protocol.http2_settings[ + hyperframe.frame.SettingsFrame.ENABLE_PUSH] == 'foo' + assert protocol.http2_settings[ + hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar' + assert protocol.http2_settings[ + hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef' + + +class TestCreateHeaders: + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_headers(self): + headers = http.Headers([ + (b':method', b'GET'), + (b':path', b'index.html'), + (b':scheme', b'https'), + (b'foo', b'bar')]) + + bytes = HTTP2StateProtocol(self.c)._create_headers( + headers, 1, end_stream=True) + assert b''.join(bytes) ==\ + codecs.decode('000014010500000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec') + + bytes = HTTP2StateProtocol(self.c)._create_headers( + headers, 1, end_stream=False) + assert b''.join(bytes) ==\ + codecs.decode('000014010400000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec') + + def test_create_headers_multiple_frames(self): + headers = http.Headers([ + (b':method', b'GET'), + (b':path', b'/'), + (b':scheme', b'https'), + (b'foo', b'bar'), + (b'server', b'version')]) + + protocol = HTTP2StateProtocol(self.c) + protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 8 + bytes = protocol._create_headers(headers, 1, end_stream=True) + assert len(bytes) == 3 + assert bytes[0] == codecs.decode('000008010100000001828487408294e783', 'hex_codec') + assert bytes[1] == codecs.decode('0000080900000000018c767f7685ee5b10', 'hex_codec') + assert bytes[2] == codecs.decode('00000209040000000163d5', 'hex_codec') + + +class TestCreateBody: + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_body_empty(self): + protocol = HTTP2StateProtocol(self.c) + bytes = protocol._create_body(b'', 1) + assert b''.join(bytes) == b'' + + def test_create_body_single_frame(self): + protocol = HTTP2StateProtocol(self.c) + bytes = protocol._create_body(b'foobar', 1) + assert b''.join(bytes) == codecs.decode('000006000100000001666f6f626172', 'hex_codec') + + def test_create_body_multiple_frames(self): + protocol = HTTP2StateProtocol(self.c) + protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 5 + bytes = protocol._create_body(b'foobarmehm42', 1) + assert len(bytes) == 3 + assert bytes[0] == codecs.decode('000005000000000001666f6f6261', 'hex_codec') + assert bytes[1] == codecs.decode('000005000000000001726d65686d', 'hex_codec') + assert bytes[2] == codecs.decode('0000020001000000013432', 'hex_codec') + + +class TestReadRequest(net_tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + codecs.decode('000003010400000001828487', 'hex_codec')) + self.wfile.write( + codecs.decode('000006000100000001666f6f626172', 'hex_codec')) + self.wfile.flush() + self.rfile.safe_read(9) # just to keep the connection alive a bit longer + + ssl = True + + def test_read_request(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + with c.connect(): + c.convert_to_ssl() + protocol = HTTP2StateProtocol(c, is_server=True) + protocol.connection_preface_performed = True + + req = protocol.read_request(NotImplemented) + + assert req.stream_id + assert req.headers.fields == () + assert req.method == "GET" + assert req.path == "/" + assert req.scheme == "https" + assert req.content == b'foobar' + + +class TestReadRequestRelative(net_tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + codecs.decode('00000c0105000000014287d5af7e4d5a777f4481f9', 'hex_codec')) + self.wfile.flush() + + ssl = True + + def test_asterisk_form(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + with c.connect(): + c.convert_to_ssl() + protocol = HTTP2StateProtocol(c, is_server=True) + protocol.connection_preface_performed = True + + req = protocol.read_request(NotImplemented) + + assert req.first_line_format == "relative" + assert req.method == "OPTIONS" + assert req.path == "*" + + +class TestReadRequestAbsolute(net_tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + codecs.decode('00001901050000000182448d9d29aee30c0e492c2a1170426366871c92585422e085', 'hex_codec')) + self.wfile.flush() + + ssl = True + + def test_absolute_form(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + with c.connect(): + c.convert_to_ssl() + protocol = HTTP2StateProtocol(c, is_server=True) + protocol.connection_preface_performed = True + + req = protocol.read_request(NotImplemented) + + assert req.first_line_format == "absolute" + assert req.scheme == "http" + assert req.host == "address" + assert req.port == 22 + + +class TestReadResponse(net_tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + codecs.decode('00000801040000002a88628594e78c767f', 'hex_codec')) + self.wfile.write( + codecs.decode('00000600010000002a666f6f626172', 'hex_codec')) + self.wfile.flush() + self.rfile.safe_read(9) # just to keep the connection alive a bit longer + + ssl = True + + def test_read_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + with c.connect(): + c.convert_to_ssl() + protocol = HTTP2StateProtocol(c) + protocol.connection_preface_performed = True + + resp = protocol.read_response(NotImplemented, stream_id=42) + + assert resp.http_version == "HTTP/2.0" + assert resp.status_code == 200 + assert resp.reason == '' + assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) + assert resp.content == b'foobar' + assert resp.timestamp_end + + +class TestReadEmptyResponse(net_tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + codecs.decode('00000801050000002a88628594e78c767f', 'hex_codec')) + self.wfile.flush() + + ssl = True + + def test_read_empty_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + with c.connect(): + c.convert_to_ssl() + protocol = HTTP2StateProtocol(c) + protocol.connection_preface_performed = True + + resp = protocol.read_response(NotImplemented, stream_id=42) + + assert resp.stream_id == 42 + assert resp.http_version == "HTTP/2.0" + assert resp.status_code == 200 + assert resp.reason == '' + assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) + assert resp.content == b'' + + +class TestAssembleRequest: + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_request_simple(self): + bytes = HTTP2StateProtocol(self.c).assemble_request(http.Request( + b'', + b'GET', + b'https', + b'', + b'', + b'/', + b"HTTP/2.0", + (), + None, + )) + assert len(bytes) == 1 + assert bytes[0] == codecs.decode('00000d0105000000018284874188089d5c0b8170dc07', 'hex_codec') + + def test_request_with_stream_id(self): + req = http.Request( + b'', + b'GET', + b'https', + b'', + b'', + b'/', + b"HTTP/2.0", + (), + None, + ) + req.stream_id = 0x42 + bytes = HTTP2StateProtocol(self.c).assemble_request(req) + assert len(bytes) == 1 + assert bytes[0] == codecs.decode('00000d0105000000428284874188089d5c0b8170dc07', 'hex_codec') + + def test_request_with_body(self): + bytes = HTTP2StateProtocol(self.c).assemble_request(http.Request( + b'', + b'GET', + b'https', + b'', + b'', + b'/', + b"HTTP/2.0", + http.Headers([(b'foo', b'bar')]), + b'foobar', + )) + assert len(bytes) == 2 + assert bytes[0] ==\ + codecs.decode('0000150104000000018284874188089d5c0b8170dc07408294e7838c767f', 'hex_codec') + assert bytes[1] ==\ + codecs.decode('000006000100000001666f6f626172', 'hex_codec') + + +class TestAssembleResponse: + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_simple(self): + bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response( + b"HTTP/2.0", + 200, + )) + assert len(bytes) == 1 + assert bytes[0] ==\ + codecs.decode('00000101050000000288', 'hex_codec') + + def test_with_stream_id(self): + resp = http.Response( + b"HTTP/2.0", + 200, + ) + resp.stream_id = 0x42 + bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(resp) + assert len(bytes) == 1 + assert bytes[0] ==\ + codecs.decode('00000101050000004288', 'hex_codec') + + def test_with_body(self): + bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response( + b"HTTP/2.0", + 200, + b'', + http.Headers(foo=b"bar"), + b'foobar' + )) + assert len(bytes) == 2 + assert bytes[0] ==\ + codecs.decode('00000901040000000288408294e7838c767f', 'hex_codec') + assert bytes[1] ==\ + codecs.decode('000006000100000002666f6f626172', 'hex_codec') diff --git a/test/pathod/protocols/test_websockets.py b/test/pathod/protocols/test_websockets.py new file mode 100644 index 00000000..777ab4dd --- /dev/null +++ b/test/pathod/protocols/test_websockets.py @@ -0,0 +1 @@ +# TODO: write tests diff --git a/test/pathod/test_language_actions.py b/test/pathod/test_language_actions.py deleted file mode 100644 index 9740e5c7..00000000 --- a/test/pathod/test_language_actions.py +++ /dev/null @@ -1,134 +0,0 @@ -import io - -from pathod.language import actions, parse_pathoc, parse_pathod, serve - - -def parse_request(s): - return next(parse_pathoc(s)) - - -def test_unique_name(): - assert not actions.PauseAt(0, "f").unique_name - assert actions.DisconnectAt(0).unique_name - - -class TestDisconnects: - - def test_parse_pathod(self): - a = next(parse_pathod("400:d0")).actions[0] - assert a.spec() == "d0" - a = next(parse_pathod("400:dr")).actions[0] - assert a.spec() == "dr" - - def test_at(self): - e = actions.DisconnectAt.expr() - v = e.parseString("d0")[0] - assert isinstance(v, actions.DisconnectAt) - assert v.offset == 0 - - v = e.parseString("d100")[0] - assert v.offset == 100 - - e = actions.DisconnectAt.expr() - v = e.parseString("dr")[0] - assert v.offset == "r" - - def test_spec(self): - assert actions.DisconnectAt("r").spec() == "dr" - assert actions.DisconnectAt(10).spec() == "d10" - - -class TestInject: - - def test_parse_pathod(self): - a = next(parse_pathod("400:ir,@100")).actions[0] - assert a.offset == "r" - assert a.value.datatype == "bytes" - assert a.value.usize == 100 - - a = next(parse_pathod("400:ia,@100")).actions[0] - assert a.offset == "a" - - def test_at(self): - e = actions.InjectAt.expr() - v = e.parseString("i0,'foo'")[0] - assert v.value.val == b"foo" - assert v.offset == 0 - assert isinstance(v, actions.InjectAt) - - v = e.parseString("ir,'foo'")[0] - assert v.offset == "r" - - def test_serve(self): - s = io.BytesIO() - r = next(parse_pathod("400:i0,'foo'")) - assert serve(r, s, {}) - - def test_spec(self): - e = actions.InjectAt.expr() - v = e.parseString("i0,'foo'")[0] - assert v.spec() == "i0,'foo'" - - def test_spec2(self): - e = actions.InjectAt.expr() - v = e.parseString("i0,@100")[0] - v2 = v.freeze({}) - v3 = v2.freeze({}) - assert v2.value.val == v3.value.val - - -class TestPauses: - - def test_parse_pathod(self): - e = actions.PauseAt.expr() - v = e.parseString("p10,10")[0] - assert v.seconds == 10 - assert v.offset == 10 - - v = e.parseString("p10,f")[0] - assert v.seconds == "f" - - v = e.parseString("pr,f")[0] - assert v.offset == "r" - - v = e.parseString("pa,f")[0] - assert v.offset == "a" - - def test_request(self): - r = next(parse_pathod('400:p10,10')) - assert r.actions[0].spec() == "p10,10" - - def test_spec(self): - assert actions.PauseAt("r", 5).spec() == "pr,5" - assert actions.PauseAt(0, 5).spec() == "p0,5" - assert actions.PauseAt(0, "f").spec() == "p0,f" - - def test_freeze(self): - l = actions.PauseAt("r", 5) - assert l.freeze({}).spec() == l.spec() - - -class Test_Action: - - def test_cmp(self): - a = actions.DisconnectAt(0) - b = actions.DisconnectAt(1) - c = actions.DisconnectAt(0) - assert a < b - assert a == c - l = sorted([b, a]) - assert l[0].offset == 0 - - def test_resolve(self): - r = parse_request('GET:"/foo"') - e = actions.DisconnectAt("r") - ret = e.resolve({}, r) - assert isinstance(ret.offset, int) - - def test_repr(self): - e = actions.DisconnectAt("r") - assert repr(e) - - def test_freeze(self): - l = actions.DisconnectAt(5) - assert l.freeze({}).spec() == l.spec() diff --git a/test/pathod/test_language_base.py b/test/pathod/test_language_base.py deleted file mode 100644 index 85e9e53b..00000000 --- a/test/pathod/test_language_base.py +++ /dev/null @@ -1,354 +0,0 @@ -import os -import pytest - -from pathod import language -from pathod.language import base, exceptions - -from mitmproxy.test import tutils - - -def parse_request(s): - return language.parse_pathoc(s).next() - - -def test_times(): - reqs = list(language.parse_pathoc("get:/:x5")) - assert len(reqs) == 5 - assert not reqs[0].times - - -def test_caseless_literal(): - class CL(base.CaselessLiteral): - TOK = "foo" - v = CL("foo") - assert v.expr() - assert v.values(language.Settings()) - - -class TestTokValueNakedLiteral: - - def test_expr(self): - v = base.TokValueNakedLiteral("foo") - assert v.expr() - - def test_spec(self): - v = base.TokValueNakedLiteral("foo") - assert v.spec() == repr(v) == "foo" - - v = base.TokValueNakedLiteral("f\x00oo") - assert v.spec() == repr(v) == r"f\x00oo" - - -class TestTokValueLiteral: - - def test_expr(self): - v = base.TokValueLiteral("foo") - assert v.expr() - assert v.val == b"foo" - - v = base.TokValueLiteral("foo\n") - assert v.expr() - assert v.val == b"foo\n" - assert repr(v) - - def test_spec(self): - v = base.TokValueLiteral("foo") - assert v.spec() == r"'foo'" - - v = base.TokValueLiteral("f\x00oo") - assert v.spec() == repr(v) == r"'f\x00oo'" - - v = base.TokValueLiteral('"') - assert v.spec() == repr(v) == """ '"' """.strip() - - # While pyparsing has a escChar argument for QuotedString, - # escChar only performs scapes single-character escapes and does not work for e.g. r"\x02". - # Thus, we cannot use that option, which means we cannot have single quotes in strings. - # To fix this, we represent single quotes as r"\x07". - v = base.TokValueLiteral("'") - assert v.spec() == r"'\x27'" - - def roundtrip(self, spec): - e = base.TokValueLiteral.expr() - v = base.TokValueLiteral(spec) - v2 = e.parseString(v.spec()) - assert v.val == v2[0].val - assert v.spec() == v2[0].spec() - - def test_roundtrip(self): - self.roundtrip("'") - self.roundtrip(r"\'") - self.roundtrip("a") - self.roundtrip("\"") - # self.roundtrip("\\") - self.roundtrip("200:b'foo':i23,'\\''") - self.roundtrip("\a") - - -class TestTokValueGenerate: - - def test_basic(self): - v = base.TokValue.parseString("@10b")[0] - assert v.usize == 10 - assert v.unit == "b" - assert v.bytes() == 10 - v = base.TokValue.parseString("@10")[0] - assert v.unit == "b" - v = base.TokValue.parseString("@10k")[0] - assert v.bytes() == 10240 - v = base.TokValue.parseString("@10g")[0] - assert v.bytes() == 1024 ** 3 * 10 - - v = base.TokValue.parseString("@10g,digits")[0] - assert v.datatype == "digits" - g = v.get_generator({}) - assert g[:100] - - v = base.TokValue.parseString("@10,digits")[0] - assert v.unit == "b" - assert v.datatype == "digits" - - def test_spec(self): - v = base.TokValueGenerate(1, "b", "bytes") - assert v.spec() == repr(v) == "@1" - - v = base.TokValueGenerate(1, "k", "bytes") - assert v.spec() == repr(v) == "@1k" - - v = base.TokValueGenerate(1, "k", "ascii") - assert v.spec() == repr(v) == "@1k,ascii" - - v = base.TokValueGenerate(1, "b", "ascii") - assert v.spec() == repr(v) == "@1,ascii" - - def test_freeze(self): - v = base.TokValueGenerate(100, "b", "ascii") - f = v.freeze(language.Settings()) - assert len(f.val) == 100 - - -class TestTokValueFile: - - def test_file_value(self): - v = base.TokValue.parseString("<'one two'")[0] - assert str(v) - assert v.path == "one two" - - v = base.TokValue.parseString(" 100 - - def test_path_generator(self): - r = parse_request("GET:@100").freeze(language.Settings()) - assert len(r.spec()) > 100 - - def test_websocket(self): - r = parse_request('ws:/path/') - res = r.resolve(language.Settings()) - assert res.method.string().lower() == b"get" - assert res.tok(http.Path).value.val == b"/path/" - assert res.tok(http.Method).value.val.lower() == b"get" - assert http.get_header(b"Upgrade", res.headers).value.val == b"websocket" - - r = parse_request('ws:put:/path/') - res = r.resolve(language.Settings()) - assert r.method.string().lower() == b"put" - assert res.tok(http.Path).value.val == b"/path/" - assert res.tok(http.Method).value.val.lower() == b"put" - assert http.get_header(b"Upgrade", res.headers).value.val == b"websocket" - - -class TestResponse: - - def dummy_response(self): - return next(language.parse_pathod("400'msg'")) - - def test_response(self): - r = next(language.parse_pathod("400:m'msg'")) - assert r.status_code.string() == b"400" - assert r.reason.string() == b"msg" - - r = next(language.parse_pathod("400:m'msg':b@100b")) - assert r.reason.string() == b"msg" - assert r.body.values({}) - assert str(r) - - r = next(language.parse_pathod("200")) - assert r.status_code.string() == b"200" - assert not r.reason - assert b"OK" in [i[:] for i in r.preamble({})] - - def test_render(self): - s = io.BytesIO() - r = next(language.parse_pathod("400:m'msg'")) - assert language.serve(r, s, {}) - - r = next(language.parse_pathod("400:p0,100:dr")) - assert "p0" in r.spec() - s = r.preview_safe() - assert "p0" not in s.spec() - - def test_raw(self): - s = io.BytesIO() - r = next(language.parse_pathod("400:b'foo'")) - language.serve(r, s, {}) - v = s.getvalue() - assert b"Content-Length" in v - - s = io.BytesIO() - r = next(language.parse_pathod("400:b'foo':r")) - language.serve(r, s, {}) - v = s.getvalue() - assert b"Content-Length" not in v - - def test_length(self): - def testlen(x): - s = io.BytesIO() - x = next(x) - language.serve(x, s, language.Settings()) - assert x.length(language.Settings()) == len(s.getvalue()) - testlen(language.parse_pathod("400:m'msg':r")) - testlen(language.parse_pathod("400:m'msg':h'foo'='bar':r")) - testlen(language.parse_pathod("400:m'msg':h'foo'='bar':b@100b:r")) - - def test_maximum_length(self): - def testlen(x): - x = next(x) - s = io.BytesIO() - m = x.maximum_length({}) - language.serve(x, s, {}) - assert m >= len(s.getvalue()) - - r = language.parse_pathod("400:m'msg':b@100:d0") - testlen(r) - - r = language.parse_pathod("400:m'msg':b@100:d0:i0,'foo'") - testlen(r) - - r = language.parse_pathod("400:m'msg':b@100:d0:i0,'foo'") - testlen(r) - - def test_parse_err(self): - with pytest.raises(language.ParseException): - language.parse_pathod("400:msg,b:") - try: - language.parse_pathod("400'msg':b:") - except language.ParseException as v: - assert v.marked() - assert str(v) - - def test_nonascii(self): - with pytest.raises(Exception, match="ASCII"): - language.parse_pathod("foo:b\xf0") - - def test_parse_header(self): - r = next(language.parse_pathod('400:h"foo"="bar"')) - assert http.get_header(b"foo", r.headers) - - def test_parse_pause_before(self): - r = next(language.parse_pathod("400:p0,10")) - assert r.actions[0].spec() == "p0,10" - - def test_parse_pause_after(self): - r = next(language.parse_pathod("400:pa,10")) - assert r.actions[0].spec() == "pa,10" - - def test_parse_pause_random(self): - r = next(language.parse_pathod("400:pr,10")) - assert r.actions[0].spec() == "pr,10" - - def test_parse_stress(self): - # While larger values are known to work on linux, len() technically - # returns an int and a python 2.7 int on windows has 32bit precision. - # Therefore, we should keep the body length < 2147483647 bytes in our - # tests. - r = next(language.parse_pathod("400:b@1g")) - assert r.length({}) - - def test_spec(self): - def rt(s): - s = next(language.parse_pathod(s)).spec() - assert next(language.parse_pathod(s)).spec() == s - rt("400:b@100g") - rt("400") - rt("400:da") - - def test_websockets(self): - r = next(language.parse_pathod("ws")) - with pytest.raises(Exception, match="No websocket key"): - r.resolve(language.Settings()) - res = r.resolve(language.Settings(websocket_key=b"foo")) - assert res.status_code.string() == b"101" - - -def test_ctype_shortcut(): - e = http.ShortcutContentType.expr() - v = e.parseString("c'foo'")[0] - assert v.key.val == b"Content-Type" - assert v.value.val == b"foo" - - s = v.spec() - assert s == e.parseString(s)[0].spec() - - e = http.ShortcutContentType.expr() - v = e.parseString("c@100")[0] - v2 = v.freeze({}) - v3 = v2.freeze({}) - assert v2.value.val == v3.value.val - - -def test_location_shortcut(): - e = http.ShortcutLocation.expr() - v = e.parseString("l'foo'")[0] - assert v.key.val == b"Location" - assert v.value.val == b"foo" - - s = v.spec() - assert s == e.parseString(s)[0].spec() - - e = http.ShortcutLocation.expr() - v = e.parseString("l@100")[0] - v2 = v.freeze({}) - v3 = v2.freeze({}) - assert v2.value.val == v3.value.val - - -def test_shortcuts(): - assert next(language.parse_pathod( - "400:c'foo'")).headers[0].key.val == b"Content-Type" - assert next(language.parse_pathod( - "400:l'foo'")).headers[0].key.val == b"Location" - - assert b"Android" in tservers.render(parse_request("get:/:ua")) - assert b"User-Agent" in tservers.render(parse_request("get:/:ua")) - - -def test_user_agent(): - e = http.ShortcutUserAgent.expr() - v = e.parseString("ua")[0] - assert b"Android" in v.string() - - e = http.ShortcutUserAgent.expr() - v = e.parseString("u'a'")[0] - assert b"Android" not in v.string() - - v = e.parseString("u@100'")[0] - assert len(str(v.freeze({}).value)) > 100 - v2 = v.freeze({}) - v3 = v2.freeze({}) - assert v2.value.val == v3.value.val - - -def test_nested_response(): - e = http.NestedResponse.expr() - v = e.parseString("s'200'")[0] - assert v.value.val == b"200" - with pytest.raises(language.ParseException): - e.parseString("s'foo'") - - v = e.parseString('s"200:b@1"')[0] - assert "@1" in v.spec() - f = v.freeze({}) - assert "@1" not in f.spec() - - -def test_nested_response_freeze(): - e = http.NestedResponse( - base.TokValueLiteral( - r"200:b\'foo\':i10,\'\\x27\'" - ) - ) - assert e.freeze({}) - assert e.values({}) - - -def test_unique_components(): - with pytest.raises(Exception, match="multiple body clauses"): - language.parse_pathod("400:b@1:b@1") diff --git a/test/pathod/test_language_http2.py b/test/pathod/test_language_http2.py deleted file mode 100644 index 4f89adb8..00000000 --- a/test/pathod/test_language_http2.py +++ /dev/null @@ -1,236 +0,0 @@ -import io -import pytest - -from mitmproxy.net import tcp -from mitmproxy.net.http import user_agents - -from pathod import language -from pathod.language import http2 -from pathod.protocols.http2 import HTTP2StateProtocol - - -def parse_request(s): - return next(language.parse_pathoc(s, True)) - - -def parse_response(s): - return next(language.parse_pathod(s, True)) - - -def default_settings(): - return language.Settings( - request_host="foo.com", - protocol=HTTP2StateProtocol(tcp.TCPClient(('localhost', 1234))) - ) - - -def test_make_error_response(): - d = io.BytesIO() - s = http2.make_error_response("foo", "bar") - language.serve(s, d, default_settings()) - - -class TestRequest: - - def test_cached_values(self): - req = parse_request("get:/") - req_id = id(req) - assert req_id == id(req.resolve(default_settings())) - assert req.values(default_settings()) == req.values(default_settings()) - - def test_nonascii(self): - with pytest.raises(Exception, match="ASCII"): - parse_request("get:\xf0") - - def test_err(self): - with pytest.raises(language.ParseException): - parse_request('GET') - - def test_simple(self): - r = parse_request('GET:"/foo"') - assert r.method.string() == b"GET" - assert r.path.string() == b"/foo" - r = parse_request('GET:/foo') - assert r.path.string() == b"/foo" - - def test_multiple(self): - r = list(language.parse_pathoc("GET:/ PUT:/")) - assert r[0].method.string() == b"GET" - assert r[1].method.string() == b"PUT" - assert len(r) == 2 - - l = """ - GET - "/foo" - - PUT - - "/foo - - - - bar" - """ - r = list(language.parse_pathoc(l, True)) - assert len(r) == 2 - assert r[0].method.string() == b"GET" - assert r[1].method.string() == b"PUT" - - l = """ - get:"http://localhost:9999/p/200" - get:"http://localhost:9999/p/200" - """ - r = list(language.parse_pathoc(l, True)) - assert len(r) == 2 - assert r[0].method.string() == b"GET" - assert r[1].method.string() == b"GET" - - def test_render_simple(self): - s = io.BytesIO() - r = parse_request("GET:'/foo'") - assert language.serve( - r, - s, - default_settings(), - ) - - def test_raw_content_length(self): - r = parse_request('GET:/:r') - assert len(r.headers) == 0 - - r = parse_request('GET:/:r:b"foobar"') - assert len(r.headers) == 0 - - r = parse_request('GET:/') - assert len(r.headers) == 1 - assert r.headers[0].values(default_settings()) == (b"content-length", b"0") - - r = parse_request('GET:/:b"foobar"') - assert len(r.headers) == 1 - assert r.headers[0].values(default_settings()) == (b"content-length", b"6") - - r = parse_request('GET:/:b"foobar":h"content-length"="42"') - assert len(r.headers) == 1 - assert r.headers[0].values(default_settings()) == (b"content-length", b"42") - - r = parse_request('GET:/:r:b"foobar":h"content-length"="42"') - assert len(r.headers) == 1 - assert r.headers[0].values(default_settings()) == (b"content-length", b"42") - - def test_content_type(self): - r = parse_request('GET:/:r:c"foobar"') - assert len(r.headers) == 1 - assert r.headers[0].values(default_settings()) == (b"content-type", b"foobar") - - def test_user_agent(self): - r = parse_request('GET:/:r:ua') - assert len(r.headers) == 1 - assert r.headers[0].values(default_settings()) == (b"user-agent", user_agents.get_by_shortcut('a')[2].encode()) - - def test_render_with_headers(self): - s = io.BytesIO() - r = parse_request('GET:/foo:h"foo"="bar"') - assert language.serve( - r, - s, - default_settings(), - ) - - def test_nested_response(self): - l = "get:/p/:s'200'" - r = parse_request(l) - assert len(r.tokens) == 3 - assert isinstance(r.tokens[2], http2.NestedResponse) - assert r.values(default_settings()) - - def test_render_with_body(self): - s = io.BytesIO() - r = parse_request("GET:'/foo':bfoobar") - assert language.serve( - r, - s, - default_settings(), - ) - - def test_spec(self): - def rt(s): - s = parse_request(s).spec() - assert parse_request(s).spec() == s - rt("get:/foo") - - -class TestResponse: - - def test_cached_values(self): - res = parse_response("200") - res_id = id(res) - assert res_id == id(res.resolve(default_settings())) - assert res.values(default_settings()) == res.values(default_settings()) - - def test_nonascii(self): - with pytest.raises(Exception, match="ASCII"): - parse_response("200:\xf0") - - def test_err(self): - with pytest.raises(language.ParseException): - parse_response('GET:/') - - def test_raw_content_length(self): - r = parse_response('200:r') - assert len(r.headers) == 0 - - r = parse_response('200') - assert len(r.headers) == 1 - assert r.headers[0].values(default_settings()) == (b"content-length", b"0") - - def test_content_type(self): - r = parse_response('200:r:c"foobar"') - assert len(r.headers) == 1 - assert r.headers[0].values(default_settings()) == (b"content-type", b"foobar") - - def test_simple(self): - r = parse_response('200:r:h"foo"="bar"') - assert r.status_code.string() == b"200" - assert len(r.headers) == 1 - assert r.headers[0].values(default_settings()) == (b"foo", b"bar") - assert r.body is None - - r = parse_response('200:r:h"foo"="bar":bfoobar:h"bla"="fasel"') - assert r.status_code.string() == b"200" - assert len(r.headers) == 2 - assert r.headers[0].values(default_settings()) == (b"foo", b"bar") - assert r.headers[1].values(default_settings()) == (b"bla", b"fasel") - assert r.body.string() == b"foobar" - - def test_render_simple(self): - s = io.BytesIO() - r = parse_response('200') - assert language.serve( - r, - s, - default_settings(), - ) - - def test_render_with_headers(self): - s = io.BytesIO() - r = parse_response('200:h"foo"="bar"') - assert language.serve( - r, - s, - default_settings(), - ) - - def test_render_with_body(self): - s = io.BytesIO() - r = parse_response('200:bfoobar') - assert language.serve( - r, - s, - default_settings(), - ) - - def test_spec(self): - def rt(s): - s = parse_response(s).spec() - assert parse_response(s).spec() == s - rt("200:bfoobar") diff --git a/test/pathod/test_language_websocket.py b/test/pathod/test_language_websocket.py deleted file mode 100644 index e5046591..00000000 --- a/test/pathod/test_language_websocket.py +++ /dev/null @@ -1,142 +0,0 @@ -import pytest - -from pathod import language -from pathod.language import websockets -import mitmproxy.net.websockets - -from . import tservers - - -def parse_request(s): - return next(language.parse_pathoc(s)) - - -class TestWebsocketFrame: - - def _test_messages(self, specs, message_klass): - for i in specs: - wf = parse_request(i) - assert isinstance(wf, message_klass) - assert wf - assert wf.values(language.Settings()) - assert wf.resolve(language.Settings()) - - spec = wf.spec() - wf2 = parse_request(spec) - assert wf2.spec() == spec - - def test_server_values(self): - specs = [ - "wf", - "wf:dr", - "wf:b'foo'", - "wf:mask:r'foo'", - "wf:l1024:b'foo'", - "wf:cbinary", - "wf:c1", - "wf:mask:knone", - "wf:fin", - "wf:fin:rsv1:rsv2:rsv3:mask", - "wf:-fin:-rsv1:-rsv2:-rsv3:-mask", - "wf:k@4", - "wf:x10", - ] - self._test_messages(specs, websockets.WebsocketFrame) - - def test_parse_websocket_frames(self): - wf = language.parse_websocket_frame("wf:x10") - assert len(list(wf)) == 10 - with pytest.raises(language.ParseException): - language.parse_websocket_frame("wf:x") - - def test_client_values(self): - specs = [ - "wf:f'wf'", - ] - self._test_messages(specs, websockets.WebsocketClientFrame) - - def test_nested_frame(self): - wf = parse_request("wf:f'wf'") - assert wf.nested_frame - - def test_flags(self): - wf = parse_request("wf:fin:mask:rsv1:rsv2:rsv3") - frm = mitmproxy.net.websockets.Frame.from_bytes(tservers.render(wf)) - assert frm.header.fin - assert frm.header.mask - assert frm.header.rsv1 - assert frm.header.rsv2 - assert frm.header.rsv3 - - wf = parse_request("wf:-fin:-mask:-rsv1:-rsv2:-rsv3") - frm = mitmproxy.net.websockets.Frame.from_bytes(tservers.render(wf)) - assert not frm.header.fin - assert not frm.header.mask - assert not frm.header.rsv1 - assert not frm.header.rsv2 - assert not frm.header.rsv3 - - def fr(self, spec, **kwargs): - settings = language.base.Settings(**kwargs) - wf = parse_request(spec) - return mitmproxy.net.websockets.Frame.from_bytes(tservers.render(wf, settings)) - - def test_construction(self): - assert self.fr("wf:c1").header.opcode == 1 - assert self.fr("wf:c0").header.opcode == 0 - assert self.fr("wf:cbinary").header.opcode ==\ - mitmproxy.net.websockets.OPCODE.BINARY - assert self.fr("wf:ctext").header.opcode ==\ - mitmproxy.net.websockets.OPCODE.TEXT - - def test_rawbody(self): - frm = self.fr("wf:mask:r'foo'") - assert len(frm.payload) == 3 - assert frm.payload != b"foo" - - assert self.fr("wf:r'foo'").payload == b"foo" - - def test_construction_2(self): - # Simple server frame - frm = self.fr("wf:b'foo'") - assert not frm.header.mask - assert not frm.header.masking_key - - # Simple client frame - frm = self.fr("wf:b'foo'", is_client=True) - assert frm.header.mask - assert frm.header.masking_key - frm = self.fr("wf:b'foo':k'abcd'", is_client=True) - assert frm.header.mask - assert frm.header.masking_key == b'abcd' - - # Server frame, mask explicitly set - frm = self.fr("wf:b'foo':mask") - assert frm.header.mask - assert frm.header.masking_key - frm = self.fr("wf:b'foo':k'abcd'") - assert frm.header.mask - assert frm.header.masking_key == b'abcd' - - # Client frame, mask explicitly unset - frm = self.fr("wf:b'foo':-mask", is_client=True) - assert not frm.header.mask - assert not frm.header.masking_key - - frm = self.fr("wf:b'foo':-mask:k'abcd'", is_client=True) - assert not frm.header.mask - # We're reading back a corrupted frame - the first 3 characters of the - # mask is mis-interpreted as the payload - assert frm.payload == b"abc" - - def test_knone(self): - with pytest.raises(Exception, match="Expected 4 bytes"): - self.fr("wf:b'foo':mask:knone") - - def test_length(self): - assert self.fr("wf:l3:b'foo'").header.payload_length == 3 - frm = self.fr("wf:l2:b'foo'") - assert frm.header.payload_length == 2 - assert frm.payload == b"fo" - with pytest.raises(Exception, match="Expected 1024 bytes"): - self.fr("wf:l1024:b'foo'") diff --git a/test/pathod/test_language_writer.py b/test/pathod/test_language_writer.py deleted file mode 100644 index 7feb985d..00000000 --- a/test/pathod/test_language_writer.py +++ /dev/null @@ -1,90 +0,0 @@ -import io -from pathod import language -from pathod.language import writer - - -def test_send_chunk(): - v = b"foobarfoobar" - for bs in range(1, len(v) + 2): - s = io.BytesIO() - writer.send_chunk(s, v, bs, 0, len(v)) - assert s.getvalue() == v - for start in range(len(v)): - for end in range(len(v)): - s = io.BytesIO() - writer.send_chunk(s, v, bs, start, end) - assert s.getvalue() == v[start:end] - - -def test_write_values_inject(): - tst = b"foo" - - s = io.BytesIO() - writer.write_values(s, [tst], [(0, "inject", b"aaa")], blocksize=5) - assert s.getvalue() == b"aaafoo" - - s = io.BytesIO() - writer.write_values(s, [tst], [(1, "inject", b"aaa")], blocksize=5) - assert s.getvalue() == b"faaaoo" - - s = io.BytesIO() - writer.write_values(s, [tst], [(1, "inject", b"aaa")], blocksize=5) - assert s.getvalue() == b"faaaoo" - - -def test_write_values_disconnects(): - s = io.BytesIO() - tst = b"foo" * 100 - writer.write_values(s, [tst], [(0, "disconnect")], blocksize=5) - assert not s.getvalue() - - -def test_write_values(): - tst = b"foobarvoing" - s = io.BytesIO() - writer.write_values(s, [tst], []) - assert s.getvalue() == tst - - for bs in range(1, len(tst) + 2): - for off in range(len(tst)): - s = io.BytesIO() - writer.write_values( - s, [tst], [(off, "disconnect")], blocksize=bs - ) - assert s.getvalue() == tst[:off] - - -def test_write_values_pauses(): - tst = "".join(str(i) for i in range(10)).encode() - for i in range(2, 10): - s = io.BytesIO() - writer.write_values( - s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i - ) - assert s.getvalue() == tst - - for i in range(2, 10): - s = io.BytesIO() - writer.write_values(s, [tst], [(1, "pause", 0)], blocksize=i) - assert s.getvalue() == tst - - tst = [tst] * 5 - for i in range(2, 10): - s = io.BytesIO() - writer.write_values(s, tst[:], [(1, "pause", 0)], blocksize=i) - assert s.getvalue() == b"".join(tst) - - -def test_write_values_after(): - s = io.BytesIO() - r = next(language.parse_pathod("400:da")) - language.serve(r, s, {}) - - s = io.BytesIO() - r = next(language.parse_pathod("400:pa,0")) - language.serve(r, s, {}) - - s = io.BytesIO() - r = next(language.parse_pathod("400:ia,'xx'")) - language.serve(r, s, {}) - assert s.getvalue().endswith(b'xx') diff --git a/test/pathod/test_protocols_http2.py b/test/pathod/test_protocols_http2.py deleted file mode 100644 index 5bb31031..00000000 --- a/test/pathod/test_protocols_http2.py +++ /dev/null @@ -1,514 +0,0 @@ -from unittest import mock -import codecs -import pytest -import hyperframe - -from mitmproxy.net import tcp, http -from mitmproxy.net.http import http2 -from mitmproxy import exceptions - -from ..mitmproxy.net import tservers as net_tservers - -from pathod.protocols.http2 import HTTP2StateProtocol, TCPHandler - -from ..conftest import requires_alpn - - -class TestTCPHandlerWrapper: - def test_wrapped(self): - h = TCPHandler(rfile='foo', wfile='bar') - p = HTTP2StateProtocol(h) - assert p.tcp_handler.rfile == 'foo' - assert p.tcp_handler.wfile == 'bar' - - def test_direct(self): - p = HTTP2StateProtocol(rfile='foo', wfile='bar') - assert isinstance(p.tcp_handler, TCPHandler) - assert p.tcp_handler.rfile == 'foo' - assert p.tcp_handler.wfile == 'bar' - - -class EchoHandler(tcp.BaseHandler): - sni = None - - def handle(self): - while True: - v = self.rfile.safe_read(1) - self.wfile.write(v) - self.wfile.flush() - - -class TestProtocol: - @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_server_connection_preface") - @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_client_connection_preface") - def test_perform_connection_preface(self, mock_client_method, mock_server_method): - protocol = HTTP2StateProtocol(is_server=False) - protocol.connection_preface_performed = True - - protocol.perform_connection_preface() - assert not mock_client_method.called - assert not mock_server_method.called - - protocol.perform_connection_preface(force=True) - assert mock_client_method.called - assert not mock_server_method.called - - @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_server_connection_preface") - @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_client_connection_preface") - def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): - protocol = HTTP2StateProtocol(is_server=True) - protocol.connection_preface_performed = True - - protocol.perform_connection_preface() - assert not mock_client_method.called - assert not mock_server_method.called - - protocol.perform_connection_preface(force=True) - assert not mock_client_method.called - assert mock_server_method.called - - -@requires_alpn -class TestCheckALPNMatch(net_tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=b'h2', - ) - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_ssl(alpn_protos=[b'h2']) - protocol = HTTP2StateProtocol(c) - assert protocol.check_alpn() - - -@requires_alpn -class TestCheckALPNMismatch(net_tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=None, - ) - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_ssl(alpn_protos=[b'h2']) - protocol = HTTP2StateProtocol(c) - with pytest.raises(NotImplementedError): - protocol.check_alpn() - - -class TestPerformServerConnectionPreface(net_tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # send magic - self.wfile.write(codecs.decode('505249202a20485454502f322e300d0a0d0a534d0d0a0d0a', 'hex_codec')) - self.wfile.flush() - - # send empty settings frame - self.wfile.write(codecs.decode('000000040000000000', 'hex_codec')) - self.wfile.flush() - - # check empty settings frame - raw = http2.read_raw_frame(self.rfile) - assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec') - - # check settings acknowledgement - raw = http2.read_raw_frame(self.rfile) - assert raw == codecs.decode('000000040100000000', 'hex_codec') - - # send settings acknowledgement - self.wfile.write(codecs.decode('000000040100000000', 'hex_codec')) - self.wfile.flush() - - def test_perform_server_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - protocol = HTTP2StateProtocol(c) - - assert not protocol.connection_preface_performed - protocol.perform_server_connection_preface() - assert protocol.connection_preface_performed - - with pytest.raises(exceptions.TcpDisconnect): - protocol.perform_server_connection_preface(force=True) - - -class TestPerformClientConnectionPreface(net_tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # check magic - assert self.rfile.read(24) == HTTP2StateProtocol.CLIENT_CONNECTION_PREFACE - - # check empty settings frame - assert self.rfile.read(9) ==\ - codecs.decode('000000040000000000', 'hex_codec') - - # send empty settings frame - self.wfile.write(codecs.decode('000000040000000000', 'hex_codec')) - self.wfile.flush() - - # check settings acknowledgement - assert self.rfile.read(9) == \ - codecs.decode('000000040100000000', 'hex_codec') - - # send settings acknowledgement - self.wfile.write(codecs.decode('000000040100000000', 'hex_codec')) - self.wfile.flush() - - def test_perform_client_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - protocol = HTTP2StateProtocol(c) - - assert not protocol.connection_preface_performed - protocol.perform_client_connection_preface() - assert protocol.connection_preface_performed - - -class TestClientStreamIds: - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = HTTP2StateProtocol(c) - - def test_client_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol._next_stream_id() == 1 - assert self.protocol.current_stream_id == 1 - assert self.protocol._next_stream_id() == 3 - assert self.protocol.current_stream_id == 3 - assert self.protocol._next_stream_id() == 5 - assert self.protocol.current_stream_id == 5 - - -class TestserverstreamIds: - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = HTTP2StateProtocol(c, is_server=True) - - def test_server_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol._next_stream_id() == 2 - assert self.protocol.current_stream_id == 2 - assert self.protocol._next_stream_id() == 4 - assert self.protocol.current_stream_id == 4 - assert self.protocol._next_stream_id() == 6 - assert self.protocol.current_stream_id == 6 - - -class TestApplySettings(net_tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - # check settings acknowledgement - assert self.rfile.read(9) == codecs.decode('000000040100000000', 'hex_codec') - self.wfile.write("OK") - self.wfile.flush() - self.rfile.safe_read(9) # just to keep the connection alive a bit longer - - ssl = True - - def test_apply_settings(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_ssl() - protocol = HTTP2StateProtocol(c) - - protocol._apply_settings({ - hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo', - hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar', - hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef', - }) - - assert c.rfile.safe_read(2) == b"OK" - - assert protocol.http2_settings[ - hyperframe.frame.SettingsFrame.ENABLE_PUSH] == 'foo' - assert protocol.http2_settings[ - hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar' - assert protocol.http2_settings[ - hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef' - - -class TestCreateHeaders: - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_headers(self): - headers = http.Headers([ - (b':method', b'GET'), - (b':path', b'index.html'), - (b':scheme', b'https'), - (b'foo', b'bar')]) - - bytes = HTTP2StateProtocol(self.c)._create_headers( - headers, 1, end_stream=True) - assert b''.join(bytes) ==\ - codecs.decode('000014010500000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec') - - bytes = HTTP2StateProtocol(self.c)._create_headers( - headers, 1, end_stream=False) - assert b''.join(bytes) ==\ - codecs.decode('000014010400000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec') - - def test_create_headers_multiple_frames(self): - headers = http.Headers([ - (b':method', b'GET'), - (b':path', b'/'), - (b':scheme', b'https'), - (b'foo', b'bar'), - (b'server', b'version')]) - - protocol = HTTP2StateProtocol(self.c) - protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 8 - bytes = protocol._create_headers(headers, 1, end_stream=True) - assert len(bytes) == 3 - assert bytes[0] == codecs.decode('000008010100000001828487408294e783', 'hex_codec') - assert bytes[1] == codecs.decode('0000080900000000018c767f7685ee5b10', 'hex_codec') - assert bytes[2] == codecs.decode('00000209040000000163d5', 'hex_codec') - - -class TestCreateBody: - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_body_empty(self): - protocol = HTTP2StateProtocol(self.c) - bytes = protocol._create_body(b'', 1) - assert b''.join(bytes) == b'' - - def test_create_body_single_frame(self): - protocol = HTTP2StateProtocol(self.c) - bytes = protocol._create_body(b'foobar', 1) - assert b''.join(bytes) == codecs.decode('000006000100000001666f6f626172', 'hex_codec') - - def test_create_body_multiple_frames(self): - protocol = HTTP2StateProtocol(self.c) - protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 5 - bytes = protocol._create_body(b'foobarmehm42', 1) - assert len(bytes) == 3 - assert bytes[0] == codecs.decode('000005000000000001666f6f6261', 'hex_codec') - assert bytes[1] == codecs.decode('000005000000000001726d65686d', 'hex_codec') - assert bytes[2] == codecs.decode('0000020001000000013432', 'hex_codec') - - -class TestReadRequest(net_tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - codecs.decode('000003010400000001828487', 'hex_codec')) - self.wfile.write( - codecs.decode('000006000100000001666f6f626172', 'hex_codec')) - self.wfile.flush() - self.rfile.safe_read(9) # just to keep the connection alive a bit longer - - ssl = True - - def test_read_request(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_ssl() - protocol = HTTP2StateProtocol(c, is_server=True) - protocol.connection_preface_performed = True - - req = protocol.read_request(NotImplemented) - - assert req.stream_id - assert req.headers.fields == () - assert req.method == "GET" - assert req.path == "/" - assert req.scheme == "https" - assert req.content == b'foobar' - - -class TestReadRequestRelative(net_tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - codecs.decode('00000c0105000000014287d5af7e4d5a777f4481f9', 'hex_codec')) - self.wfile.flush() - - ssl = True - - def test_asterisk_form(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_ssl() - protocol = HTTP2StateProtocol(c, is_server=True) - protocol.connection_preface_performed = True - - req = protocol.read_request(NotImplemented) - - assert req.first_line_format == "relative" - assert req.method == "OPTIONS" - assert req.path == "*" - - -class TestReadRequestAbsolute(net_tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - codecs.decode('00001901050000000182448d9d29aee30c0e492c2a1170426366871c92585422e085', 'hex_codec')) - self.wfile.flush() - - ssl = True - - def test_absolute_form(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_ssl() - protocol = HTTP2StateProtocol(c, is_server=True) - protocol.connection_preface_performed = True - - req = protocol.read_request(NotImplemented) - - assert req.first_line_format == "absolute" - assert req.scheme == "http" - assert req.host == "address" - assert req.port == 22 - - -class TestReadResponse(net_tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - codecs.decode('00000801040000002a88628594e78c767f', 'hex_codec')) - self.wfile.write( - codecs.decode('00000600010000002a666f6f626172', 'hex_codec')) - self.wfile.flush() - self.rfile.safe_read(9) # just to keep the connection alive a bit longer - - ssl = True - - def test_read_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_ssl() - protocol = HTTP2StateProtocol(c) - protocol.connection_preface_performed = True - - resp = protocol.read_response(NotImplemented, stream_id=42) - - assert resp.http_version == "HTTP/2.0" - assert resp.status_code == 200 - assert resp.reason == '' - assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) - assert resp.content == b'foobar' - assert resp.timestamp_end - - -class TestReadEmptyResponse(net_tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - codecs.decode('00000801050000002a88628594e78c767f', 'hex_codec')) - self.wfile.flush() - - ssl = True - - def test_read_empty_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_ssl() - protocol = HTTP2StateProtocol(c) - protocol.connection_preface_performed = True - - resp = protocol.read_response(NotImplemented, stream_id=42) - - assert resp.stream_id == 42 - assert resp.http_version == "HTTP/2.0" - assert resp.status_code == 200 - assert resp.reason == '' - assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) - assert resp.content == b'' - - -class TestAssembleRequest: - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_request_simple(self): - bytes = HTTP2StateProtocol(self.c).assemble_request(http.Request( - b'', - b'GET', - b'https', - b'', - b'', - b'/', - b"HTTP/2.0", - (), - None, - )) - assert len(bytes) == 1 - assert bytes[0] == codecs.decode('00000d0105000000018284874188089d5c0b8170dc07', 'hex_codec') - - def test_request_with_stream_id(self): - req = http.Request( - b'', - b'GET', - b'https', - b'', - b'', - b'/', - b"HTTP/2.0", - (), - None, - ) - req.stream_id = 0x42 - bytes = HTTP2StateProtocol(self.c).assemble_request(req) - assert len(bytes) == 1 - assert bytes[0] == codecs.decode('00000d0105000000428284874188089d5c0b8170dc07', 'hex_codec') - - def test_request_with_body(self): - bytes = HTTP2StateProtocol(self.c).assemble_request(http.Request( - b'', - b'GET', - b'https', - b'', - b'', - b'/', - b"HTTP/2.0", - http.Headers([(b'foo', b'bar')]), - b'foobar', - )) - assert len(bytes) == 2 - assert bytes[0] ==\ - codecs.decode('0000150104000000018284874188089d5c0b8170dc07408294e7838c767f', 'hex_codec') - assert bytes[1] ==\ - codecs.decode('000006000100000001666f6f626172', 'hex_codec') - - -class TestAssembleResponse: - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_simple(self): - bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response( - b"HTTP/2.0", - 200, - )) - assert len(bytes) == 1 - assert bytes[0] ==\ - codecs.decode('00000101050000000288', 'hex_codec') - - def test_with_stream_id(self): - resp = http.Response( - b"HTTP/2.0", - 200, - ) - resp.stream_id = 0x42 - bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(resp) - assert len(bytes) == 1 - assert bytes[0] ==\ - codecs.decode('00000101050000004288', 'hex_codec') - - def test_with_body(self): - bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response( - b"HTTP/2.0", - 200, - b'', - http.Headers(foo=b"bar"), - b'foobar' - )) - assert len(bytes) == 2 - assert bytes[0] ==\ - codecs.decode('00000901040000000288408294e7838c767f', 'hex_codec') - assert bytes[1] ==\ - codecs.decode('000006000100000002666f6f626172', 'hex_codec') -- cgit v1.2.3