aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.coveragerc2
-rw-r--r--.gitignore9
-rw-r--r--README2
-rw-r--r--netlib/__init__.py0
-rw-r--r--netlib/odict.py160
-rw-r--r--netlib/protocol.py218
-rw-r--r--netlib/tcp.py182
-rw-r--r--test/test_odict.py113
-rw-r--r--test/test_protocol.py163
-rw-r--r--test/test_tcp.py93
-rw-r--r--test/tutils.py56
11 files changed, 998 insertions, 0 deletions
diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 00000000..99f57cb0
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,2 @@
+[report]
+include = *netlib*
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 00000000..f53cd2e2
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,9 @@
+MANIFEST
+/build
+/dist
+/tmp
+/doc
+*.py[cdo]
+*.swp
+*.swo
+.coverage
diff --git a/README b/README
new file mode 100644
index 00000000..1c86738c
--- /dev/null
+++ b/README
@@ -0,0 +1,2 @@
+Netlib is a collection of common utility functions, used by the pathod and
+mitmproxy projects.
diff --git a/netlib/__init__.py b/netlib/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/netlib/__init__.py
diff --git a/netlib/odict.py b/netlib/odict.py
new file mode 100644
index 00000000..afc33caa
--- /dev/null
+++ b/netlib/odict.py
@@ -0,0 +1,160 @@
+import re, copy
+
+def safe_subn(pattern, repl, target, *args, **kwargs):
+ """
+ There are Unicode conversion problems with re.subn. We try to smooth
+ that over by casting the pattern and replacement to strings. We really
+ need a better solution that is aware of the actual content ecoding.
+ """
+ return re.subn(str(pattern), str(repl), target, *args, **kwargs)
+
+
+class ODict:
+ """
+ A dictionary-like object for managing ordered (key, value) data.
+ """
+ def __init__(self, lst=None):
+ self.lst = lst or []
+
+ def _kconv(self, s):
+ return s
+
+ def __eq__(self, other):
+ return self.lst == other.lst
+
+ def __getitem__(self, k):
+ """
+ Returns a list of values matching key.
+ """
+ ret = []
+ k = self._kconv(k)
+ for i in self.lst:
+ if self._kconv(i[0]) == k:
+ ret.append(i[1])
+ return ret
+
+ def _filter_lst(self, k, lst):
+ k = self._kconv(k)
+ new = []
+ for i in lst:
+ if self._kconv(i[0]) != k:
+ new.append(i)
+ return new
+
+ def __len__(self):
+ """
+ Total number of (key, value) pairs.
+ """
+ return len(self.lst)
+
+ def __setitem__(self, k, valuelist):
+ """
+ Sets the values for key k. If there are existing values for this
+ key, they are cleared.
+ """
+ if isinstance(valuelist, basestring):
+ raise ValueError("ODict valuelist should be lists.")
+ new = self._filter_lst(k, self.lst)
+ for i in valuelist:
+ new.append([k, i])
+ self.lst = new
+
+ def __delitem__(self, k):
+ """
+ Delete all items matching k.
+ """
+ self.lst = self._filter_lst(k, self.lst)
+
+ def __contains__(self, k):
+ for i in self.lst:
+ if self._kconv(i[0]) == self._kconv(k):
+ return True
+ return False
+
+ def add(self, key, value):
+ self.lst.append([key, str(value)])
+
+ def get(self, k, d=None):
+ if k in self:
+ return self[k]
+ else:
+ return d
+
+ def items(self):
+ return self.lst[:]
+
+ def _get_state(self):
+ return [tuple(i) for i in self.lst]
+
+ @classmethod
+ def _from_state(klass, state):
+ return klass([list(i) for i in state])
+
+ def copy(self):
+ """
+ Returns a copy of this object.
+ """
+ lst = copy.deepcopy(self.lst)
+ return self.__class__(lst)
+
+ def __repr__(self):
+ elements = []
+ for itm in self.lst:
+ elements.append(itm[0] + ": " + itm[1])
+ elements.append("")
+ return "\r\n".join(elements)
+
+ def in_any(self, key, value, caseless=False):
+ """
+ Do any of the values matching key contain value?
+
+ If caseless is true, value comparison is case-insensitive.
+ """
+ if caseless:
+ value = value.lower()
+ for i in self[key]:
+ if caseless:
+ i = i.lower()
+ if value in i:
+ return True
+ return False
+
+ def match_re(self, expr):
+ """
+ Match the regular expression against each (key, value) pair. For
+ each pair a string of the following format is matched against:
+
+ "key: value"
+ """
+ for k, v in self.lst:
+ s = "%s: %s"%(k, v)
+ if re.search(expr, s):
+ return True
+ return False
+
+ def replace(self, pattern, repl, *args, **kwargs):
+ """
+ Replaces a regular expression pattern with repl in both keys and
+ values. Encoded content will be decoded before replacement, and
+ re-encoded afterwards.
+
+ Returns the number of replacements made.
+ """
+ nlst, count = [], 0
+ for i in self.lst:
+ k, c = safe_subn(pattern, repl, i[0], *args, **kwargs)
+ count += c
+ v, c = safe_subn(pattern, repl, i[1], *args, **kwargs)
+ count += c
+ nlst.append([k, v])
+ self.lst = nlst
+ return count
+
+
+class ODictCaseless(ODict):
+ """
+ A variant of ODict with "caseless" keys. This version _preserves_ key
+ case, but does not consider case when setting or getting items.
+ """
+ def _kconv(self, s):
+ return s.lower()
diff --git a/netlib/protocol.py b/netlib/protocol.py
new file mode 100644
index 00000000..55bcf440
--- /dev/null
+++ b/netlib/protocol.py
@@ -0,0 +1,218 @@
+import string, urlparse
+
+class ProtocolError(Exception):
+ def __init__(self, code, msg):
+ self.code, self.msg = code, msg
+
+ def __str__(self):
+ return "ProtocolError(%s, %s)"%(self.code, self.msg)
+
+
+def parse_url(url):
+ """
+ Returns a (scheme, host, port, path) tuple, or None on error.
+ """
+ scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
+ if not scheme:
+ return None
+ if ':' in netloc:
+ host, port = string.rsplit(netloc, ':', maxsplit=1)
+ try:
+ port = int(port)
+ except ValueError:
+ return None
+ else:
+ host = netloc
+ if scheme == "https":
+ port = 443
+ else:
+ port = 80
+ path = urlparse.urlunparse(('', '', path, params, query, fragment))
+ if not path.startswith("/"):
+ path = "/" + path
+ return scheme, host, port, path
+
+
+def read_headers(fp):
+ """
+ Read a set of headers from a file pointer. Stop once a blank line
+ is reached. Return a ODictCaseless object.
+ """
+ ret = []
+ name = ''
+ while 1:
+ line = fp.readline()
+ if not line or line == '\r\n' or line == '\n':
+ break
+ if line[0] in ' \t':
+ # continued header
+ ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
+ else:
+ i = line.find(':')
+ # We're being liberal in what we accept, here.
+ if i > 0:
+ name = line[:i]
+ value = line[i+1:].strip()
+ ret.append([name, value])
+ return ret
+
+
+def read_chunked(fp, limit):
+ content = ""
+ total = 0
+ while 1:
+ line = fp.readline(128)
+ if line == "":
+ raise IOError("Connection closed")
+ if line == '\r\n' or line == '\n':
+ continue
+ try:
+ length = int(line,16)
+ except ValueError:
+ # FIXME: Not strictly correct - this could be from the server, in which
+ # case we should send a 502.
+ raise ProtocolError(400, "Invalid chunked encoding length: %s"%line)
+ if not length:
+ break
+ total += length
+ if limit is not None and total > limit:
+ msg = "HTTP Body too large."\
+ " Limit is %s, chunked content length was at least %s"%(limit, total)
+ raise ProtocolError(509, msg)
+ content += fp.read(length)
+ line = fp.readline(5)
+ if line != '\r\n':
+ raise IOError("Malformed chunked body")
+ while 1:
+ line = fp.readline()
+ if line == "":
+ raise IOError("Connection closed")
+ if line == '\r\n' or line == '\n':
+ break
+ return content
+
+
+def has_chunked_encoding(headers):
+ for i in headers["transfer-encoding"]:
+ for j in i.split(","):
+ if j.lower() == "chunked":
+ return True
+ return False
+
+
+def read_http_body(rfile, headers, all, limit):
+ if has_chunked_encoding(headers):
+ content = read_chunked(rfile, limit)
+ elif "content-length" in headers:
+ try:
+ l = int(headers["content-length"][0])
+ except ValueError:
+ # FIXME: Not strictly correct - this could be from the server, in which
+ # case we should send a 502.
+ raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"])
+ if limit is not None and l > limit:
+ raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l))
+ content = rfile.read(l)
+ elif all:
+ content = rfile.read(limit if limit else None)
+ else:
+ content = ""
+ return content
+
+
+def parse_http_protocol(s):
+ if not s.startswith("HTTP/"):
+ return None
+ major, minor = s.split('/')[1].split('.')
+ major = int(major)
+ minor = int(minor)
+ return major, minor
+
+
+def parse_init_connect(line):
+ try:
+ method, url, protocol = string.split(line)
+ except ValueError:
+ return None
+ if method != 'CONNECT':
+ return None
+ try:
+ host, port = url.split(":")
+ except ValueError:
+ return None
+ port = int(port)
+ httpversion = parse_http_protocol(protocol)
+ if not httpversion:
+ return None
+ return host, port, httpversion
+
+
+def parse_init_proxy(line):
+ try:
+ method, url, protocol = string.split(line)
+ except ValueError:
+ return None
+ parts = parse_url(url)
+ if not parts:
+ return None
+ scheme, host, port, path = parts
+ httpversion = parse_http_protocol(protocol)
+ if not httpversion:
+ return None
+ return method, scheme, host, port, path, httpversion
+
+
+def parse_init_http(line):
+ """
+ Returns (method, url, httpversion)
+ """
+ try:
+ method, url, protocol = string.split(line)
+ except ValueError:
+ return None
+ if not (url.startswith("/") or url == "*"):
+ return None
+ httpversion = parse_http_protocol(protocol)
+ if not httpversion:
+ return None
+ return method, url, httpversion
+
+
+def request_connection_close(httpversion, headers):
+ """
+ Checks the request to see if the client connection should be closed.
+ """
+ if "connection" in headers:
+ for value in ",".join(headers['connection']).split(","):
+ value = value.strip()
+ if value == "close":
+ return True
+ elif value == "keep-alive":
+ return False
+ # HTTP 1.1 connections are assumed to be persistent
+ if httpversion == (1, 1):
+ return False
+ return True
+
+
+def response_connection_close(httpversion, headers):
+ """
+ Checks the response to see if the client connection should be closed.
+ """
+ if request_connection_close(httpversion, headers):
+ return True
+ elif not has_chunked_encoding(headers) and "content-length" in headers:
+ return True
+ return False
+
+
+def read_http_body_request(rfile, wfile, headers, httpversion, limit):
+ if "expect" in headers:
+ # FIXME: Should be forwarded upstream
+ expect = ",".join(headers['expect'])
+ if expect == "100-continue" and httpversion >= (1, 1):
+ wfile.write('HTTP/1.1 100 Continue\r\n')
+ wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
+ wfile.write('\r\n')
+ del headers['expect']
+ return read_http_body(rfile, headers, False, limit)
diff --git a/netlib/tcp.py b/netlib/tcp.py
new file mode 100644
index 00000000..08ccba09
--- /dev/null
+++ b/netlib/tcp.py
@@ -0,0 +1,182 @@
+import select, socket, threading, traceback, sys
+from OpenSSL import SSL
+
+
+class NetLibError(Exception): pass
+
+
+class FileLike:
+ def __init__(self, o):
+ self.o = o
+
+ def __getattr__(self, attr):
+ return getattr(self.o, attr)
+
+ def flush(self):
+ pass
+
+ def read(self, length):
+ result = ''
+ while len(result) < length:
+ try:
+ data = self.o.read(length)
+ except SSL.ZeroReturnError:
+ break
+ if not data:
+ break
+ result += data
+ return result
+
+ def write(self, v):
+ self.o.sendall(v)
+
+ def readline(self, size = None):
+ result = ''
+ bytes_read = 0
+ while True:
+ if size is not None and bytes_read >= size:
+ break
+ ch = self.read(1)
+ bytes_read += 1
+ if not ch:
+ break
+ else:
+ result += ch
+ if ch == '\n':
+ break
+ return result
+
+
+class TCPClient:
+ def __init__(self, ssl, host, port, clientcert):
+ self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert
+ self.connection, self.rfile, self.wfile = None, None, None
+ self.cert = None
+ self.connect()
+
+ def connect(self):
+ try:
+ addr = socket.gethostbyname(self.host)
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ if self.ssl:
+ context = SSL.Context(SSL.SSLv23_METHOD)
+ if self.clientcert:
+ context.use_certificate_file(self.clientcert)
+ server = SSL.Connection(context, server)
+ server.connect((addr, self.port))
+ if self.ssl:
+ self.cert = server.get_peer_certificate()
+ self.rfile, self.wfile = FileLike(server), FileLike(server)
+ else:
+ self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
+ except socket.error, err:
+ raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
+ self.connection = server
+
+
+class BaseHandler:
+ rbufsize = -1
+ wbufsize = 0
+ def __init__(self, connection, client_address, server):
+ self.connection = connection
+ self.rfile = self.connection.makefile('rb', self.rbufsize)
+ self.wfile = self.connection.makefile('wb', self.wbufsize)
+
+ self.client_address = client_address
+ self.server = server
+ self.handle()
+ self.finish()
+
+ def convert_to_ssl(self, cert, key):
+ ctx = SSL.Context(SSL.SSLv23_METHOD)
+ ctx.use_privatekey_file(key)
+ ctx.use_certificate_file(cert)
+ self.connection = SSL.Connection(ctx, self.connection)
+ self.connection.set_accept_state()
+ self.rfile = FileLike(self.connection)
+ self.wfile = FileLike(self.connection)
+
+ def finish(self):
+ try:
+ if not getattr(self.wfile, "closed", False):
+ self.wfile.flush()
+ self.connection.close()
+ self.wfile.close()
+ self.rfile.close()
+ except IOError: # pragma: no cover
+ pass
+
+ def handle(self): # pragma: no cover
+ raise NotImplementedError
+
+
+class TCPServer:
+ request_queue_size = 20
+ def __init__(self, server_address):
+ self.server_address = server_address
+ self.__is_shut_down = threading.Event()
+ self.__shutdown_request = False
+ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ self.socket.bind(self.server_address)
+ self.server_address = self.socket.getsockname()
+ self.socket.listen(self.request_queue_size)
+ self.port = self.socket.getsockname()[1]
+
+ def request_thread(self, request, client_address):
+ try:
+ self.handle_connection(request, client_address)
+ request.close()
+ except:
+ self.handle_error(request, client_address)
+ request.close()
+
+ def serve_forever(self, poll_interval=0.5):
+ self.__is_shut_down.clear()
+ try:
+ while not self.__shutdown_request:
+ r, w, e = select.select([self.socket], [], [], poll_interval)
+ if self.socket in r:
+ try:
+ request, client_address = self.socket.accept()
+ except socket.error:
+ return
+ try:
+ t = threading.Thread(
+ target = self.request_thread,
+ args = (request, client_address)
+ )
+ t.setDaemon(1)
+ t.start()
+ except:
+ self.handle_error(request, client_address)
+ request.close()
+ finally:
+ self.__shutdown_request = False
+ self.__is_shut_down.set()
+
+ def shutdown(self):
+ self.__shutdown_request = True
+ self.__is_shut_down.wait()
+ self.handle_shutdown()
+
+ def handle_error(self, request, client_address, fp=sys.stderr):
+ """
+ Called when handle_connection raises an exception.
+ """
+ print >> fp, '-'*40
+ print >> fp, "Error processing of request from %s:%s"%client_address
+ print >> fp, traceback.format_exc()
+ print >> fp, '-'*40
+
+ def handle_connection(self, request, client_address): # pragma: no cover
+ """
+ Called after client connection.
+ """
+ raise NotImplementedError
+
+ def handle_shutdown(self):
+ """
+ Called after server shutdown.
+ """
+ pass
diff --git a/test/test_odict.py b/test/test_odict.py
new file mode 100644
index 00000000..e7453e2d
--- /dev/null
+++ b/test/test_odict.py
@@ -0,0 +1,113 @@
+from netlib import odict
+import tutils
+
+
+class TestODict:
+ def setUp(self):
+ self.od = odict.ODict()
+
+ def test_str_err(self):
+ h = odict.ODict()
+ tutils.raises(ValueError, h.__setitem__, "key", "foo")
+
+ def test_dictToHeader1(self):
+ self.od.add("one", "uno")
+ self.od.add("two", "due")
+ self.od.add("two", "tre")
+ expected = [
+ "one: uno\r\n",
+ "two: due\r\n",
+ "two: tre\r\n",
+ "\r\n"
+ ]
+ out = repr(self.od)
+ for i in expected:
+ assert out.find(i) >= 0
+
+ def test_dictToHeader2(self):
+ self.od["one"] = ["uno"]
+ expected1 = "one: uno\r\n"
+ expected2 = "\r\n"
+ out = repr(self.od)
+ assert out.find(expected1) >= 0
+ assert out.find(expected2) >= 0
+
+ def test_match_re(self):
+ h = odict.ODict()
+ h.add("one", "uno")
+ h.add("two", "due")
+ h.add("two", "tre")
+ assert h.match_re("uno")
+ assert h.match_re("two: due")
+ assert not h.match_re("nonono")
+
+ def test_getset_state(self):
+ self.od.add("foo", 1)
+ self.od.add("foo", 2)
+ self.od.add("bar", 3)
+ state = self.od._get_state()
+ nd = odict.ODict._from_state(state)
+ assert nd == self.od
+
+ def test_in_any(self):
+ self.od["one"] = ["atwoa", "athreea"]
+ assert self.od.in_any("one", "two")
+ assert self.od.in_any("one", "three")
+ assert not self.od.in_any("one", "four")
+ assert not self.od.in_any("nonexistent", "foo")
+ assert not self.od.in_any("one", "TWO")
+ assert self.od.in_any("one", "TWO", True)
+
+ def test_copy(self):
+ self.od.add("foo", 1)
+ self.od.add("foo", 2)
+ self.od.add("bar", 3)
+ assert self.od == self.od.copy()
+
+ def test_del(self):
+ self.od.add("foo", 1)
+ self.od.add("Foo", 2)
+ self.od.add("bar", 3)
+ del self.od["foo"]
+ assert len(self.od.lst) == 2
+
+ def test_replace(self):
+ self.od.add("one", "two")
+ self.od.add("two", "one")
+ assert self.od.replace("one", "vun") == 2
+ assert self.od.lst == [
+ ["vun", "two"],
+ ["two", "vun"],
+ ]
+
+ def test_get(self):
+ self.od.add("one", "two")
+ assert self.od.get("one") == ["two"]
+ assert self.od.get("two") == None
+
+
+class TestODictCaseless:
+ def setUp(self):
+ self.od = odict.ODictCaseless()
+
+ def test_override(self):
+ o = odict.ODictCaseless()
+ o.add('T', 'application/x-www-form-urlencoded; charset=UTF-8')
+ o["T"] = ["foo"]
+ assert o["T"] == ["foo"]
+
+ def test_case_preservation(self):
+ self.od["Foo"] = ["1"]
+ assert "foo" in self.od
+ assert self.od.items()[0][0] == "Foo"
+ assert self.od.get("foo") == ["1"]
+ assert self.od.get("foo", [""]) == ["1"]
+ assert self.od.get("Foo", [""]) == ["1"]
+ assert self.od.get("xx", "yy") == "yy"
+
+ def test_del(self):
+ self.od.add("foo", 1)
+ self.od.add("Foo", 2)
+ self.od.add("bar", 3)
+ del self.od["foo"]
+ assert len(self.od) == 1
diff --git a/test/test_protocol.py b/test/test_protocol.py
new file mode 100644
index 00000000..028faadd
--- /dev/null
+++ b/test/test_protocol.py
@@ -0,0 +1,163 @@
+import cStringIO, textwrap
+from netlib import protocol, odict
+import tutils
+
+def test_has_chunked_encoding():
+ h = odict.ODictCaseless()
+ assert not protocol.has_chunked_encoding(h)
+ h["transfer-encoding"] = ["chunked"]
+ assert protocol.has_chunked_encoding(h)
+
+
+def test_read_chunked():
+ s = cStringIO.StringIO("1\r\na\r\n0\r\n")
+ tutils.raises(IOError, protocol.read_chunked, s, None)
+
+ s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n")
+ assert protocol.read_chunked(s, None) == "a"
+
+ s = cStringIO.StringIO("\r\n")
+ tutils.raises(IOError, protocol.read_chunked, s, None)
+
+ s = cStringIO.StringIO("1\r\nfoo")
+ tutils.raises(IOError, protocol.read_chunked, s, None)
+
+ s = cStringIO.StringIO("foo\r\nfoo")
+ tutils.raises(protocol.ProtocolError, protocol.read_chunked, s, None)
+
+
+def test_request_connection_close():
+ h = odict.ODictCaseless()
+ assert protocol.request_connection_close((1, 0), h)
+ assert not protocol.request_connection_close((1, 1), h)
+
+ h["connection"] = ["keep-alive"]
+ assert not protocol.request_connection_close((1, 1), h)
+
+
+def test_read_http_body():
+ h = odict.ODict()
+ s = cStringIO.StringIO("testing")
+ assert protocol.read_http_body(s, h, False, None) == ""
+
+ h["content-length"] = ["foo"]
+ s = cStringIO.StringIO("testing")
+ tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, None)
+
+ h["content-length"] = [5]
+ s = cStringIO.StringIO("testing")
+ assert len(protocol.read_http_body(s, h, False, None)) == 5
+ s = cStringIO.StringIO("testing")
+ tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, 4)
+
+ h = odict.ODict()
+ s = cStringIO.StringIO("testing")
+ assert len(protocol.read_http_body(s, h, True, 4)) == 4
+ s = cStringIO.StringIO("testing")
+ assert len(protocol.read_http_body(s, h, True, 100)) == 7
+
+def test_parse_http_protocol():
+ assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1)
+ assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0)
+ assert not protocol.parse_http_protocol("foo/0.0")
+
+
+def test_parse_init_connect():
+ assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0")
+ assert not protocol.parse_init_connect("bogus")
+ assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0")
+ assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0")
+ assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0")
+
+
+def test_prase_init_proxy():
+ u = "GET http://foo.com:8888/test HTTP/1.1"
+ m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u)
+ assert m == "GET"
+ assert s == "http"
+ assert h == "foo.com"
+ assert po == 8888
+ assert pa == "/test"
+ assert httpversion == (1, 1)
+
+ assert not protocol.parse_init_proxy("invalid")
+ assert not protocol.parse_init_proxy("GET invalid HTTP/1.1")
+ assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1")
+
+
+def test_parse_init_http():
+ u = "GET /test HTTP/1.1"
+ m, u, httpversion= protocol.parse_init_http(u)
+ assert m == "GET"
+ assert u == "/test"
+ assert httpversion == (1, 1)
+
+ assert not protocol.parse_init_http("invalid")
+ assert not protocol.parse_init_http("GET invalid HTTP/1.1")
+ assert not protocol.parse_init_http("GET /test foo/1.1")
+
+
+class TestReadHeaders:
+ def test_read_simple(self):
+ data = """
+ Header: one
+ Header2: two
+ \r\n
+ """
+ data = textwrap.dedent(data)
+ data = data.strip()
+ s = cStringIO.StringIO(data)
+ h = protocol.read_headers(s)
+ assert h == [["Header", "one"], ["Header2", "two"]]
+
+ def test_read_multi(self):
+ data = """
+ Header: one
+ Header: two
+ \r\n
+ """
+ data = textwrap.dedent(data)
+ data = data.strip()
+ s = cStringIO.StringIO(data)
+ h = protocol.read_headers(s)
+ assert h == [["Header", "one"], ["Header", "two"]]
+
+ def test_read_continued(self):
+ data = """
+ Header: one
+ \ttwo
+ Header2: three
+ \r\n
+ """
+ data = textwrap.dedent(data)
+ data = data.strip()
+ s = cStringIO.StringIO(data)
+ h = protocol.read_headers(s)
+ assert h == [["Header", "one\r\n two"], ["Header2", "three"]]
+
+
+def test_parse_url():
+ assert not protocol.parse_url("")
+
+ u = "http://foo.com:8888/test"
+ s, h, po, pa = protocol.parse_url(u)
+ assert s == "http"
+ assert h == "foo.com"
+ assert po == 8888
+ assert pa == "/test"
+
+ s, h, po, pa = protocol.parse_url("http://foo/bar")
+ assert s == "http"
+ assert h == "foo"
+ assert po == 80
+ assert pa == "/bar"
+
+ s, h, po, pa = protocol.parse_url("http://foo")
+ assert pa == "/"
+
+ s, h, po, pa = protocol.parse_url("https://foo")
+ assert po == 443
+
+ assert not protocol.parse_url("https://foo:bar")
+ assert not protocol.parse_url("https://foo:")
+
diff --git a/test/test_tcp.py b/test/test_tcp.py
new file mode 100644
index 00000000..d7d4483e
--- /dev/null
+++ b/test/test_tcp.py
@@ -0,0 +1,93 @@
+import cStringIO, threading, Queue
+from netlib import tcp
+import tutils
+
+class ServerThread(threading.Thread):
+ def __init__(self, server):
+ self.server = server
+ threading.Thread.__init__(self)
+
+ def run(self):
+ self.server.serve_forever()
+
+ def shutdown(self):
+ self.server.shutdown()
+
+
+class ServerTestBase:
+ @classmethod
+ def setupAll(cls):
+ cls.server = ServerThread(cls.makeserver())
+ cls.server.start()
+
+ @classmethod
+ def teardownAll(cls):
+ cls.server.shutdown()
+
+
+class THandler(tcp.BaseHandler):
+ def handle(self):
+ v = self.rfile.readline()
+ if v.startswith("echo"):
+ self.wfile.write(v)
+ elif v.startswith("error"):
+ raise ValueError("Testing an error.")
+ self.wfile.flush()
+
+
+class TServer(tcp.TCPServer):
+ def __init__(self, addr, q):
+ tcp.TCPServer.__init__(self, addr)
+ self.q = q
+
+ def handle_connection(self, request, client_address):
+ THandler(request, client_address, self)
+
+ def handle_error(self, request, client_address):
+ s = cStringIO.StringIO()
+ tcp.TCPServer.handle_error(self, request, client_address, s)
+ self.q.put(s.getvalue())
+
+
+class TestServer(ServerTestBase):
+ @classmethod
+ def makeserver(cls):
+ cls.q = Queue.Queue()
+ s = TServer(("127.0.0.1", 0), cls.q)
+ cls.port = s.port
+ return s
+
+ def test_echo(self):
+ testval = "echo!\n"
+ c = tcp.TCPClient(False, "127.0.0.1", self.port, None)
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
+
+ def test_error(self):
+ testval = "error!\n"
+ c = tcp.TCPClient(False, "127.0.0.1", self.port, None)
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert "Testing an error" in self.q.get()
+
+
+class TestTCPClient:
+ def test_conerr(self):
+ tutils.raises(tcp.NetLibError, tcp.TCPClient, False, "127.0.0.1", 0, None)
+
+
+class TestFileLike:
+ def test_wrap(self):
+ s = cStringIO.StringIO("foobar\nfoobar")
+ s = tcp.FileLike(s)
+ s.flush()
+ assert s.readline() == "foobar\n"
+ assert s.readline() == "foobar"
+ # Test __getattr__
+ assert s.isatty
+
+ def test_limit(self):
+ s = cStringIO.StringIO("foobar\nfoobar")
+ s = tcp.FileLike(s)
+ assert s.readline(3) == "foo"
diff --git a/test/tutils.py b/test/tutils.py
new file mode 100644
index 00000000..c8e06b96
--- /dev/null
+++ b/test/tutils.py
@@ -0,0 +1,56 @@
+import tempfile, os, shutil
+from contextlib import contextmanager
+from libpathod import utils
+
+
+@contextmanager
+def tmpdir(*args, **kwargs):
+ orig_workdir = os.getcwd()
+ temp_workdir = tempfile.mkdtemp(*args, **kwargs)
+ os.chdir(temp_workdir)
+
+ yield temp_workdir
+
+ os.chdir(orig_workdir)
+ shutil.rmtree(temp_workdir)
+
+
+def raises(exc, obj, *args, **kwargs):
+ """
+ Assert that a callable raises a specified exception.
+
+ :exc An exception class or a string. If a class, assert that an
+ exception of this type is raised. If a string, assert that the string
+ occurs in the string representation of the exception, based on a
+ case-insenstivie match.
+
+ :obj A callable object.
+
+ :args Arguments to be passsed to the callable.
+
+ :kwargs Arguments to be passed to the callable.
+ """
+ try:
+ apply(obj, args, kwargs)
+ except Exception, v:
+ if isinstance(exc, basestring):
+ if exc.lower() in str(v).lower():
+ return
+ else:
+ raise AssertionError(
+ "Expected %s, but caught %s"%(
+ repr(str(exc)), v
+ )
+ )
+ else:
+ if isinstance(v, exc):
+ return
+ else:
+ raise AssertionError(
+ "Expected %s, but caught %s %s"%(
+ exc.__name__, v.__class__.__name__, str(v)
+ )
+ )
+ raise AssertionError("No exception raised.")
+
+test_data = utils.Data(__name__)