aboutsummaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2012-06-19 09:42:32 +1200
committerAldo Cortesi <aldo@nullcube.com>2012-06-19 09:42:32 +1200
commitb558997fd9db8406b2a24a1831d06e283dbf35a6 (patch)
tree7e5236ae407cc8f5f1b95e407cca187fe5bddb9d /test
downloadmitmproxy-b558997fd9db8406b2a24a1831d06e283dbf35a6.tar.gz
mitmproxy-b558997fd9db8406b2a24a1831d06e283dbf35a6.tar.bz2
mitmproxy-b558997fd9db8406b2a24a1831d06e283dbf35a6.zip
Initial checkin.
Diffstat (limited to 'test')
-rw-r--r--test/test_odict.py113
-rw-r--r--test/test_protocol.py163
-rw-r--r--test/test_tcp.py93
-rw-r--r--test/tutils.py56
4 files changed, 425 insertions, 0 deletions
diff --git a/test/test_odict.py b/test/test_odict.py
new file mode 100644
index 00000000..e7453e2d
--- /dev/null
+++ b/test/test_odict.py
@@ -0,0 +1,113 @@
+from netlib import odict
+import tutils
+
+
+class TestODict:
+ def setUp(self):
+ self.od = odict.ODict()
+
+ def test_str_err(self):
+ h = odict.ODict()
+ tutils.raises(ValueError, h.__setitem__, "key", "foo")
+
+ def test_dictToHeader1(self):
+ self.od.add("one", "uno")
+ self.od.add("two", "due")
+ self.od.add("two", "tre")
+ expected = [
+ "one: uno\r\n",
+ "two: due\r\n",
+ "two: tre\r\n",
+ "\r\n"
+ ]
+ out = repr(self.od)
+ for i in expected:
+ assert out.find(i) >= 0
+
+ def test_dictToHeader2(self):
+ self.od["one"] = ["uno"]
+ expected1 = "one: uno\r\n"
+ expected2 = "\r\n"
+ out = repr(self.od)
+ assert out.find(expected1) >= 0
+ assert out.find(expected2) >= 0
+
+ def test_match_re(self):
+ h = odict.ODict()
+ h.add("one", "uno")
+ h.add("two", "due")
+ h.add("two", "tre")
+ assert h.match_re("uno")
+ assert h.match_re("two: due")
+ assert not h.match_re("nonono")
+
+ def test_getset_state(self):
+ self.od.add("foo", 1)
+ self.od.add("foo", 2)
+ self.od.add("bar", 3)
+ state = self.od._get_state()
+ nd = odict.ODict._from_state(state)
+ assert nd == self.od
+
+ def test_in_any(self):
+ self.od["one"] = ["atwoa", "athreea"]
+ assert self.od.in_any("one", "two")
+ assert self.od.in_any("one", "three")
+ assert not self.od.in_any("one", "four")
+ assert not self.od.in_any("nonexistent", "foo")
+ assert not self.od.in_any("one", "TWO")
+ assert self.od.in_any("one", "TWO", True)
+
+ def test_copy(self):
+ self.od.add("foo", 1)
+ self.od.add("foo", 2)
+ self.od.add("bar", 3)
+ assert self.od == self.od.copy()
+
+ def test_del(self):
+ self.od.add("foo", 1)
+ self.od.add("Foo", 2)
+ self.od.add("bar", 3)
+ del self.od["foo"]
+ assert len(self.od.lst) == 2
+
+ def test_replace(self):
+ self.od.add("one", "two")
+ self.od.add("two", "one")
+ assert self.od.replace("one", "vun") == 2
+ assert self.od.lst == [
+ ["vun", "two"],
+ ["two", "vun"],
+ ]
+
+ def test_get(self):
+ self.od.add("one", "two")
+ assert self.od.get("one") == ["two"]
+ assert self.od.get("two") == None
+
+
+class TestODictCaseless:
+ def setUp(self):
+ self.od = odict.ODictCaseless()
+
+ def test_override(self):
+ o = odict.ODictCaseless()
+ o.add('T', 'application/x-www-form-urlencoded; charset=UTF-8')
+ o["T"] = ["foo"]
+ assert o["T"] == ["foo"]
+
+ def test_case_preservation(self):
+ self.od["Foo"] = ["1"]
+ assert "foo" in self.od
+ assert self.od.items()[0][0] == "Foo"
+ assert self.od.get("foo") == ["1"]
+ assert self.od.get("foo", [""]) == ["1"]
+ assert self.od.get("Foo", [""]) == ["1"]
+ assert self.od.get("xx", "yy") == "yy"
+
+ def test_del(self):
+ self.od.add("foo", 1)
+ self.od.add("Foo", 2)
+ self.od.add("bar", 3)
+ del self.od["foo"]
+ assert len(self.od) == 1
diff --git a/test/test_protocol.py b/test/test_protocol.py
new file mode 100644
index 00000000..028faadd
--- /dev/null
+++ b/test/test_protocol.py
@@ -0,0 +1,163 @@
+import cStringIO, textwrap
+from netlib import protocol, odict
+import tutils
+
+def test_has_chunked_encoding():
+ h = odict.ODictCaseless()
+ assert not protocol.has_chunked_encoding(h)
+ h["transfer-encoding"] = ["chunked"]
+ assert protocol.has_chunked_encoding(h)
+
+
+def test_read_chunked():
+ s = cStringIO.StringIO("1\r\na\r\n0\r\n")
+ tutils.raises(IOError, protocol.read_chunked, s, None)
+
+ s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n")
+ assert protocol.read_chunked(s, None) == "a"
+
+ s = cStringIO.StringIO("\r\n")
+ tutils.raises(IOError, protocol.read_chunked, s, None)
+
+ s = cStringIO.StringIO("1\r\nfoo")
+ tutils.raises(IOError, protocol.read_chunked, s, None)
+
+ s = cStringIO.StringIO("foo\r\nfoo")
+ tutils.raises(protocol.ProtocolError, protocol.read_chunked, s, None)
+
+
+def test_request_connection_close():
+ h = odict.ODictCaseless()
+ assert protocol.request_connection_close((1, 0), h)
+ assert not protocol.request_connection_close((1, 1), h)
+
+ h["connection"] = ["keep-alive"]
+ assert not protocol.request_connection_close((1, 1), h)
+
+
+def test_read_http_body():
+ h = odict.ODict()
+ s = cStringIO.StringIO("testing")
+ assert protocol.read_http_body(s, h, False, None) == ""
+
+ h["content-length"] = ["foo"]
+ s = cStringIO.StringIO("testing")
+ tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, None)
+
+ h["content-length"] = [5]
+ s = cStringIO.StringIO("testing")
+ assert len(protocol.read_http_body(s, h, False, None)) == 5
+ s = cStringIO.StringIO("testing")
+ tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, 4)
+
+ h = odict.ODict()
+ s = cStringIO.StringIO("testing")
+ assert len(protocol.read_http_body(s, h, True, 4)) == 4
+ s = cStringIO.StringIO("testing")
+ assert len(protocol.read_http_body(s, h, True, 100)) == 7
+
+def test_parse_http_protocol():
+ assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1)
+ assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0)
+ assert not protocol.parse_http_protocol("foo/0.0")
+
+
+def test_parse_init_connect():
+ assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0")
+ assert not protocol.parse_init_connect("bogus")
+ assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0")
+ assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0")
+ assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0")
+
+
+def test_prase_init_proxy():
+ u = "GET http://foo.com:8888/test HTTP/1.1"
+ m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u)
+ assert m == "GET"
+ assert s == "http"
+ assert h == "foo.com"
+ assert po == 8888
+ assert pa == "/test"
+ assert httpversion == (1, 1)
+
+ assert not protocol.parse_init_proxy("invalid")
+ assert not protocol.parse_init_proxy("GET invalid HTTP/1.1")
+ assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1")
+
+
+def test_parse_init_http():
+ u = "GET /test HTTP/1.1"
+ m, u, httpversion= protocol.parse_init_http(u)
+ assert m == "GET"
+ assert u == "/test"
+ assert httpversion == (1, 1)
+
+ assert not protocol.parse_init_http("invalid")
+ assert not protocol.parse_init_http("GET invalid HTTP/1.1")
+ assert not protocol.parse_init_http("GET /test foo/1.1")
+
+
+class TestReadHeaders:
+ def test_read_simple(self):
+ data = """
+ Header: one
+ Header2: two
+ \r\n
+ """
+ data = textwrap.dedent(data)
+ data = data.strip()
+ s = cStringIO.StringIO(data)
+ h = protocol.read_headers(s)
+ assert h == [["Header", "one"], ["Header2", "two"]]
+
+ def test_read_multi(self):
+ data = """
+ Header: one
+ Header: two
+ \r\n
+ """
+ data = textwrap.dedent(data)
+ data = data.strip()
+ s = cStringIO.StringIO(data)
+ h = protocol.read_headers(s)
+ assert h == [["Header", "one"], ["Header", "two"]]
+
+ def test_read_continued(self):
+ data = """
+ Header: one
+ \ttwo
+ Header2: three
+ \r\n
+ """
+ data = textwrap.dedent(data)
+ data = data.strip()
+ s = cStringIO.StringIO(data)
+ h = protocol.read_headers(s)
+ assert h == [["Header", "one\r\n two"], ["Header2", "three"]]
+
+
+def test_parse_url():
+ assert not protocol.parse_url("")
+
+ u = "http://foo.com:8888/test"
+ s, h, po, pa = protocol.parse_url(u)
+ assert s == "http"
+ assert h == "foo.com"
+ assert po == 8888
+ assert pa == "/test"
+
+ s, h, po, pa = protocol.parse_url("http://foo/bar")
+ assert s == "http"
+ assert h == "foo"
+ assert po == 80
+ assert pa == "/bar"
+
+ s, h, po, pa = protocol.parse_url("http://foo")
+ assert pa == "/"
+
+ s, h, po, pa = protocol.parse_url("https://foo")
+ assert po == 443
+
+ assert not protocol.parse_url("https://foo:bar")
+ assert not protocol.parse_url("https://foo:")
+
diff --git a/test/test_tcp.py b/test/test_tcp.py
new file mode 100644
index 00000000..d7d4483e
--- /dev/null
+++ b/test/test_tcp.py
@@ -0,0 +1,93 @@
+import cStringIO, threading, Queue
+from netlib import tcp
+import tutils
+
+class ServerThread(threading.Thread):
+ def __init__(self, server):
+ self.server = server
+ threading.Thread.__init__(self)
+
+ def run(self):
+ self.server.serve_forever()
+
+ def shutdown(self):
+ self.server.shutdown()
+
+
+class ServerTestBase:
+ @classmethod
+ def setupAll(cls):
+ cls.server = ServerThread(cls.makeserver())
+ cls.server.start()
+
+ @classmethod
+ def teardownAll(cls):
+ cls.server.shutdown()
+
+
+class THandler(tcp.BaseHandler):
+ def handle(self):
+ v = self.rfile.readline()
+ if v.startswith("echo"):
+ self.wfile.write(v)
+ elif v.startswith("error"):
+ raise ValueError("Testing an error.")
+ self.wfile.flush()
+
+
+class TServer(tcp.TCPServer):
+ def __init__(self, addr, q):
+ tcp.TCPServer.__init__(self, addr)
+ self.q = q
+
+ def handle_connection(self, request, client_address):
+ THandler(request, client_address, self)
+
+ def handle_error(self, request, client_address):
+ s = cStringIO.StringIO()
+ tcp.TCPServer.handle_error(self, request, client_address, s)
+ self.q.put(s.getvalue())
+
+
+class TestServer(ServerTestBase):
+ @classmethod
+ def makeserver(cls):
+ cls.q = Queue.Queue()
+ s = TServer(("127.0.0.1", 0), cls.q)
+ cls.port = s.port
+ return s
+
+ def test_echo(self):
+ testval = "echo!\n"
+ c = tcp.TCPClient(False, "127.0.0.1", self.port, None)
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
+
+ def test_error(self):
+ testval = "error!\n"
+ c = tcp.TCPClient(False, "127.0.0.1", self.port, None)
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert "Testing an error" in self.q.get()
+
+
+class TestTCPClient:
+ def test_conerr(self):
+ tutils.raises(tcp.NetLibError, tcp.TCPClient, False, "127.0.0.1", 0, None)
+
+
+class TestFileLike:
+ def test_wrap(self):
+ s = cStringIO.StringIO("foobar\nfoobar")
+ s = tcp.FileLike(s)
+ s.flush()
+ assert s.readline() == "foobar\n"
+ assert s.readline() == "foobar"
+ # Test __getattr__
+ assert s.isatty
+
+ def test_limit(self):
+ s = cStringIO.StringIO("foobar\nfoobar")
+ s = tcp.FileLike(s)
+ assert s.readline(3) == "foo"
diff --git a/test/tutils.py b/test/tutils.py
new file mode 100644
index 00000000..c8e06b96
--- /dev/null
+++ b/test/tutils.py
@@ -0,0 +1,56 @@
+import tempfile, os, shutil
+from contextlib import contextmanager
+from libpathod import utils
+
+
+@contextmanager
+def tmpdir(*args, **kwargs):
+ orig_workdir = os.getcwd()
+ temp_workdir = tempfile.mkdtemp(*args, **kwargs)
+ os.chdir(temp_workdir)
+
+ yield temp_workdir
+
+ os.chdir(orig_workdir)
+ shutil.rmtree(temp_workdir)
+
+
+def raises(exc, obj, *args, **kwargs):
+ """
+ Assert that a callable raises a specified exception.
+
+ :exc An exception class or a string. If a class, assert that an
+ exception of this type is raised. If a string, assert that the string
+ occurs in the string representation of the exception, based on a
+ case-insenstivie match.
+
+ :obj A callable object.
+
+ :args Arguments to be passsed to the callable.
+
+ :kwargs Arguments to be passed to the callable.
+ """
+ try:
+ apply(obj, args, kwargs)
+ except Exception, v:
+ if isinstance(exc, basestring):
+ if exc.lower() in str(v).lower():
+ return
+ else:
+ raise AssertionError(
+ "Expected %s, but caught %s"%(
+ repr(str(exc)), v
+ )
+ )
+ else:
+ if isinstance(v, exc):
+ return
+ else:
+ raise AssertionError(
+ "Expected %s, but caught %s %s"%(
+ exc.__name__, v.__class__.__name__, str(v)
+ )
+ )
+ raise AssertionError("No exception raised.")
+
+test_data = utils.Data(__name__)