diff options
| author | Aldo Cortesi <aldo@nullcube.com> | 2012-06-19 09:42:32 +1200 |
|---|---|---|
| committer | Aldo Cortesi <aldo@nullcube.com> | 2012-06-19 09:42:32 +1200 |
| commit | b558997fd9db8406b2a24a1831d06e283dbf35a6 (patch) | |
| tree | 7e5236ae407cc8f5f1b95e407cca187fe5bddb9d /test | |
| download | mitmproxy-b558997fd9db8406b2a24a1831d06e283dbf35a6.tar.gz mitmproxy-b558997fd9db8406b2a24a1831d06e283dbf35a6.tar.bz2 mitmproxy-b558997fd9db8406b2a24a1831d06e283dbf35a6.zip | |
Initial checkin.
Diffstat (limited to 'test')
| -rw-r--r-- | test/test_odict.py | 113 | ||||
| -rw-r--r-- | test/test_protocol.py | 163 | ||||
| -rw-r--r-- | test/test_tcp.py | 93 | ||||
| -rw-r--r-- | test/tutils.py | 56 |
4 files changed, 425 insertions, 0 deletions
diff --git a/test/test_odict.py b/test/test_odict.py new file mode 100644 index 00000000..e7453e2d --- /dev/null +++ b/test/test_odict.py @@ -0,0 +1,113 @@ +from netlib import odict +import tutils + + +class TestODict: + def setUp(self): + self.od = odict.ODict() + + def test_str_err(self): + h = odict.ODict() + tutils.raises(ValueError, h.__setitem__, "key", "foo") + + def test_dictToHeader1(self): + self.od.add("one", "uno") + self.od.add("two", "due") + self.od.add("two", "tre") + expected = [ + "one: uno\r\n", + "two: due\r\n", + "two: tre\r\n", + "\r\n" + ] + out = repr(self.od) + for i in expected: + assert out.find(i) >= 0 + + def test_dictToHeader2(self): + self.od["one"] = ["uno"] + expected1 = "one: uno\r\n" + expected2 = "\r\n" + out = repr(self.od) + assert out.find(expected1) >= 0 + assert out.find(expected2) >= 0 + + def test_match_re(self): + h = odict.ODict() + h.add("one", "uno") + h.add("two", "due") + h.add("two", "tre") + assert h.match_re("uno") + assert h.match_re("two: due") + assert not h.match_re("nonono") + + def test_getset_state(self): + self.od.add("foo", 1) + self.od.add("foo", 2) + self.od.add("bar", 3) + state = self.od._get_state() + nd = odict.ODict._from_state(state) + assert nd == self.od + + def test_in_any(self): + self.od["one"] = ["atwoa", "athreea"] + assert self.od.in_any("one", "two") + assert self.od.in_any("one", "three") + assert not self.od.in_any("one", "four") + assert not self.od.in_any("nonexistent", "foo") + assert not self.od.in_any("one", "TWO") + assert self.od.in_any("one", "TWO", True) + + def test_copy(self): + self.od.add("foo", 1) + self.od.add("foo", 2) + self.od.add("bar", 3) + assert self.od == self.od.copy() + + def test_del(self): + self.od.add("foo", 1) + self.od.add("Foo", 2) + self.od.add("bar", 3) + del self.od["foo"] + assert len(self.od.lst) == 2 + + def test_replace(self): + self.od.add("one", "two") + self.od.add("two", "one") + assert self.od.replace("one", "vun") == 2 + assert self.od.lst == [ + ["vun", "two"], + ["two", "vun"], + ] + + def test_get(self): + self.od.add("one", "two") + assert self.od.get("one") == ["two"] + assert self.od.get("two") == None + + +class TestODictCaseless: + def setUp(self): + self.od = odict.ODictCaseless() + + def test_override(self): + o = odict.ODictCaseless() + o.add('T', 'application/x-www-form-urlencoded; charset=UTF-8') + o["T"] = ["foo"] + assert o["T"] == ["foo"] + + def test_case_preservation(self): + self.od["Foo"] = ["1"] + assert "foo" in self.od + assert self.od.items()[0][0] == "Foo" + assert self.od.get("foo") == ["1"] + assert self.od.get("foo", [""]) == ["1"] + assert self.od.get("Foo", [""]) == ["1"] + assert self.od.get("xx", "yy") == "yy" + + def test_del(self): + self.od.add("foo", 1) + self.od.add("Foo", 2) + self.od.add("bar", 3) + del self.od["foo"] + assert len(self.od) == 1 diff --git a/test/test_protocol.py b/test/test_protocol.py new file mode 100644 index 00000000..028faadd --- /dev/null +++ b/test/test_protocol.py @@ -0,0 +1,163 @@ +import cStringIO, textwrap +from netlib import protocol, odict +import tutils + +def test_has_chunked_encoding(): + h = odict.ODictCaseless() + assert not protocol.has_chunked_encoding(h) + h["transfer-encoding"] = ["chunked"] + assert protocol.has_chunked_encoding(h) + + +def test_read_chunked(): + s = cStringIO.StringIO("1\r\na\r\n0\r\n") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") + assert protocol.read_chunked(s, None) == "a" + + s = cStringIO.StringIO("\r\n") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("1\r\nfoo") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("foo\r\nfoo") + tutils.raises(protocol.ProtocolError, protocol.read_chunked, s, None) + + +def test_request_connection_close(): + h = odict.ODictCaseless() + assert protocol.request_connection_close((1, 0), h) + assert not protocol.request_connection_close((1, 1), h) + + h["connection"] = ["keep-alive"] + assert not protocol.request_connection_close((1, 1), h) + + +def test_read_http_body(): + h = odict.ODict() + s = cStringIO.StringIO("testing") + assert protocol.read_http_body(s, h, False, None) == "" + + h["content-length"] = ["foo"] + s = cStringIO.StringIO("testing") + tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, None) + + h["content-length"] = [5] + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, False, None)) == 5 + s = cStringIO.StringIO("testing") + tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, 4) + + h = odict.ODict() + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, True, 4)) == 4 + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, True, 100)) == 7 + +def test_parse_http_protocol(): + assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) + assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not protocol.parse_http_protocol("foo/0.0") + + +def test_parse_init_connect(): + assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("bogus") + assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") + + +def test_prase_init_proxy(): + u = "GET http://foo.com:8888/test HTTP/1.1" + m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) + assert m == "GET" + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + assert httpversion == (1, 1) + + assert not protocol.parse_init_proxy("invalid") + assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") + assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + + +def test_parse_init_http(): + u = "GET /test HTTP/1.1" + m, u, httpversion= protocol.parse_init_http(u) + assert m == "GET" + assert u == "/test" + assert httpversion == (1, 1) + + assert not protocol.parse_init_http("invalid") + assert not protocol.parse_init_http("GET invalid HTTP/1.1") + assert not protocol.parse_init_http("GET /test foo/1.1") + + +class TestReadHeaders: + def test_read_simple(self): + data = """ + Header: one + Header2: two + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = protocol.read_headers(s) + assert h == [["Header", "one"], ["Header2", "two"]] + + def test_read_multi(self): + data = """ + Header: one + Header: two + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = protocol.read_headers(s) + assert h == [["Header", "one"], ["Header", "two"]] + + def test_read_continued(self): + data = """ + Header: one + \ttwo + Header2: three + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = protocol.read_headers(s) + assert h == [["Header", "one\r\n two"], ["Header2", "three"]] + + +def test_parse_url(): + assert not protocol.parse_url("") + + u = "http://foo.com:8888/test" + s, h, po, pa = protocol.parse_url(u) + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + + s, h, po, pa = protocol.parse_url("http://foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = protocol.parse_url("http://foo") + assert pa == "/" + + s, h, po, pa = protocol.parse_url("https://foo") + assert po == 443 + + assert not protocol.parse_url("https://foo:bar") + assert not protocol.parse_url("https://foo:") + diff --git a/test/test_tcp.py b/test/test_tcp.py new file mode 100644 index 00000000..d7d4483e --- /dev/null +++ b/test/test_tcp.py @@ -0,0 +1,93 @@ +import cStringIO, threading, Queue +from netlib import tcp +import tutils + +class ServerThread(threading.Thread): + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class ServerTestBase: + @classmethod + def setupAll(cls): + cls.server = ServerThread(cls.makeserver()) + cls.server.start() + + @classmethod + def teardownAll(cls): + cls.server.shutdown() + + +class THandler(tcp.BaseHandler): + def handle(self): + v = self.rfile.readline() + if v.startswith("echo"): + self.wfile.write(v) + elif v.startswith("error"): + raise ValueError("Testing an error.") + self.wfile.flush() + + +class TServer(tcp.TCPServer): + def __init__(self, addr, q): + tcp.TCPServer.__init__(self, addr) + self.q = q + + def handle_connection(self, request, client_address): + THandler(request, client_address, self) + + def handle_error(self, request, client_address): + s = cStringIO.StringIO() + tcp.TCPServer.handle_error(self, request, client_address, s) + self.q.put(s.getvalue()) + + +class TestServer(ServerTestBase): + @classmethod + def makeserver(cls): + cls.q = Queue.Queue() + s = TServer(("127.0.0.1", 0), cls.q) + cls.port = s.port + return s + + def test_echo(self): + testval = "echo!\n" + c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval + + def test_error(self): + testval = "error!\n" + c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + c.wfile.write(testval) + c.wfile.flush() + assert "Testing an error" in self.q.get() + + +class TestTCPClient: + def test_conerr(self): + tutils.raises(tcp.NetLibError, tcp.TCPClient, False, "127.0.0.1", 0, None) + + +class TestFileLike: + def test_wrap(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.FileLike(s) + s.flush() + assert s.readline() == "foobar\n" + assert s.readline() == "foobar" + # Test __getattr__ + assert s.isatty + + def test_limit(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.FileLike(s) + assert s.readline(3) == "foo" diff --git a/test/tutils.py b/test/tutils.py new file mode 100644 index 00000000..c8e06b96 --- /dev/null +++ b/test/tutils.py @@ -0,0 +1,56 @@ +import tempfile, os, shutil +from contextlib import contextmanager +from libpathod import utils + + +@contextmanager +def tmpdir(*args, **kwargs): + orig_workdir = os.getcwd() + temp_workdir = tempfile.mkdtemp(*args, **kwargs) + os.chdir(temp_workdir) + + yield temp_workdir + + os.chdir(orig_workdir) + shutil.rmtree(temp_workdir) + + +def raises(exc, obj, *args, **kwargs): + """ + Assert that a callable raises a specified exception. + + :exc An exception class or a string. If a class, assert that an + exception of this type is raised. If a string, assert that the string + occurs in the string representation of the exception, based on a + case-insenstivie match. + + :obj A callable object. + + :args Arguments to be passsed to the callable. + + :kwargs Arguments to be passed to the callable. + """ + try: + apply(obj, args, kwargs) + except Exception, v: + if isinstance(exc, basestring): + if exc.lower() in str(v).lower(): + return + else: + raise AssertionError( + "Expected %s, but caught %s"%( + repr(str(exc)), v + ) + ) + else: + if isinstance(v, exc): + return + else: + raise AssertionError( + "Expected %s, but caught %s %s"%( + exc.__name__, v.__class__.__name__, str(v) + ) + ) + raise AssertionError("No exception raised.") + +test_data = utils.Data(__name__) |
