From b558997fd9db8406b2a24a1831d06e283dbf35a6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 19 Jun 2012 09:42:32 +1200 Subject: Initial checkin. --- .coveragerc | 2 + .gitignore | 9 +++ README | 2 + netlib/__init__.py | 0 netlib/odict.py | 160 ++++++++++++++++++++++++++++++++++++ netlib/protocol.py | 218 ++++++++++++++++++++++++++++++++++++++++++++++++++ netlib/tcp.py | 182 +++++++++++++++++++++++++++++++++++++++++ test/test_odict.py | 113 ++++++++++++++++++++++++++ test/test_protocol.py | 163 +++++++++++++++++++++++++++++++++++++ test/test_tcp.py | 93 +++++++++++++++++++++ test/tutils.py | 56 +++++++++++++ 11 files changed, 998 insertions(+) create mode 100644 .coveragerc create mode 100644 .gitignore create mode 100644 README create mode 100644 netlib/__init__.py create mode 100644 netlib/odict.py create mode 100644 netlib/protocol.py create mode 100644 netlib/tcp.py create mode 100644 test/test_odict.py create mode 100644 test/test_protocol.py create mode 100644 test/test_tcp.py create mode 100644 test/tutils.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..99f57cb0 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,2 @@ +[report] +include = *netlib* diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..f53cd2e2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +MANIFEST +/build +/dist +/tmp +/doc +*.py[cdo] +*.swp +*.swo +.coverage diff --git a/README b/README new file mode 100644 index 00000000..1c86738c --- /dev/null +++ b/README @@ -0,0 +1,2 @@ +Netlib is a collection of common utility functions, used by the pathod and +mitmproxy projects. diff --git a/netlib/__init__.py b/netlib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netlib/odict.py b/netlib/odict.py new file mode 100644 index 00000000..afc33caa --- /dev/null +++ b/netlib/odict.py @@ -0,0 +1,160 @@ +import re, copy + +def safe_subn(pattern, repl, target, *args, **kwargs): + """ + There are Unicode conversion problems with re.subn. We try to smooth + that over by casting the pattern and replacement to strings. We really + need a better solution that is aware of the actual content ecoding. + """ + return re.subn(str(pattern), str(repl), target, *args, **kwargs) + + +class ODict: + """ + A dictionary-like object for managing ordered (key, value) data. + """ + def __init__(self, lst=None): + self.lst = lst or [] + + def _kconv(self, s): + return s + + def __eq__(self, other): + return self.lst == other.lst + + def __getitem__(self, k): + """ + Returns a list of values matching key. + """ + ret = [] + k = self._kconv(k) + for i in self.lst: + if self._kconv(i[0]) == k: + ret.append(i[1]) + return ret + + def _filter_lst(self, k, lst): + k = self._kconv(k) + new = [] + for i in lst: + if self._kconv(i[0]) != k: + new.append(i) + return new + + def __len__(self): + """ + Total number of (key, value) pairs. + """ + return len(self.lst) + + def __setitem__(self, k, valuelist): + """ + Sets the values for key k. If there are existing values for this + key, they are cleared. + """ + if isinstance(valuelist, basestring): + raise ValueError("ODict valuelist should be lists.") + new = self._filter_lst(k, self.lst) + for i in valuelist: + new.append([k, i]) + self.lst = new + + def __delitem__(self, k): + """ + Delete all items matching k. + """ + self.lst = self._filter_lst(k, self.lst) + + def __contains__(self, k): + for i in self.lst: + if self._kconv(i[0]) == self._kconv(k): + return True + return False + + def add(self, key, value): + self.lst.append([key, str(value)]) + + def get(self, k, d=None): + if k in self: + return self[k] + else: + return d + + def items(self): + return self.lst[:] + + def _get_state(self): + return [tuple(i) for i in self.lst] + + @classmethod + def _from_state(klass, state): + return klass([list(i) for i in state]) + + def copy(self): + """ + Returns a copy of this object. + """ + lst = copy.deepcopy(self.lst) + return self.__class__(lst) + + def __repr__(self): + elements = [] + for itm in self.lst: + elements.append(itm[0] + ": " + itm[1]) + elements.append("") + return "\r\n".join(elements) + + def in_any(self, key, value, caseless=False): + """ + Do any of the values matching key contain value? + + If caseless is true, value comparison is case-insensitive. + """ + if caseless: + value = value.lower() + for i in self[key]: + if caseless: + i = i.lower() + if value in i: + return True + return False + + def match_re(self, expr): + """ + Match the regular expression against each (key, value) pair. For + each pair a string of the following format is matched against: + + "key: value" + """ + for k, v in self.lst: + s = "%s: %s"%(k, v) + if re.search(expr, s): + return True + return False + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both keys and + values. Encoded content will be decoded before replacement, and + re-encoded afterwards. + + Returns the number of replacements made. + """ + nlst, count = [], 0 + for i in self.lst: + k, c = safe_subn(pattern, repl, i[0], *args, **kwargs) + count += c + v, c = safe_subn(pattern, repl, i[1], *args, **kwargs) + count += c + nlst.append([k, v]) + self.lst = nlst + return count + + +class ODictCaseless(ODict): + """ + A variant of ODict with "caseless" keys. This version _preserves_ key + case, but does not consider case when setting or getting items. + """ + def _kconv(self, s): + return s.lower() diff --git a/netlib/protocol.py b/netlib/protocol.py new file mode 100644 index 00000000..55bcf440 --- /dev/null +++ b/netlib/protocol.py @@ -0,0 +1,218 @@ +import string, urlparse + +class ProtocolError(Exception): + def __init__(self, code, msg): + self.code, self.msg = code, msg + + def __str__(self): + return "ProtocolError(%s, %s)"%(self.code, self.msg) + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + """ + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + if not scheme: + return None + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + return scheme, host, port, path + + +def read_headers(fp): + """ + Read a set of headers from a file pointer. Stop once a blank line + is reached. Return a ODictCaseless object. + """ + ret = [] + name = '' + while 1: + line = fp.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i+1:].strip() + ret.append([name, value]) + return ret + + +def read_chunked(fp, limit): + content = "" + total = 0 + while 1: + line = fp.readline(128) + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + continue + try: + length = int(line,16) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise ProtocolError(400, "Invalid chunked encoding length: %s"%line) + if not length: + break + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large."\ + " Limit is %s, chunked content length was at least %s"%(limit, total) + raise ProtocolError(509, msg) + content += fp.read(length) + line = fp.readline(5) + if line != '\r\n': + raise IOError("Malformed chunked body") + while 1: + line = fp.readline() + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + break + return content + + +def has_chunked_encoding(headers): + for i in headers["transfer-encoding"]: + for j in i.split(","): + if j.lower() == "chunked": + return True + return False + + +def read_http_body(rfile, headers, all, limit): + if has_chunked_encoding(headers): + content = read_chunked(rfile, limit) + elif "content-length" in headers: + try: + l = int(headers["content-length"][0]) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"]) + if limit is not None and l > limit: + raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) + content = rfile.read(l) + elif all: + content = rfile.read(limit if limit else None) + else: + content = "" + return content + + +def parse_http_protocol(s): + if not s.startswith("HTTP/"): + return None + major, minor = s.split('/')[1].split('.') + major = int(major) + minor = int(minor) + return major, minor + + +def parse_init_connect(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if method != 'CONNECT': + return None + try: + host, port = url.split(":") + except ValueError: + return None + port = int(port) + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return host, port, httpversion + + +def parse_init_proxy(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + parts = parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, scheme, host, port, path, httpversion + + +def parse_init_http(line): + """ + Returns (method, url, httpversion) + """ + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if not (url.startswith("/") or url == "*"): + return None + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, url, httpversion + + +def request_connection_close(httpversion, headers): + """ + Checks the request to see if the client connection should be closed. + """ + if "connection" in headers: + for value in ",".join(headers['connection']).split(","): + value = value.strip() + if value == "close": + return True + elif value == "keep-alive": + return False + # HTTP 1.1 connections are assumed to be persistent + if httpversion == (1, 1): + return False + return True + + +def response_connection_close(httpversion, headers): + """ + Checks the response to see if the client connection should be closed. + """ + if request_connection_close(httpversion, headers): + return True + elif not has_chunked_encoding(headers) and "content-length" in headers: + return True + return False + + +def read_http_body_request(rfile, wfile, headers, httpversion, limit): + if "expect" in headers: + # FIXME: Should be forwarded upstream + expect = ",".join(headers['expect']) + if expect == "100-continue" and httpversion >= (1, 1): + wfile.write('HTTP/1.1 100 Continue\r\n') + wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) + wfile.write('\r\n') + del headers['expect'] + return read_http_body(rfile, headers, False, limit) diff --git a/netlib/tcp.py b/netlib/tcp.py new file mode 100644 index 00000000..08ccba09 --- /dev/null +++ b/netlib/tcp.py @@ -0,0 +1,182 @@ +import select, socket, threading, traceback, sys +from OpenSSL import SSL + + +class NetLibError(Exception): pass + + +class FileLike: + def __init__(self, o): + self.o = o + + def __getattr__(self, attr): + return getattr(self.o, attr) + + def flush(self): + pass + + def read(self, length): + result = '' + while len(result) < length: + try: + data = self.o.read(length) + except SSL.ZeroReturnError: + break + if not data: + break + result += data + return result + + def write(self, v): + self.o.sendall(v) + + def readline(self, size = None): + result = '' + bytes_read = 0 + while True: + if size is not None and bytes_read >= size: + break + ch = self.read(1) + bytes_read += 1 + if not ch: + break + else: + result += ch + if ch == '\n': + break + return result + + +class TCPClient: + def __init__(self, ssl, host, port, clientcert): + self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert + self.connection, self.rfile, self.wfile = None, None, None + self.cert = None + self.connect() + + def connect(self): + try: + addr = socket.gethostbyname(self.host) + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if self.ssl: + context = SSL.Context(SSL.SSLv23_METHOD) + if self.clientcert: + context.use_certificate_file(self.clientcert) + server = SSL.Connection(context, server) + server.connect((addr, self.port)) + if self.ssl: + self.cert = server.get_peer_certificate() + self.rfile, self.wfile = FileLike(server), FileLike(server) + else: + self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') + except socket.error, err: + raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) + self.connection = server + + +class BaseHandler: + rbufsize = -1 + wbufsize = 0 + def __init__(self, connection, client_address, server): + self.connection = connection + self.rfile = self.connection.makefile('rb', self.rbufsize) + self.wfile = self.connection.makefile('wb', self.wbufsize) + + self.client_address = client_address + self.server = server + self.handle() + self.finish() + + def convert_to_ssl(self, cert, key): + ctx = SSL.Context(SSL.SSLv23_METHOD) + ctx.use_privatekey_file(key) + ctx.use_certificate_file(cert) + self.connection = SSL.Connection(ctx, self.connection) + self.connection.set_accept_state() + self.rfile = FileLike(self.connection) + self.wfile = FileLike(self.connection) + + def finish(self): + try: + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.connection.close() + self.wfile.close() + self.rfile.close() + except IOError: # pragma: no cover + pass + + def handle(self): # pragma: no cover + raise NotImplementedError + + +class TCPServer: + request_queue_size = 20 + def __init__(self, server_address): + self.server_address = server_address + self.__is_shut_down = threading.Event() + self.__shutdown_request = False + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind(self.server_address) + self.server_address = self.socket.getsockname() + self.socket.listen(self.request_queue_size) + self.port = self.socket.getsockname()[1] + + def request_thread(self, request, client_address): + try: + self.handle_connection(request, client_address) + request.close() + except: + self.handle_error(request, client_address) + request.close() + + def serve_forever(self, poll_interval=0.5): + self.__is_shut_down.clear() + try: + while not self.__shutdown_request: + r, w, e = select.select([self.socket], [], [], poll_interval) + if self.socket in r: + try: + request, client_address = self.socket.accept() + except socket.error: + return + try: + t = threading.Thread( + target = self.request_thread, + args = (request, client_address) + ) + t.setDaemon(1) + t.start() + except: + self.handle_error(request, client_address) + request.close() + finally: + self.__shutdown_request = False + self.__is_shut_down.set() + + def shutdown(self): + self.__shutdown_request = True + self.__is_shut_down.wait() + self.handle_shutdown() + + def handle_error(self, request, client_address, fp=sys.stderr): + """ + Called when handle_connection raises an exception. + """ + print >> fp, '-'*40 + print >> fp, "Error processing of request from %s:%s"%client_address + print >> fp, traceback.format_exc() + print >> fp, '-'*40 + + def handle_connection(self, request, client_address): # pragma: no cover + """ + Called after client connection. + """ + raise NotImplementedError + + def handle_shutdown(self): + """ + Called after server shutdown. + """ + pass diff --git a/test/test_odict.py b/test/test_odict.py new file mode 100644 index 00000000..e7453e2d --- /dev/null +++ b/test/test_odict.py @@ -0,0 +1,113 @@ +from netlib import odict +import tutils + + +class TestODict: + def setUp(self): + self.od = odict.ODict() + + def test_str_err(self): + h = odict.ODict() + tutils.raises(ValueError, h.__setitem__, "key", "foo") + + def test_dictToHeader1(self): + self.od.add("one", "uno") + self.od.add("two", "due") + self.od.add("two", "tre") + expected = [ + "one: uno\r\n", + "two: due\r\n", + "two: tre\r\n", + "\r\n" + ] + out = repr(self.od) + for i in expected: + assert out.find(i) >= 0 + + def test_dictToHeader2(self): + self.od["one"] = ["uno"] + expected1 = "one: uno\r\n" + expected2 = "\r\n" + out = repr(self.od) + assert out.find(expected1) >= 0 + assert out.find(expected2) >= 0 + + def test_match_re(self): + h = odict.ODict() + h.add("one", "uno") + h.add("two", "due") + h.add("two", "tre") + assert h.match_re("uno") + assert h.match_re("two: due") + assert not h.match_re("nonono") + + def test_getset_state(self): + self.od.add("foo", 1) + self.od.add("foo", 2) + self.od.add("bar", 3) + state = self.od._get_state() + nd = odict.ODict._from_state(state) + assert nd == self.od + + def test_in_any(self): + self.od["one"] = ["atwoa", "athreea"] + assert self.od.in_any("one", "two") + assert self.od.in_any("one", "three") + assert not self.od.in_any("one", "four") + assert not self.od.in_any("nonexistent", "foo") + assert not self.od.in_any("one", "TWO") + assert self.od.in_any("one", "TWO", True) + + def test_copy(self): + self.od.add("foo", 1) + self.od.add("foo", 2) + self.od.add("bar", 3) + assert self.od == self.od.copy() + + def test_del(self): + self.od.add("foo", 1) + self.od.add("Foo", 2) + self.od.add("bar", 3) + del self.od["foo"] + assert len(self.od.lst) == 2 + + def test_replace(self): + self.od.add("one", "two") + self.od.add("two", "one") + assert self.od.replace("one", "vun") == 2 + assert self.od.lst == [ + ["vun", "two"], + ["two", "vun"], + ] + + def test_get(self): + self.od.add("one", "two") + assert self.od.get("one") == ["two"] + assert self.od.get("two") == None + + +class TestODictCaseless: + def setUp(self): + self.od = odict.ODictCaseless() + + def test_override(self): + o = odict.ODictCaseless() + o.add('T', 'application/x-www-form-urlencoded; charset=UTF-8') + o["T"] = ["foo"] + assert o["T"] == ["foo"] + + def test_case_preservation(self): + self.od["Foo"] = ["1"] + assert "foo" in self.od + assert self.od.items()[0][0] == "Foo" + assert self.od.get("foo") == ["1"] + assert self.od.get("foo", [""]) == ["1"] + assert self.od.get("Foo", [""]) == ["1"] + assert self.od.get("xx", "yy") == "yy" + + def test_del(self): + self.od.add("foo", 1) + self.od.add("Foo", 2) + self.od.add("bar", 3) + del self.od["foo"] + assert len(self.od) == 1 diff --git a/test/test_protocol.py b/test/test_protocol.py new file mode 100644 index 00000000..028faadd --- /dev/null +++ b/test/test_protocol.py @@ -0,0 +1,163 @@ +import cStringIO, textwrap +from netlib import protocol, odict +import tutils + +def test_has_chunked_encoding(): + h = odict.ODictCaseless() + assert not protocol.has_chunked_encoding(h) + h["transfer-encoding"] = ["chunked"] + assert protocol.has_chunked_encoding(h) + + +def test_read_chunked(): + s = cStringIO.StringIO("1\r\na\r\n0\r\n") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") + assert protocol.read_chunked(s, None) == "a" + + s = cStringIO.StringIO("\r\n") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("1\r\nfoo") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("foo\r\nfoo") + tutils.raises(protocol.ProtocolError, protocol.read_chunked, s, None) + + +def test_request_connection_close(): + h = odict.ODictCaseless() + assert protocol.request_connection_close((1, 0), h) + assert not protocol.request_connection_close((1, 1), h) + + h["connection"] = ["keep-alive"] + assert not protocol.request_connection_close((1, 1), h) + + +def test_read_http_body(): + h = odict.ODict() + s = cStringIO.StringIO("testing") + assert protocol.read_http_body(s, h, False, None) == "" + + h["content-length"] = ["foo"] + s = cStringIO.StringIO("testing") + tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, None) + + h["content-length"] = [5] + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, False, None)) == 5 + s = cStringIO.StringIO("testing") + tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, 4) + + h = odict.ODict() + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, True, 4)) == 4 + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, True, 100)) == 7 + +def test_parse_http_protocol(): + assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) + assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not protocol.parse_http_protocol("foo/0.0") + + +def test_parse_init_connect(): + assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("bogus") + assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") + + +def test_prase_init_proxy(): + u = "GET http://foo.com:8888/test HTTP/1.1" + m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) + assert m == "GET" + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + assert httpversion == (1, 1) + + assert not protocol.parse_init_proxy("invalid") + assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") + assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + + +def test_parse_init_http(): + u = "GET /test HTTP/1.1" + m, u, httpversion= protocol.parse_init_http(u) + assert m == "GET" + assert u == "/test" + assert httpversion == (1, 1) + + assert not protocol.parse_init_http("invalid") + assert not protocol.parse_init_http("GET invalid HTTP/1.1") + assert not protocol.parse_init_http("GET /test foo/1.1") + + +class TestReadHeaders: + def test_read_simple(self): + data = """ + Header: one + Header2: two + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = protocol.read_headers(s) + assert h == [["Header", "one"], ["Header2", "two"]] + + def test_read_multi(self): + data = """ + Header: one + Header: two + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = protocol.read_headers(s) + assert h == [["Header", "one"], ["Header", "two"]] + + def test_read_continued(self): + data = """ + Header: one + \ttwo + Header2: three + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = protocol.read_headers(s) + assert h == [["Header", "one\r\n two"], ["Header2", "three"]] + + +def test_parse_url(): + assert not protocol.parse_url("") + + u = "http://foo.com:8888/test" + s, h, po, pa = protocol.parse_url(u) + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + + s, h, po, pa = protocol.parse_url("http://foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = protocol.parse_url("http://foo") + assert pa == "/" + + s, h, po, pa = protocol.parse_url("https://foo") + assert po == 443 + + assert not protocol.parse_url("https://foo:bar") + assert not protocol.parse_url("https://foo:") + diff --git a/test/test_tcp.py b/test/test_tcp.py new file mode 100644 index 00000000..d7d4483e --- /dev/null +++ b/test/test_tcp.py @@ -0,0 +1,93 @@ +import cStringIO, threading, Queue +from netlib import tcp +import tutils + +class ServerThread(threading.Thread): + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class ServerTestBase: + @classmethod + def setupAll(cls): + cls.server = ServerThread(cls.makeserver()) + cls.server.start() + + @classmethod + def teardownAll(cls): + cls.server.shutdown() + + +class THandler(tcp.BaseHandler): + def handle(self): + v = self.rfile.readline() + if v.startswith("echo"): + self.wfile.write(v) + elif v.startswith("error"): + raise ValueError("Testing an error.") + self.wfile.flush() + + +class TServer(tcp.TCPServer): + def __init__(self, addr, q): + tcp.TCPServer.__init__(self, addr) + self.q = q + + def handle_connection(self, request, client_address): + THandler(request, client_address, self) + + def handle_error(self, request, client_address): + s = cStringIO.StringIO() + tcp.TCPServer.handle_error(self, request, client_address, s) + self.q.put(s.getvalue()) + + +class TestServer(ServerTestBase): + @classmethod + def makeserver(cls): + cls.q = Queue.Queue() + s = TServer(("127.0.0.1", 0), cls.q) + cls.port = s.port + return s + + def test_echo(self): + testval = "echo!\n" + c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval + + def test_error(self): + testval = "error!\n" + c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + c.wfile.write(testval) + c.wfile.flush() + assert "Testing an error" in self.q.get() + + +class TestTCPClient: + def test_conerr(self): + tutils.raises(tcp.NetLibError, tcp.TCPClient, False, "127.0.0.1", 0, None) + + +class TestFileLike: + def test_wrap(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.FileLike(s) + s.flush() + assert s.readline() == "foobar\n" + assert s.readline() == "foobar" + # Test __getattr__ + assert s.isatty + + def test_limit(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.FileLike(s) + assert s.readline(3) == "foo" diff --git a/test/tutils.py b/test/tutils.py new file mode 100644 index 00000000..c8e06b96 --- /dev/null +++ b/test/tutils.py @@ -0,0 +1,56 @@ +import tempfile, os, shutil +from contextlib import contextmanager +from libpathod import utils + + +@contextmanager +def tmpdir(*args, **kwargs): + orig_workdir = os.getcwd() + temp_workdir = tempfile.mkdtemp(*args, **kwargs) + os.chdir(temp_workdir) + + yield temp_workdir + + os.chdir(orig_workdir) + shutil.rmtree(temp_workdir) + + +def raises(exc, obj, *args, **kwargs): + """ + Assert that a callable raises a specified exception. + + :exc An exception class or a string. If a class, assert that an + exception of this type is raised. If a string, assert that the string + occurs in the string representation of the exception, based on a + case-insenstivie match. + + :obj A callable object. + + :args Arguments to be passsed to the callable. + + :kwargs Arguments to be passed to the callable. + """ + try: + apply(obj, args, kwargs) + except Exception, v: + if isinstance(exc, basestring): + if exc.lower() in str(v).lower(): + return + else: + raise AssertionError( + "Expected %s, but caught %s"%( + repr(str(exc)), v + ) + ) + else: + if isinstance(v, exc): + return + else: + raise AssertionError( + "Expected %s, but caught %s %s"%( + exc.__name__, v.__class__.__name__, str(v) + ) + ) + raise AssertionError("No exception raised.") + +test_data = utils.Data(__name__) -- cgit v1.2.3 From c7e9051cbbee1e76abb24518268d30a24df3a16a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 19 Jun 2012 10:42:25 +1200 Subject: Import wsgi. --- netlib/wsgi.py | 125 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ test/test_wsgi.py | 98 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+) create mode 100644 netlib/wsgi.py create mode 100644 test/test_wsgi.py diff --git a/netlib/wsgi.py b/netlib/wsgi.py new file mode 100644 index 00000000..0608245c --- /dev/null +++ b/netlib/wsgi.py @@ -0,0 +1,125 @@ +import cStringIO, urllib, time, sys, traceback +import odict + +def date_time_string(): + """Return the current date and time formatted for a message header.""" + WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] + MONTHS = [None, + 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', + 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] + now = time.time() + year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) + s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( + WEEKS[wd], + day, MONTHS[month], year, + hh, mm, ss) + return s + + +class WSGIAdaptor: + def __init__(self, app, domain, port, sversion): + self.app, self.domain, self.port, self.sversion = app, domain, port, sversion + + def make_environ(self, request, errsoc): + if '?' in request.path: + path_info, query = request.path.split('?', 1) + else: + path_info = request.path + query = '' + environ = { + 'wsgi.version': (1, 0), + 'wsgi.url_scheme': request.scheme, + 'wsgi.input': cStringIO.StringIO(request.content), + 'wsgi.errors': errsoc, + 'wsgi.multithread': True, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'SERVER_SOFTWARE': self.sversion, + 'REQUEST_METHOD': request.method, + 'SCRIPT_NAME': '', + 'PATH_INFO': urllib.unquote(path_info), + 'QUERY_STRING': query, + 'CONTENT_TYPE': request.headers.get('Content-Type', [''])[0], + 'CONTENT_LENGTH': request.headers.get('Content-Length', [''])[0], + 'SERVER_NAME': self.domain, + 'SERVER_PORT': self.port, + # FIXME: We need to pick up the protocol read from the request. + 'SERVER_PROTOCOL': "HTTP/1.1", + } + if request.client_conn.address: + environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address + + for key, value in request.headers.items(): + key = 'HTTP_' + key.upper().replace('-', '_') + if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): + environ[key] = value + return environ + + def error_page(self, soc, headers_sent, s): + """ + Make a best-effort attempt to write an error page. If headers are + already sent, we just bung the error into the page. + """ + c = """ + +

Internal Server Error

+
%s"
+ + """%s + if not headers_sent: + soc.write("HTTP/1.1 500 Internal Server Error\r\n") + soc.write("Content-Type: text/html\r\n") + soc.write("Content-Length: %s\r\n"%len(c)) + soc.write("\r\n") + soc.write(c) + + def serve(self, request, soc): + state = dict( + response_started = False, + headers_sent = False, + status = None, + headers = None + ) + def write(data): + if not state["headers_sent"]: + soc.write("HTTP/1.1 %s\r\n"%state["status"]) + h = state["headers"] + if 'server' not in h: + h["Server"] = [version.NAMEVERSION] + if 'date' not in h: + h["Date"] = [date_time_string()] + soc.write(str(h)) + soc.write("\r\n") + state["headers_sent"] = True + soc.write(data) + soc.flush() + + def start_response(status, headers, exc_info=None): + if exc_info: + try: + if state["headers_sent"]: + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None + elif state["status"]: + raise AssertionError('Response already started') + state["status"] = status + state["headers"] = odict.ODictCaseless(headers) + return write + + errs = cStringIO.StringIO() + try: + dataiter = self.app(self.make_environ(request, errs), start_response) + for i in dataiter: + write(i) + if not state["headers_sent"]: + write("") + except Exception, v: + try: + s = traceback.format_exc() + self.error_page(soc, state["headers_sent"], s) + except Exception, v: # pragma: no cover + pass # pragma: no cover + return errs.getvalue() + + diff --git a/test/test_wsgi.py b/test/test_wsgi.py new file mode 100644 index 00000000..c55ab1d8 --- /dev/null +++ b/test/test_wsgi.py @@ -0,0 +1,98 @@ +import cStringIO, sys +import libpry +from netlib import wsgi +import tutils + + +class TestApp: + def __init__(self): + self.called = False + + def __call__(self, environ, start_response): + self.called = True + status = '200 OK' + response_headers = [('Content-type', 'text/plain')] + start_response(status, response_headers) + return ['Hello', ' world!\n'] + + +class uWSGIAdaptor(libpry.AutoTree): + def test_make_environ(self): + w = wsgi.WSGIAdaptor(None, "foo", 80) + tr = tutils.treq() + assert w.make_environ(tr, None) + + tr.path = "/foo?bar=voing" + r = w.make_environ(tr, None) + assert r["QUERY_STRING"] == "bar=voing" + + def test_serve(self): + ta = TestApp() + w = wsgi.WSGIAdaptor(ta, "foo", 80) + r = tutils.treq() + r.host = "foo" + r.port = 80 + + wfile = cStringIO.StringIO() + err = w.serve(r, wfile) + assert ta.called + assert not err + + val = wfile.getvalue() + assert "Hello world" in val + assert "Server:" in val + + def _serve(self, app): + w = wsgi.WSGIAdaptor(app, "foo", 80) + r = tutils.treq() + r.host = "foo" + r.port = 80 + wfile = cStringIO.StringIO() + err = w.serve(r, wfile) + return wfile.getvalue() + + def test_serve_empty_body(self): + def app(environ, start_response): + status = '200 OK' + response_headers = [('Foo', 'bar')] + start_response(status, response_headers) + return [] + assert self._serve(app) + + def test_serve_double_start(self): + def app(environ, start_response): + try: + raise ValueError("foo") + except: + ei = sys.exc_info() + status = '200 OK' + response_headers = [('Content-type', 'text/plain')] + start_response(status, response_headers) + start_response(status, response_headers) + assert "Internal Server Error" in self._serve(app) + + def test_serve_single_err(self): + def app(environ, start_response): + try: + raise ValueError("foo") + except: + ei = sys.exc_info() + status = '200 OK' + response_headers = [('Content-type', 'text/plain')] + start_response(status, response_headers, ei) + assert "Internal Server Error" in self._serve(app) + + def test_serve_double_err(self): + def app(environ, start_response): + try: + raise ValueError("foo") + except: + ei = sys.exc_info() + status = '200 OK' + response_headers = [('Content-type', 'text/plain')] + start_response(status, response_headers) + yield "aaa" + start_response(status, response_headers, ei) + yield "bbb" + assert "Internal Server Error" in self._serve(app) + -- cgit v1.2.3 From ce1ef554561d55a414961993dcaf8f11000d1f22 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 19 Jun 2012 14:23:22 +1200 Subject: Adapt WSGI, convert test suite to nose. --- netlib/wsgi.py | 15 ++++++++++++++- test/test_wsgi.py | 24 +++++++++++++++--------- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 0608245c..3c3a8384 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,6 +1,19 @@ import cStringIO, urllib, time, sys, traceback import odict + +class ClientConn: + def __init__(self, address): + self.address = address + + +class Request: + def __init__(self, client_conn, scheme, method, path, headers, content): + self.scheme, self.method, self.path = scheme, method, path + self.headers, self.content = headers, content + self.client_conn = client_conn + + def date_time_string(): """Return the current date and time formatted for a message header.""" WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] @@ -85,7 +98,7 @@ class WSGIAdaptor: soc.write("HTTP/1.1 %s\r\n"%state["status"]) h = state["headers"] if 'server' not in h: - h["Server"] = [version.NAMEVERSION] + h["Server"] = [self.sversion] if 'date' not in h: h["Date"] = [date_time_string()] soc.write(str(h)) diff --git a/test/test_wsgi.py b/test/test_wsgi.py index c55ab1d8..7763b9e5 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -1,7 +1,13 @@ import cStringIO, sys import libpry -from netlib import wsgi -import tutils +from netlib import wsgi, odict + + +def treq(): + cc = wsgi.ClientConn(("127.0.0.1", 8888)) + h = odict.ODictCaseless() + h["test"] = ["value"] + return wsgi.Request(cc, "http", "GET", "/", h, "") class TestApp: @@ -16,10 +22,10 @@ class TestApp: return ['Hello', ' world!\n'] -class uWSGIAdaptor(libpry.AutoTree): +class TestWSGI: def test_make_environ(self): - w = wsgi.WSGIAdaptor(None, "foo", 80) - tr = tutils.treq() + w = wsgi.WSGIAdaptor(None, "foo", 80, "version") + tr = treq() assert w.make_environ(tr, None) tr.path = "/foo?bar=voing" @@ -28,8 +34,8 @@ class uWSGIAdaptor(libpry.AutoTree): def test_serve(self): ta = TestApp() - w = wsgi.WSGIAdaptor(ta, "foo", 80) - r = tutils.treq() + w = wsgi.WSGIAdaptor(ta, "foo", 80, "version") + r = treq() r.host = "foo" r.port = 80 @@ -43,8 +49,8 @@ class uWSGIAdaptor(libpry.AutoTree): assert "Server:" in val def _serve(self, app): - w = wsgi.WSGIAdaptor(app, "foo", 80) - r = tutils.treq() + w = wsgi.WSGIAdaptor(app, "foo", 80, "version") + r = treq() r.host = "foo" r.port = 80 wfile = cStringIO.StringIO() -- cgit v1.2.3 From 084be7684d5cb367d4b8995dbf01f177af6113bf Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 20 Jun 2012 10:51:02 +1200 Subject: Close socket on shutdown. --- netlib/tcp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 08ccba09..92a7e92f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -158,6 +158,7 @@ class TCPServer: def shutdown(self): self.__shutdown_request = True self.__is_shut_down.wait() + self.socket.close() self.handle_shutdown() def handle_error(self, request, client_address, fp=sys.stderr): -- cgit v1.2.3 From b7062007965ebd8c11e94bd28775ac6d6083eedf Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 20 Jun 2012 11:01:40 +1200 Subject: Drop default poll interval to 0.1s. --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 92a7e92f..5a942522 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -131,7 +131,7 @@ class TCPServer: self.handle_error(request, client_address) request.close() - def serve_forever(self, poll_interval=0.5): + def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() try: while not self.__shutdown_request: -- cgit v1.2.3 From 227e72abf4124cbf55328cd15be917b4af99367f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Jun 2012 13:49:57 +1200 Subject: README, setup.py, version --- README | 12 ++++++-- netlib/version.py | 4 +++ setup.py | 92 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 netlib/version.py create mode 100644 setup.py diff --git a/README b/README index 1c86738c..958a0302 100644 --- a/README +++ b/README @@ -1,2 +1,10 @@ -Netlib is a collection of common utility functions, used by the pathod and -mitmproxy projects. + +Netlib is a collection of network utility classes, used by pathod and mitmproxy +projects. It differs from other projects in some fundamental respects, because +both pathod and mitmproxy often need to violate standards. This means that +protocols are implemented as small, well-contained and flexible functions, and +servers are implemented to allow misbehaviour when needed. + +At this point, I have no plans to make netlib useful beyond mitmproxy and +pathod. Please get in touch if you think parts of netlib might have broader +utility. diff --git a/netlib/version.py b/netlib/version.py new file mode 100644 index 00000000..1c4a4b66 --- /dev/null +++ b/netlib/version.py @@ -0,0 +1,4 @@ +IVERSION = (0, 1) +VERSION = ".".join(str(i) for i in IVERSION) +NAME = "netlib" +NAMEVERSION = NAME + " " + VERSION diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..06ac8aea --- /dev/null +++ b/setup.py @@ -0,0 +1,92 @@ +from distutils.core import setup +import fnmatch, os.path +from netlib import version + +def _fnmatch(name, patternList): + for i in patternList: + if fnmatch.fnmatch(name, i): + return True + return False + + +def _splitAll(path): + parts = [] + h = path + while 1: + if not h: + break + h, t = os.path.split(h) + parts.append(t) + parts.reverse() + return parts + + +def findPackages(path, dataExclude=[]): + """ + Recursively find all packages and data directories rooted at path. Note + that only data _directories_ and their contents are returned - + non-Python files at module scope are not, and should be manually + included. + + dataExclude is a list of fnmatch-compatible expressions for files and + directories that should not be included in pakcage_data. + + Returns a (packages, package_data) tuple, ready to be passed to the + corresponding distutils.core.setup arguments. + """ + packages = [] + datadirs = [] + for root, dirs, files in os.walk(path, topdown=True): + if "__init__.py" in files: + p = _splitAll(root) + packages.append(".".join(p)) + else: + dirs[:] = [] + if packages: + datadirs.append(root) + + # Now we recurse into the data directories + package_data = {} + for i in datadirs: + if not _fnmatch(i, dataExclude): + parts = _splitAll(i) + module = ".".join(parts[:-1]) + acc = package_data.get(module, []) + for root, dirs, files in os.walk(i, topdown=True): + sub = os.path.join(*_splitAll(root)[1:]) + if not _fnmatch(sub, dataExclude): + for fname in files: + path = os.path.join(sub, fname) + if not _fnmatch(path, dataExclude): + acc.append(path) + else: + dirs[:] = [] + package_data[module] = acc + return packages, package_data + + +long_description = file("README").read() +packages, package_data = findPackages("libpathod") +setup( + name = "netlib", + version = version.VERSION, + description = "A collection of network utilities used by pathod and mitmproxy.", + long_description = long_description, + author = "Aldo Cortesi", + author_email = "aldo@corte.si", + url = "http://cortesi.github.com/netlib", + packages = packages, + package_data = package_data, + classifiers = [ + "License :: OSI Approved :: MIT License", + "Development Status :: 3 - Alpha", + "Operating System :: POSIX", + "Programming Language :: Python", + "Topic :: Internet", + "Topic :: Internet :: WWW/HTTP :: HTTP Servers", + "Topic :: Software Development :: Testing", + "Topic :: Software Development :: Testing :: Traffic Generation", + "Topic :: Internet :: WWW/HTTP", + ], + install_requires=["pyasn1>0.1.2", "pyopenssl>=0.12"], +) -- cgit v1.2.3 From 5cf6aeb926e0b3a1cad23a0b169b8dfa8536a22f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Jun 2012 13:56:17 +1200 Subject: protocol.py -> http.py --- netlib/http.py | 218 ++++++++++++++++++++++++++++++++++++++++++++++++++ netlib/protocol.py | 218 -------------------------------------------------- test/test_http.py | 163 +++++++++++++++++++++++++++++++++++++ test/test_protocol.py | 163 ------------------------------------- 4 files changed, 381 insertions(+), 381 deletions(-) create mode 100644 netlib/http.py delete mode 100644 netlib/protocol.py create mode 100644 test/test_http.py delete mode 100644 test/test_protocol.py diff --git a/netlib/http.py b/netlib/http.py new file mode 100644 index 00000000..c676c25c --- /dev/null +++ b/netlib/http.py @@ -0,0 +1,218 @@ +import string, urlparse + +class HttpError(Exception): + def __init__(self, code, msg): + self.code, self.msg = code, msg + + def __str__(self): + return "HttpError(%s, %s)"%(self.code, self.msg) + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + """ + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + if not scheme: + return None + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + return scheme, host, port, path + + +def read_headers(fp): + """ + Read a set of headers from a file pointer. Stop once a blank line + is reached. Return a ODictCaseless object. + """ + ret = [] + name = '' + while 1: + line = fp.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i+1:].strip() + ret.append([name, value]) + return ret + + +def read_chunked(fp, limit): + content = "" + total = 0 + while 1: + line = fp.readline(128) + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + continue + try: + length = int(line,16) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise HttpError(400, "Invalid chunked encoding length: %s"%line) + if not length: + break + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large."\ + " Limit is %s, chunked content length was at least %s"%(limit, total) + raise HttpError(509, msg) + content += fp.read(length) + line = fp.readline(5) + if line != '\r\n': + raise IOError("Malformed chunked body") + while 1: + line = fp.readline() + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + break + return content + + +def has_chunked_encoding(headers): + for i in headers["transfer-encoding"]: + for j in i.split(","): + if j.lower() == "chunked": + return True + return False + + +def read_http_body(rfile, headers, all, limit): + if has_chunked_encoding(headers): + content = read_chunked(rfile, limit) + elif "content-length" in headers: + try: + l = int(headers["content-length"][0]) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise HttpError(400, "Invalid content-length header: %s"%headers["content-length"]) + if limit is not None and l > limit: + raise HttpError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) + content = rfile.read(l) + elif all: + content = rfile.read(limit if limit else None) + else: + content = "" + return content + + +def parse_http_protocol(s): + if not s.startswith("HTTP/"): + return None + major, minor = s.split('/')[1].split('.') + major = int(major) + minor = int(minor) + return major, minor + + +def parse_init_connect(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if method != 'CONNECT': + return None + try: + host, port = url.split(":") + except ValueError: + return None + port = int(port) + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return host, port, httpversion + + +def parse_init_proxy(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + parts = parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, scheme, host, port, path, httpversion + + +def parse_init_http(line): + """ + Returns (method, url, httpversion) + """ + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if not (url.startswith("/") or url == "*"): + return None + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, url, httpversion + + +def request_connection_close(httpversion, headers): + """ + Checks the request to see if the client connection should be closed. + """ + if "connection" in headers: + for value in ",".join(headers['connection']).split(","): + value = value.strip() + if value == "close": + return True + elif value == "keep-alive": + return False + # HTTP 1.1 connections are assumed to be persistent + if httpversion == (1, 1): + return False + return True + + +def response_connection_close(httpversion, headers): + """ + Checks the response to see if the client connection should be closed. + """ + if request_connection_close(httpversion, headers): + return True + elif not has_chunked_encoding(headers) and "content-length" in headers: + return True + return False + + +def read_http_body_request(rfile, wfile, headers, httpversion, limit): + if "expect" in headers: + # FIXME: Should be forwarded upstream + expect = ",".join(headers['expect']) + if expect == "100-continue" and httpversion >= (1, 1): + wfile.write('HTTP/1.1 100 Continue\r\n') + wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) + wfile.write('\r\n') + del headers['expect'] + return read_http_body(rfile, headers, False, limit) diff --git a/netlib/protocol.py b/netlib/protocol.py deleted file mode 100644 index 55bcf440..00000000 --- a/netlib/protocol.py +++ /dev/null @@ -1,218 +0,0 @@ -import string, urlparse - -class ProtocolError(Exception): - def __init__(self, code, msg): - self.code, self.msg = code, msg - - def __str__(self): - return "ProtocolError(%s, %s)"%(self.code, self.msg) - - -def parse_url(url): - """ - Returns a (scheme, host, port, path) tuple, or None on error. - """ - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - if not scheme: - return None - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None - else: - host = netloc - if scheme == "https": - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path - return scheme, host, port, path - - -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line - is reached. Return a ODictCaseless object. - """ - ret = [] - name = '' - while 1: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i+1:].strip() - ret.append([name, value]) - return ret - - -def read_chunked(fp, limit): - content = "" - total = 0 - while 1: - line = fp.readline(128) - if line == "": - raise IOError("Connection closed") - if line == '\r\n' or line == '\n': - continue - try: - length = int(line,16) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise ProtocolError(400, "Invalid chunked encoding length: %s"%line) - if not length: - break - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large."\ - " Limit is %s, chunked content length was at least %s"%(limit, total) - raise ProtocolError(509, msg) - content += fp.read(length) - line = fp.readline(5) - if line != '\r\n': - raise IOError("Malformed chunked body") - while 1: - line = fp.readline() - if line == "": - raise IOError("Connection closed") - if line == '\r\n' or line == '\n': - break - return content - - -def has_chunked_encoding(headers): - for i in headers["transfer-encoding"]: - for j in i.split(","): - if j.lower() == "chunked": - return True - return False - - -def read_http_body(rfile, headers, all, limit): - if has_chunked_encoding(headers): - content = read_chunked(rfile, limit) - elif "content-length" in headers: - try: - l = int(headers["content-length"][0]) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"]) - if limit is not None and l > limit: - raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) - content = rfile.read(l) - elif all: - content = rfile.read(limit if limit else None) - else: - content = "" - return content - - -def parse_http_protocol(s): - if not s.startswith("HTTP/"): - return None - major, minor = s.split('/')[1].split('.') - major = int(major) - minor = int(minor) - return major, minor - - -def parse_init_connect(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - if method != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - port = int(port) - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return host, port, httpversion - - -def parse_init_proxy(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - parts = parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - try: - method, url, protocol = string.split(line) - except ValueError: - return None - if not (url.startswith("/") or url == "*"): - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - return method, url, httpversion - - -def request_connection_close(httpversion, headers): - """ - Checks the request to see if the client connection should be closed. - """ - if "connection" in headers: - for value in ",".join(headers['connection']).split(","): - value = value.strip() - if value == "close": - return True - elif value == "keep-alive": - return False - # HTTP 1.1 connections are assumed to be persistent - if httpversion == (1, 1): - return False - return True - - -def response_connection_close(httpversion, headers): - """ - Checks the response to see if the client connection should be closed. - """ - if request_connection_close(httpversion, headers): - return True - elif not has_chunked_encoding(headers) and "content-length" in headers: - return True - return False - - -def read_http_body_request(rfile, wfile, headers, httpversion, limit): - if "expect" in headers: - # FIXME: Should be forwarded upstream - expect = ",".join(headers['expect']) - if expect == "100-continue" and httpversion >= (1, 1): - wfile.write('HTTP/1.1 100 Continue\r\n') - wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) - wfile.write('\r\n') - del headers['expect'] - return read_http_body(rfile, headers, False, limit) diff --git a/test/test_http.py b/test/test_http.py new file mode 100644 index 00000000..d272f343 --- /dev/null +++ b/test/test_http.py @@ -0,0 +1,163 @@ +import cStringIO, textwrap +from netlib import http, odict +import tutils + +def test_has_chunked_encoding(): + h = odict.ODictCaseless() + assert not http.has_chunked_encoding(h) + h["transfer-encoding"] = ["chunked"] + assert http.has_chunked_encoding(h) + + +def test_read_chunked(): + s = cStringIO.StringIO("1\r\na\r\n0\r\n") + tutils.raises(IOError, http.read_chunked, s, None) + + s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") + assert http.read_chunked(s, None) == "a" + + s = cStringIO.StringIO("\r\n") + tutils.raises(IOError, http.read_chunked, s, None) + + s = cStringIO.StringIO("1\r\nfoo") + tutils.raises(IOError, http.read_chunked, s, None) + + s = cStringIO.StringIO("foo\r\nfoo") + tutils.raises(http.HttpError, http.read_chunked, s, None) + + +def test_request_connection_close(): + h = odict.ODictCaseless() + assert http.request_connection_close((1, 0), h) + assert not http.request_connection_close((1, 1), h) + + h["connection"] = ["keep-alive"] + assert not http.request_connection_close((1, 1), h) + + +def test_read_http_body(): + h = odict.ODict() + s = cStringIO.StringIO("testing") + assert http.read_http_body(s, h, False, None) == "" + + h["content-length"] = ["foo"] + s = cStringIO.StringIO("testing") + tutils.raises(http.HttpError, http.read_http_body, s, h, False, None) + + h["content-length"] = [5] + s = cStringIO.StringIO("testing") + assert len(http.read_http_body(s, h, False, None)) == 5 + s = cStringIO.StringIO("testing") + tutils.raises(http.HttpError, http.read_http_body, s, h, False, 4) + + h = odict.ODict() + s = cStringIO.StringIO("testing") + assert len(http.read_http_body(s, h, True, 4)) == 4 + s = cStringIO.StringIO("testing") + assert len(http.read_http_body(s, h, True, 100)) == 7 + +def test_parse_http_protocol(): + assert http.parse_http_protocol("HTTP/1.1") == (1, 1) + assert http.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not http.parse_http_protocol("foo/0.0") + + +def test_parse_init_connect(): + assert http.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not http.parse_init_connect("bogus") + assert not http.parse_init_connect("GET host.com:443 HTTP/1.0") + assert not http.parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not http.parse_init_connect("CONNECT host.com:443 foo/1.0") + + +def test_prase_init_proxy(): + u = "GET http://foo.com:8888/test HTTP/1.1" + m, s, h, po, pa, httpversion = http.parse_init_proxy(u) + assert m == "GET" + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + assert httpversion == (1, 1) + + assert not http.parse_init_proxy("invalid") + assert not http.parse_init_proxy("GET invalid HTTP/1.1") + assert not http.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + + +def test_parse_init_http(): + u = "GET /test HTTP/1.1" + m, u, httpversion= http.parse_init_http(u) + assert m == "GET" + assert u == "/test" + assert httpversion == (1, 1) + + assert not http.parse_init_http("invalid") + assert not http.parse_init_http("GET invalid HTTP/1.1") + assert not http.parse_init_http("GET /test foo/1.1") + + +class TestReadHeaders: + def test_read_simple(self): + data = """ + Header: one + Header2: two + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = http.read_headers(s) + assert h == [["Header", "one"], ["Header2", "two"]] + + def test_read_multi(self): + data = """ + Header: one + Header: two + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = http.read_headers(s) + assert h == [["Header", "one"], ["Header", "two"]] + + def test_read_continued(self): + data = """ + Header: one + \ttwo + Header2: three + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = http.read_headers(s) + assert h == [["Header", "one\r\n two"], ["Header2", "three"]] + + +def test_parse_url(): + assert not http.parse_url("") + + u = "http://foo.com:8888/test" + s, h, po, pa = http.parse_url(u) + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + + s, h, po, pa = http.parse_url("http://foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = http.parse_url("http://foo") + assert pa == "/" + + s, h, po, pa = http.parse_url("https://foo") + assert po == 443 + + assert not http.parse_url("https://foo:bar") + assert not http.parse_url("https://foo:") + diff --git a/test/test_protocol.py b/test/test_protocol.py deleted file mode 100644 index 028faadd..00000000 --- a/test/test_protocol.py +++ /dev/null @@ -1,163 +0,0 @@ -import cStringIO, textwrap -from netlib import protocol, odict -import tutils - -def test_has_chunked_encoding(): - h = odict.ODictCaseless() - assert not protocol.has_chunked_encoding(h) - h["transfer-encoding"] = ["chunked"] - assert protocol.has_chunked_encoding(h) - - -def test_read_chunked(): - s = cStringIO.StringIO("1\r\na\r\n0\r\n") - tutils.raises(IOError, protocol.read_chunked, s, None) - - s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert protocol.read_chunked(s, None) == "a" - - s = cStringIO.StringIO("\r\n") - tutils.raises(IOError, protocol.read_chunked, s, None) - - s = cStringIO.StringIO("1\r\nfoo") - tutils.raises(IOError, protocol.read_chunked, s, None) - - s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises(protocol.ProtocolError, protocol.read_chunked, s, None) - - -def test_request_connection_close(): - h = odict.ODictCaseless() - assert protocol.request_connection_close((1, 0), h) - assert not protocol.request_connection_close((1, 1), h) - - h["connection"] = ["keep-alive"] - assert not protocol.request_connection_close((1, 1), h) - - -def test_read_http_body(): - h = odict.ODict() - s = cStringIO.StringIO("testing") - assert protocol.read_http_body(s, h, False, None) == "" - - h["content-length"] = ["foo"] - s = cStringIO.StringIO("testing") - tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, None) - - h["content-length"] = [5] - s = cStringIO.StringIO("testing") - assert len(protocol.read_http_body(s, h, False, None)) == 5 - s = cStringIO.StringIO("testing") - tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, 4) - - h = odict.ODict() - s = cStringIO.StringIO("testing") - assert len(protocol.read_http_body(s, h, True, 4)) == 4 - s = cStringIO.StringIO("testing") - assert len(protocol.read_http_body(s, h, True, 100)) == 7 - -def test_parse_http_protocol(): - assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) - assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not protocol.parse_http_protocol("foo/0.0") - - -def test_parse_init_connect(): - assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("bogus") - assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") - - -def test_prase_init_proxy(): - u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) - assert m == "GET" - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - assert httpversion == (1, 1) - - assert not protocol.parse_init_proxy("invalid") - assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") - assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") - - -def test_parse_init_http(): - u = "GET /test HTTP/1.1" - m, u, httpversion= protocol.parse_init_http(u) - assert m == "GET" - assert u == "/test" - assert httpversion == (1, 1) - - assert not protocol.parse_init_http("invalid") - assert not protocol.parse_init_http("GET invalid HTTP/1.1") - assert not protocol.parse_init_http("GET /test foo/1.1") - - -class TestReadHeaders: - def test_read_simple(self): - data = """ - Header: one - Header2: two - \r\n - """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = protocol.read_headers(s) - assert h == [["Header", "one"], ["Header2", "two"]] - - def test_read_multi(self): - data = """ - Header: one - Header: two - \r\n - """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = protocol.read_headers(s) - assert h == [["Header", "one"], ["Header", "two"]] - - def test_read_continued(self): - data = """ - Header: one - \ttwo - Header2: three - \r\n - """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = protocol.read_headers(s) - assert h == [["Header", "one\r\n two"], ["Header2", "three"]] - - -def test_parse_url(): - assert not protocol.parse_url("") - - u = "http://foo.com:8888/test" - s, h, po, pa = protocol.parse_url(u) - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - - s, h, po, pa = protocol.parse_url("http://foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = protocol.parse_url("http://foo") - assert pa == "/" - - s, h, po, pa = protocol.parse_url("https://foo") - assert po == 443 - - assert not protocol.parse_url("https://foo:bar") - assert not protocol.parse_url("https://foo:") - -- cgit v1.2.3 From 1263221ddd06da12f3f1f5f9c3e55858b304ce54 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Jun 2012 15:07:42 +1200 Subject: 100% testcoverage for netlib.http --- netlib/http.py | 91 +++++++++++++++++++++++++++++++++++-------------------- test/test_http.py | 73 ++++++++++++++++++++++++++++++++++++-------- 2 files changed, 118 insertions(+), 46 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index c676c25c..da43d070 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -57,36 +57,40 @@ def read_headers(fp): return ret -def read_chunked(fp, limit): +def read_chunked(code, fp, limit): + """ + Read a chunked HTTP body. + + May raise HttpError. + """ content = "" total = 0 while 1: line = fp.readline(128) if line == "": - raise IOError("Connection closed") - if line == '\r\n' or line == '\n': - continue - try: - length = int(line,16) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise HttpError(400, "Invalid chunked encoding length: %s"%line) - if not length: - break - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large."\ - " Limit is %s, chunked content length was at least %s"%(limit, total) - raise HttpError(509, msg) - content += fp.read(length) - line = fp.readline(5) - if line != '\r\n': - raise IOError("Malformed chunked body") + raise HttpError(code, "Connection closed prematurely") + if line != '\r\n' and line != '\n': + try: + length = int(line, 16) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise HttpError(code, "Invalid chunked encoding length: %s"%line) + if not length: + break + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large."\ + " Limit is %s, chunked content length was at least %s"%(limit, total) + raise HttpError(code, msg) + content += fp.read(length) + line = fp.readline(5) + if line != '\r\n': + raise HttpError(code, "Malformed chunked body") while 1: line = fp.readline() if line == "": - raise IOError("Connection closed") + raise HttpError(code, "Connection closed prematurely") if line == '\r\n' or line == '\n': break return content @@ -100,18 +104,27 @@ def has_chunked_encoding(headers): return False -def read_http_body(rfile, headers, all, limit): +def read_http_body(code, rfile, headers, all, limit): + """ + Read an HTTP body: + + code: The HTTP error code to be used when raising HttpError + rfile: A file descriptor to read from + headers: An ODictCaseless object + all: Should we read all data? + limit: Size limit. + """ if has_chunked_encoding(headers): - content = read_chunked(rfile, limit) + content = read_chunked(code, rfile, limit) elif "content-length" in headers: try: l = int(headers["content-length"][0]) except ValueError: # FIXME: Not strictly correct - this could be from the server, in which # case we should send a 502. - raise HttpError(400, "Invalid content-length header: %s"%headers["content-length"]) + raise HttpError(code, "Invalid content-length header: %s"%headers["content-length"]) if limit is not None and l > limit: - raise HttpError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) + raise HttpError(code, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) content = rfile.read(l) elif all: content = rfile.read(limit if limit else None) @@ -121,6 +134,10 @@ def read_http_body(rfile, headers, all, limit): def parse_http_protocol(s): + """ + Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or + None. + """ if not s.startswith("HTTP/"): return None major, minor = s.split('/')[1].split('.') @@ -201,18 +218,26 @@ def response_connection_close(httpversion, headers): """ if request_connection_close(httpversion, headers): return True - elif not has_chunked_encoding(headers) and "content-length" in headers: - return True - return False + elif (not has_chunked_encoding(headers)) and "content-length" in headers: + return False + return True def read_http_body_request(rfile, wfile, headers, httpversion, limit): + """ + Read the HTTP body from a client request. + """ if "expect" in headers: # FIXME: Should be forwarded upstream - expect = ",".join(headers['expect']) - if expect == "100-continue" and httpversion >= (1, 1): + if "100-continue" in headers['expect'] and httpversion >= (1, 1): wfile.write('HTTP/1.1 100 Continue\r\n') - wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) wfile.write('\r\n') del headers['expect'] - return read_http_body(rfile, headers, False, limit) + return read_http_body(400, rfile, headers, False, limit) + + +def read_http_body_response(rfile, headers, False, limit): + """ + Read the HTTP body from a server response. + """ + return read_http_body(500, rfile, headers, False, limit) diff --git a/test/test_http.py b/test/test_http.py index d272f343..bf525de5 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -2,6 +2,11 @@ import cStringIO, textwrap from netlib import http, odict import tutils +def test_httperror(): + e = http.HttpError(404, "Not found") + assert str(e) + + def test_has_chunked_encoding(): h = odict.ODictCaseless() assert not http.has_chunked_encoding(h) @@ -11,19 +16,25 @@ def test_has_chunked_encoding(): def test_read_chunked(): s = cStringIO.StringIO("1\r\na\r\n0\r\n") - tutils.raises(IOError, http.read_chunked, s, None) + tutils.raises("closed prematurely", http.read_chunked, 500, s, None) s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert http.read_chunked(s, None) == "a" + assert http.read_chunked(500, s, None) == "a" + + s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") + assert http.read_chunked(500, s, None) == "a" s = cStringIO.StringIO("\r\n") - tutils.raises(IOError, http.read_chunked, s, None) + tutils.raises("closed prematurely", http.read_chunked, 500, s, None) s = cStringIO.StringIO("1\r\nfoo") - tutils.raises(IOError, http.read_chunked, s, None) + tutils.raises("malformed chunked body", http.read_chunked, 500, s, None) s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises(http.HttpError, http.read_chunked, s, None) + tutils.raises(http.HttpError, http.read_chunked, 500, s, None) + + s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") + tutils.raises("too large", http.read_chunked, 500, s, 2) def test_request_connection_close(): @@ -34,27 +45,63 @@ def test_request_connection_close(): h["connection"] = ["keep-alive"] assert not http.request_connection_close((1, 1), h) + h["connection"] = ["close"] + assert http.request_connection_close((1, 1), h) + + +def test_response_connection_close(): + h = odict.ODictCaseless() + assert http.response_connection_close((1, 1), h) + + h["content-length"] = [10] + assert not http.response_connection_close((1, 1), h) + + h["connection"] = ["close"] + assert http.response_connection_close((1, 1), h) + + +def test_read_http_body_response(): + h = odict.ODictCaseless() + h["content-length"] = [7] + s = cStringIO.StringIO("testing") + assert http.read_http_body_response(s, h, False, None) == "testing" + + +def test_read_http_body_request(): + h = odict.ODictCaseless() + h["expect"] = ["100-continue"] + r = cStringIO.StringIO("testing") + w = cStringIO.StringIO() + assert http.read_http_body_request(r, w, h, (1, 1), None) == "" + assert "100 Continue" in w.getvalue() + def test_read_http_body(): - h = odict.ODict() + h = odict.ODictCaseless() s = cStringIO.StringIO("testing") - assert http.read_http_body(s, h, False, None) == "" + assert http.read_http_body(500, s, h, False, None) == "" h["content-length"] = ["foo"] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, False, None) + tutils.raises(http.HttpError, http.read_http_body, 500, s, h, False, None) h["content-length"] = [5] s = cStringIO.StringIO("testing") - assert len(http.read_http_body(s, h, False, None)) == 5 + assert len(http.read_http_body(500, s, h, False, None)) == 5 s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, False, 4) + tutils.raises(http.HttpError, http.read_http_body, 500, s, h, False, 4) - h = odict.ODict() + h = odict.ODictCaseless() s = cStringIO.StringIO("testing") - assert len(http.read_http_body(s, h, True, 4)) == 4 + assert len(http.read_http_body(500, s, h, True, 4)) == 4 s = cStringIO.StringIO("testing") - assert len(http.read_http_body(s, h, True, 100)) == 7 + assert len(http.read_http_body(500, s, h, True, 100)) == 7 + + h = odict.ODictCaseless() + h["transfer-encoding"] = ["chunked"] + s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") + assert http.read_http_body(500, s, h, True, 100) == "aaaaa" + def test_parse_http_protocol(): assert http.parse_http_protocol("HTTP/1.1") == (1, 1) -- cgit v1.2.3 From 171de05d8ea4a31b0f97c38206b44826364d7693 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 23 Jun 2012 18:34:51 +1200 Subject: Add http_status.py --- netlib/http_status.py | 103 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 netlib/http_status.py diff --git a/netlib/http_status.py b/netlib/http_status.py new file mode 100644 index 00000000..9f3f7e15 --- /dev/null +++ b/netlib/http_status.py @@ -0,0 +1,103 @@ + +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 + +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 + +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 +REQUESTED_RANGE_NOT_SATISFIABLE = 416 +EXPECTATION_FAILED = 417 + +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510 + +RESPONSES = { + # 100 + CONTINUE: "Continue", + SWITCHING: "Switching Protocols", + + # 200 + OK: "OK", + CREATED: "Created", + ACCEPTED: "Accepted", + NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", + NO_CONTENT: "No Content", + RESET_CONTENT: "Reset Content.", + PARTIAL_CONTENT: "Partial Content", + MULTI_STATUS: "Multi-Status", + + # 300 + MULTIPLE_CHOICE: "Multiple Choices", + MOVED_PERMANENTLY: "Moved Permanently", + FOUND: "Found", + SEE_OTHER: "See Other", + NOT_MODIFIED: "Not Modified", + USE_PROXY: "Use Proxy", + # 306 not defined?? + TEMPORARY_REDIRECT: "Temporary Redirect", + + # 400 + BAD_REQUEST: "Bad Request", + UNAUTHORIZED: "Unauthorized", + PAYMENT_REQUIRED: "Payment Required", + FORBIDDEN: "Forbidden", + NOT_FOUND: "Not Found", + NOT_ALLOWED: "Method Not Allowed", + NOT_ACCEPTABLE: "Not Acceptable", + PROXY_AUTH_REQUIRED: "Proxy Authentication Required", + REQUEST_TIMEOUT: "Request Time-out", + CONFLICT: "Conflict", + GONE: "Gone", + LENGTH_REQUIRED: "Length Required", + PRECONDITION_FAILED: "Precondition Failed", + REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", + REQUEST_URI_TOO_LONG: "Request-URI Too Long", + UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", + REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", + EXPECTATION_FAILED: "Expectation Failed", + + # 500 + INTERNAL_SERVER_ERROR: "Internal Server Error", + NOT_IMPLEMENTED: "Not Implemented", + BAD_GATEWAY: "Bad Gateway", + SERVICE_UNAVAILABLE: "Service Unavailable", + GATEWAY_TIMEOUT: "Gateway Time-out", + HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", + INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", + NOT_EXTENDED: "Not Extended" +} -- cgit v1.2.3 From 0de765f3600bfa977cffb48da1efa26f2e3236f3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Jun 2012 21:49:23 +1200 Subject: Make read_headers return an ODictCaseless object. --- netlib/http.py | 5 +++-- test/test_http.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index da43d070..1f5f8901 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,4 +1,5 @@ import string, urlparse +import odict class HttpError(Exception): def __init__(self, code, msg): @@ -54,7 +55,7 @@ def read_headers(fp): name = line[:i] value = line[i+1:].strip() ret.append([name, value]) - return ret + return odict.ODictCaseless(ret) def read_chunked(code, fp, limit): @@ -107,7 +108,7 @@ def has_chunked_encoding(headers): def read_http_body(code, rfile, headers, all, limit): """ Read an HTTP body: - + code: The HTTP error code to be used when raising HttpError rfile: A file descriptor to read from headers: An ODictCaseless object diff --git a/test/test_http.py b/test/test_http.py index bf525de5..3546fec6 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -155,7 +155,7 @@ class TestReadHeaders: data = data.strip() s = cStringIO.StringIO(data) h = http.read_headers(s) - assert h == [["Header", "one"], ["Header2", "two"]] + assert h.lst == [["Header", "one"], ["Header2", "two"]] def test_read_multi(self): data = """ @@ -167,7 +167,7 @@ class TestReadHeaders: data = data.strip() s = cStringIO.StringIO(data) h = http.read_headers(s) - assert h == [["Header", "one"], ["Header", "two"]] + assert h.lst == [["Header", "one"], ["Header", "two"]] def test_read_continued(self): data = """ @@ -180,7 +180,7 @@ class TestReadHeaders: data = data.strip() s = cStringIO.StringIO(data) h = http.read_headers(s) - assert h == [["Header", "one\r\n two"], ["Header2", "three"]] + assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] def test_parse_url(): -- cgit v1.2.3 From 5988b65419d6d498b760876b47e4bd627b2467f6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Jun 2012 22:45:40 +1200 Subject: Add and unit test http.read_response --- netlib/http.py | 40 ++++++++++++++++++++++++++++++++++++---- test/test_http.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 1f5f8901..f0982b6d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -128,7 +128,7 @@ def read_http_body(code, rfile, headers, all, limit): raise HttpError(code, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) content = rfile.read(l) elif all: - content = rfile.read(limit if limit else None) + content = rfile.read(limit if limit else -1) else: content = "" return content @@ -141,7 +141,10 @@ def parse_http_protocol(s): """ if not s.startswith("HTTP/"): return None - major, minor = s.split('/')[1].split('.') + _, version = s.split('/') + if "." not in version: + return None + major, minor = version.split('.') major = int(major) minor = int(minor) return major, minor @@ -237,8 +240,37 @@ def read_http_body_request(rfile, wfile, headers, httpversion, limit): return read_http_body(400, rfile, headers, False, limit) -def read_http_body_response(rfile, headers, False, limit): +def read_http_body_response(rfile, headers, all, limit): """ Read the HTTP body from a server response. """ - return read_http_body(500, rfile, headers, False, limit) + return read_http_body(500, rfile, headers, all, limit) + + +def read_response(rfile, method, body_size_limit): + line = rfile.readline() + if line == "\r\n" or line == "\n": # Possible leftover from previous message + line = rfile.readline() + if not line: + raise HttpError(502, "Blank server response.") + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if not len(parts) == 3: + raise HttpError(502, "Invalid server response: %s."%line) + proto, code, msg = parts + httpversion = parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version: %s."%httpversion) + try: + code = int(code) + except ValueError: + raise HttpError(502, "Invalid server response: %s."%line) + headers = read_headers(rfile) + if code >= 100 and code <= 199: + return read_response(rfile, method, body_size_limit) + if method == "HEAD" or code == 204 or code == 304: + content = "" + else: + content = read_http_body_response(rfile, headers, True, body_size_limit) + return httpversion, code, msg, headers, content diff --git a/test/test_http.py b/test/test_http.py index 3546fec6..b7ee6697 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -107,6 +107,7 @@ def test_parse_http_protocol(): assert http.parse_http_protocol("HTTP/1.1") == (1, 1) assert http.parse_http_protocol("HTTP/0.0") == (0, 0) assert not http.parse_http_protocol("foo/0.0") + assert not http.parse_http_protocol("HTTP/x") def test_parse_init_connect(): @@ -183,6 +184,47 @@ class TestReadHeaders: assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] +def test_read_response(): + def tst(data, method, limit): + data = textwrap.dedent(data) + r = cStringIO.StringIO(data) + return http.read_response(r, method, limit) + + tutils.raises("blank server response", tst, "", "GET", None) + tutils.raises("invalid server response", tst, "foo", "GET", None) + data = """ + HTTP/1.1 200 OK + """ + assert tst(data, "GET", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), '') + data = """ + HTTP/1.1 200 + """ + assert tst(data, "GET", None) == ((1, 1), 200, '', odict.ODictCaseless(), '') + data = """ + HTTP/x 200 OK + """ + tutils.raises("invalid http version", tst, data, "GET", None) + data = """ + HTTP/1.1 xx OK + """ + tutils.raises("invalid server response", tst, data, "GET", None) + + data = """ + HTTP/1.1 100 CONTINUE + + HTTP/1.1 200 OK + """ + assert tst(data, "GET", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), '') + + data = """ + HTTP/1.1 200 OK + + foo + """ + assert tst(data, "GET", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), 'foo\n') + assert tst(data, "HEAD", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), '') + + def test_parse_url(): assert not http.parse_url("") -- cgit v1.2.3 From 820ac5152e02108f9d4e2226da1ba4369f67a4df Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Jun 2012 22:57:09 +1200 Subject: WSGI SERVER_PORT should be a string. --- netlib/wsgi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 3c3a8384..755bea5a 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -55,7 +55,7 @@ class WSGIAdaptor: 'CONTENT_TYPE': request.headers.get('Content-Type', [''])[0], 'CONTENT_LENGTH': request.headers.get('Content-Length', [''])[0], 'SERVER_NAME': self.domain, - 'SERVER_PORT': self.port, + 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. 'SERVER_PROTOCOL': "HTTP/1.1", } -- cgit v1.2.3 From 7d01d5c7970c2b1b86bc6c98be5dfcaa145b1d53 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Jun 2012 23:13:09 +1200 Subject: Don't read all from server by default. This can cause us to hang waiting for data. More research is needed to establish the right course of action here. --- netlib/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/http.py b/netlib/http.py index f0982b6d..150995dd 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -272,5 +272,5 @@ def read_response(rfile, method, body_size_limit): if method == "HEAD" or code == 204 or code == 304: content = "" else: - content = read_http_body_response(rfile, headers, True, body_size_limit) + content = read_http_body_response(rfile, headers, False, body_size_limit) return httpversion, code, msg, headers, content -- cgit v1.2.3 From 1662d6d5724cb39080bfad98dd515ea66e157c25 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Jun 2012 23:16:06 +1200 Subject: Repair test suite. --- test/test_http.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_http.py b/test/test_http.py index b7ee6697..206fc4df 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -218,11 +218,12 @@ def test_read_response(): data = """ HTTP/1.1 200 OK + Content-Length: 3 foo """ - assert tst(data, "GET", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), 'foo\n') - assert tst(data, "HEAD", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), '') + assert tst(data, "GET", None)[4] == 'foo' + assert tst(data, "HEAD", None)[4] == '' def test_parse_url(): -- cgit v1.2.3 From 8f0754b9c48176aa479dc7701c42b26e115163a5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Jun 2012 11:00:39 +1200 Subject: SSL tests, plus some self-signed test certificates. --- netlib/tcp.py | 4 ++-- test/data/server.crt | 14 ++++++++++++++ test/data/server.key | 15 +++++++++++++++ test/test_tcp.py | 32 +++++++++++++++++++++++--------- 4 files changed, 54 insertions(+), 11 deletions(-) create mode 100644 test/data/server.crt create mode 100644 test/data/server.key diff --git a/netlib/tcp.py b/netlib/tcp.py index 5a942522..007cf3a5 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,8 +48,8 @@ class FileLike: class TCPClient: - def __init__(self, ssl, host, port, clientcert): - self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert + def __init__(self, ssl, host, port, clientcert, sni): + self.ssl, self.host, self.port, self.clientcert, self.sni = ssl, host, port, clientcert, sni self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.connect() diff --git a/test/data/server.crt b/test/data/server.crt new file mode 100644 index 00000000..68f61bac --- /dev/null +++ b/test/data/server.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICOzCCAaQCCQDC7f5GsEpo9jANBgkqhkiG9w0BAQUFADBiMQswCQYDVQQGEwJO +WjEOMAwGA1UECBMFT3RhZ28xEDAOBgNVBAcTB0R1bmVkaW4xDzANBgNVBAoTBm5l +dGxpYjEPMA0GA1UECxMGbmV0bGliMQ8wDQYDVQQDEwZuZXRsaWIwHhcNMTIwNjI0 +MjI0MTU0WhcNMjIwNjIyMjI0MTU0WjBiMQswCQYDVQQGEwJOWjEOMAwGA1UECBMF +T3RhZ28xEDAOBgNVBAcTB0R1bmVkaW4xDzANBgNVBAoTBm5ldGxpYjEPMA0GA1UE +CxMGbmV0bGliMQ8wDQYDVQQDEwZuZXRsaWIwgZ8wDQYJKoZIhvcNAQEBBQADgY0A +MIGJAoGBALJSVEl9y3QUSYuXTH0UjBOPQgS0nHmNWej9hjqnA0KWvEnGY+c6yQeP +/rmwswlKw1iVV5o8kRK9Wej88YWQl/hl/xruyeJgGic0+yqY/FcueZxRudwBcWu2 +7+46aEftwLLRF0GwHZxX/HwWME+TcCXGpXGSG2qs921M4iVeBn5hAgMBAAEwDQYJ +KoZIhvcNAQEFBQADgYEAODZCihEv2yr8zmmQZDrfqg2ChxAoOXWF5+W2F/0LAUBf +2bHP+K4XE6BJWmadX1xKngj7SWrhmmTDp1gBAvXURoDaScOkB1iOCOHoIyalscTR +0FvSHKqFF8fgSlfqS6eYaSbXU3zQolvwP+URzIVnGDqgQCWPtjMqLD3Kd5tuwos= +-----END CERTIFICATE----- diff --git a/test/data/server.key b/test/data/server.key new file mode 100644 index 00000000..b1b658ab --- /dev/null +++ b/test/data/server.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQCyUlRJfct0FEmLl0x9FIwTj0IEtJx5jVno/YY6pwNClrxJxmPn +OskHj/65sLMJSsNYlVeaPJESvVno/PGFkJf4Zf8a7sniYBonNPsqmPxXLnmcUbnc +AXFrtu/uOmhH7cCy0RdBsB2cV/x8FjBPk3AlxqVxkhtqrPdtTOIlXgZ+YQIDAQAB +AoGAQEpGcSiVTYhy64zk2sOprPOdTa0ALSK1I7cjycmk90D5KXAJXLho+f0ETVZT +dioqO6m8J7NmamcyHznyqcDzyNRqD2hEBDGVRJWmpOjIER/JwWLNNbpeVjsMHV8I +40P5rZMOhBPYlwECSC5NtMwaN472fyGNNze8u37IZKiER/ECQQDe1iY5AG3CgkP3 +tEZB3Vtzcn4PoOr3Utyn1YER34lPqAmeAsWUhmAVEfR3N1HDe1VFD9s2BidhBn1a +/Bgqxz4DAkEAzNw0m+uO0WkD7aEYRBW7SbXCX+3xsbVToIWC1jXFG+XDzSWn++c1 +DMXEElzEJxPDA+FzQUvRTml4P92bTAbGywJAS9H7wWtm7Ubbj33UZfbGdhqfz/uF +109naufXedhgZS0c0JnK1oV+Tc0FLEczV9swIUaK5O/lGDtYDcw3AN84NwJBAIw5 +/1jrOOtm8uVp6+5O4dBmthJsEZEPCZtLSG/Qhoe+EvUN3Zq0fL+tb7USAsKs6ERz +wizj9PWzhDhTPMYhrVkCQGIponZHx6VqiFyLgYUH9+gDTjBhYyI+6yMTYzcRweyL +9Suc2NkS3X2Lp+wCjvVZdwGtStp6Vo8z02b3giIsAIY= +-----END RSA PRIVATE KEY----- diff --git a/test/test_tcp.py b/test/test_tcp.py index d7d4483e..9aebb2f0 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -27,6 +27,11 @@ class ServerTestBase: class THandler(tcp.BaseHandler): def handle(self): + if self.server.ssl: + self.convert_to_ssl( + tutils.test_data.path("data/server.crt"), + tutils.test_data.path("data/server.key"), + ) v = self.rfile.readline() if v.startswith("echo"): self.wfile.write(v) @@ -36,9 +41,9 @@ class THandler(tcp.BaseHandler): class TServer(tcp.TCPServer): - def __init__(self, addr, q): + def __init__(self, addr, ssl, q): tcp.TCPServer.__init__(self, addr) - self.q = q + self.ssl, self.q = ssl, q def handle_connection(self, request, client_address): THandler(request, client_address, self) @@ -53,28 +58,37 @@ class TestServer(ServerTestBase): @classmethod def makeserver(cls): cls.q = Queue.Queue() - s = TServer(("127.0.0.1", 0), cls.q) + s = TServer(("127.0.0.1", 0), False, cls.q) cls.port = s.port return s def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + c = tcp.TCPClient(False, "127.0.0.1", self.port, None, None) c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval - def test_error(self): - testval = "error!\n" - c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + +class TestServerSSL(ServerTestBase): + @classmethod + def makeserver(cls): + cls.q = Queue.Queue() + s = TServer(("127.0.0.1", 0), True, cls.q) + cls.port = s.port + return s + + def test_echo(self): + c = tcp.TCPClient(True, "127.0.0.1", self.port, None, None) + testval = "echo!\n" c.wfile.write(testval) c.wfile.flush() - assert "Testing an error" in self.q.get() + assert c.rfile.readline() == testval class TestTCPClient: def test_conerr(self): - tutils.raises(tcp.NetLibError, tcp.TCPClient, False, "127.0.0.1", 0, None) + tutils.raises(tcp.NetLibError, tcp.TCPClient, False, "127.0.0.1", 0, None, None) class TestFileLike: -- cgit v1.2.3 From f3237503a77258d37b67c5716ac178cbfd7ffe1b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Jun 2012 11:23:04 +1200 Subject: Don't connect during __init__ methods for either client or server. This means we now need to do these things explicitly at the caller. --- netlib/tcp.py | 10 +++++----- test/test_tcp.py | 13 +++++++++---- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 007cf3a5..25e83e07 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,11 +48,10 @@ class FileLike: class TCPClient: - def __init__(self, ssl, host, port, clientcert, sni): - self.ssl, self.host, self.port, self.clientcert, self.sni = ssl, host, port, clientcert, sni + def __init__(self, ssl, host, port, clientcert): + self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert self.connection, self.rfile, self.wfile = None, None, None self.cert = None - self.connect() def connect(self): try: @@ -75,6 +74,9 @@ class TCPClient: class BaseHandler: + """ + The instantiator is expected to call the handle() and finish() methods. + """ rbufsize = -1 wbufsize = 0 def __init__(self, connection, client_address, server): @@ -84,8 +86,6 @@ class BaseHandler: self.client_address = client_address self.server = server - self.handle() - self.finish() def convert_to_ssl(self, cert, key): ctx = SSL.Context(SSL.SSLv23_METHOD) diff --git a/test/test_tcp.py b/test/test_tcp.py index 9aebb2f0..1bad9a04 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -46,7 +46,9 @@ class TServer(tcp.TCPServer): self.ssl, self.q = ssl, q def handle_connection(self, request, client_address): - THandler(request, client_address, self) + h = THandler(request, client_address, self) + h.handle() + h.finish() def handle_error(self, request, client_address): s = cStringIO.StringIO() @@ -64,7 +66,8 @@ class TestServer(ServerTestBase): def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient(False, "127.0.0.1", self.port, None, None) + c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + c.connect() c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval @@ -79,7 +82,8 @@ class TestServerSSL(ServerTestBase): return s def test_echo(self): - c = tcp.TCPClient(True, "127.0.0.1", self.port, None, None) + c = tcp.TCPClient(True, "127.0.0.1", self.port, None) + c.connect() testval = "echo!\n" c.wfile.write(testval) c.wfile.flush() @@ -88,7 +92,8 @@ class TestServerSSL(ServerTestBase): class TestTCPClient: def test_conerr(self): - tutils.raises(tcp.NetLibError, tcp.TCPClient, False, "127.0.0.1", 0, None, None) + c = tcp.TCPClient(True, "127.0.0.1", 0, None) + tutils.raises(tcp.NetLibError, c.connect) class TestFileLike: -- cgit v1.2.3 From 47f862ae278c61df9bd1b62ec291a954fc0707ea Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Jun 2012 11:34:10 +1200 Subject: Add a finished flag to BaseHandler, and catch an extra OpenSSL exception. --- netlib/tcp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 25e83e07..91b0c742 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -20,7 +20,7 @@ class FileLike: while len(result) < length: try: data = self.o.read(length) - except SSL.ZeroReturnError: + except (SSL.ZeroReturnError, SSL.SysCallError): break if not data: break @@ -86,6 +86,7 @@ class BaseHandler: self.client_address = client_address self.server = server + self.finished = False def convert_to_ssl(self, cert, key): ctx = SSL.Context(SSL.SSLv23_METHOD) @@ -97,6 +98,7 @@ class BaseHandler: self.wfile = FileLike(self.connection) def finish(self): + self.finished = True try: if not getattr(self.wfile, "closed", False): self.wfile.flush() -- cgit v1.2.3 From 353efec7ce032a447efbba60c5ccea441bc573fb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Jun 2012 14:42:15 +1200 Subject: Improve TCPClient interface. - Don't pass SSL parameters on instantiation. - Add a convert_to_ssl method analogous to that in TCPServer. --- netlib/tcp.py | 31 ++++++++++++++++--------------- test/test_tcp.py | 7 ++++--- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 91b0c742..3c5c89b7 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,29 +48,30 @@ class FileLike: class TCPClient: - def __init__(self, ssl, host, port, clientcert): - self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert + def __init__(self, host, port): + self.host, self.port = host, port self.connection, self.rfile, self.wfile = None, None, None self.cert = None + def convert_to_ssl(self, clientcert=None): + context = SSL.Context(SSL.SSLv23_METHOD) + if clientcert: + context.use_certificate_file(self.clientcert) + self.connection = SSL.Connection(context, self.connection) + self.connection.set_connect_state() + self.cert = self.connection.get_peer_certificate() + self.rfile = FileLike(self.connection) + self.wfile = FileLike(self.connection) + def connect(self): try: addr = socket.gethostbyname(self.host) - server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - if self.ssl: - context = SSL.Context(SSL.SSLv23_METHOD) - if self.clientcert: - context.use_certificate_file(self.clientcert) - server = SSL.Connection(context, server) - server.connect((addr, self.port)) - if self.ssl: - self.cert = server.get_peer_certificate() - self.rfile, self.wfile = FileLike(server), FileLike(server) - else: - self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') + connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + connection.connect((addr, self.port)) + self.rfile, self.wfile = connection.makefile('rb'), connection.makefile('wb') except socket.error, err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) - self.connection = server + self.connection = connection class BaseHandler: diff --git a/test/test_tcp.py b/test/test_tcp.py index 1bad9a04..26286bc4 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -66,7 +66,7 @@ class TestServer(ServerTestBase): def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + c = tcp.TCPClient("127.0.0.1", self.port) c.connect() c.wfile.write(testval) c.wfile.flush() @@ -82,8 +82,9 @@ class TestServerSSL(ServerTestBase): return s def test_echo(self): - c = tcp.TCPClient(True, "127.0.0.1", self.port, None) + c = tcp.TCPClient("127.0.0.1", self.port) c.connect() + c.convert_to_ssl() testval = "echo!\n" c.wfile.write(testval) c.wfile.flush() @@ -92,7 +93,7 @@ class TestServerSSL(ServerTestBase): class TestTCPClient: def test_conerr(self): - c = tcp.TCPClient(True, "127.0.0.1", 0, None) + c = tcp.TCPClient("127.0.0.1", 0) tutils.raises(tcp.NetLibError, c.connect) -- cgit v1.2.3 From ea457fac2e270c258172be65a0eeb4701ad23d8e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Jun 2012 16:16:01 +1200 Subject: Perform handshake immediately on SSL conversion. Otherwise the handshake happens at first write, which can balls up if either side hangs immediately. --- netlib/tcp.py | 2 ++ test/test_tcp.py | 42 ++++++++++++++++++++++++++++++++---------- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 3c5c89b7..276d3162 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -59,6 +59,7 @@ class TCPClient: context.use_certificate_file(self.clientcert) self.connection = SSL.Connection(context, self.connection) self.connection.set_connect_state() + self.connection.do_handshake() self.cert = self.connection.get_peer_certificate() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) @@ -95,6 +96,7 @@ class BaseHandler: ctx.use_certificate_file(cert) self.connection = SSL.Connection(ctx, self.connection) self.connection.set_accept_state() + self.connection.do_handshake() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) diff --git a/test/test_tcp.py b/test/test_tcp.py index 26286bc4..a81632e7 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -25,13 +25,8 @@ class ServerTestBase: cls.server.shutdown() -class THandler(tcp.BaseHandler): +class EchoHandler(tcp.BaseHandler): def handle(self): - if self.server.ssl: - self.convert_to_ssl( - tutils.test_data.path("data/server.crt"), - tutils.test_data.path("data/server.key"), - ) v = self.rfile.readline() if v.startswith("echo"): self.wfile.write(v) @@ -40,13 +35,24 @@ class THandler(tcp.BaseHandler): self.wfile.flush() +class DisconnectHandler(tcp.BaseHandler): + def handle(self): + self.finish() + + class TServer(tcp.TCPServer): - def __init__(self, addr, ssl, q): + def __init__(self, addr, ssl, q, handler): tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q + self.handler = handler def handle_connection(self, request, client_address): - h = THandler(request, client_address, self) + h = self.handler(request, client_address, self) + if self.ssl: + h.convert_to_ssl( + tutils.test_data.path("data/server.crt"), + tutils.test_data.path("data/server.key"), + ) h.handle() h.finish() @@ -60,7 +66,7 @@ class TestServer(ServerTestBase): @classmethod def makeserver(cls): cls.q = Queue.Queue() - s = TServer(("127.0.0.1", 0), False, cls.q) + s = TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) cls.port = s.port return s @@ -77,7 +83,7 @@ class TestServerSSL(ServerTestBase): @classmethod def makeserver(cls): cls.q = Queue.Queue() - s = TServer(("127.0.0.1", 0), True, cls.q) + s = TServer(("127.0.0.1", 0), True, cls.q, EchoHandler) cls.port = s.port return s @@ -91,6 +97,22 @@ class TestServerSSL(ServerTestBase): assert c.rfile.readline() == testval +class TestSSLDisconnect(ServerTestBase): + @classmethod + def makeserver(cls): + cls.q = Queue.Queue() + s = TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler) + cls.port = s.port + return s + + def test_echo(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.convert_to_ssl() + # Excercise SSL.ZeroReturnError + c.rfile.read(10) + + class TestTCPClient: def test_conerr(self): c = tcp.TCPClient("127.0.0.1", 0) -- cgit v1.2.3 From ccf2603ddc9c832f9533eeb3c4ffbbd685b00057 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 26 Jun 2012 09:50:42 +1200 Subject: Add SNI. --- netlib/tcp.py | 23 ++++++++++++++++++++++- test/test_tcp.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 276d3162..c8ffefdf 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -53,11 +53,13 @@ class TCPClient: self.connection, self.rfile, self.wfile = None, None, None self.cert = None - def convert_to_ssl(self, clientcert=None): + def convert_to_ssl(self, clientcert=None, sni=None): context = SSL.Context(SSL.SSLv23_METHOD) if clientcert: context.use_certificate_file(self.clientcert) self.connection = SSL.Connection(context, self.connection) + if sni: + self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() self.connection.do_handshake() self.cert = self.connection.get_peer_certificate() @@ -92,10 +94,12 @@ class BaseHandler: def convert_to_ssl(self, cert, key): ctx = SSL.Context(SSL.SSLv23_METHOD) + ctx.set_tlsext_servername_callback(self.handle_sni) ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) self.connection = SSL.Connection(ctx, self.connection) self.connection.set_accept_state() + # SNI callback happens during do_handshake() self.connection.do_handshake() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) @@ -111,6 +115,23 @@ class BaseHandler: except IOError: # pragma: no cover pass + def handle_sni(self, connection): + """ + Called if the client has given a server name indication. + + Server name can be retrieved like this: + + connection.get_servername() + + And you can specify the connection keys as follows: + + new_context = Context(TLSv1_METHOD) + new_context.use_privatekey(key) + new_context.use_certificate(cert) + connection.set_context(new_context) + """ + pass + def handle(self): # pragma: no cover raise NotImplementedError diff --git a/test/test_tcp.py b/test/test_tcp.py index a81632e7..a2ee5e36 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -25,7 +25,21 @@ class ServerTestBase: cls.server.shutdown() +class SNIHandler(tcp.BaseHandler): + sni = None + def handle_sni(self, connection): + self.sni = connection.get_servername() + + def handle(self): + self.wfile.write(self.sni) + self.wfile.flush() + + class EchoHandler(tcp.BaseHandler): + sni = None + def handle_sni(self, connection): + self.sni = connection.get_servername() + def handle(self): v = self.rfile.readline() if v.startswith("echo"): @@ -90,13 +104,28 @@ class TestServerSSL(ServerTestBase): def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() - c.convert_to_ssl() + c.convert_to_ssl(sni="foo.com") testval = "echo!\n" c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval +class TestSNI(ServerTestBase): + @classmethod + def makeserver(cls): + cls.q = Queue.Queue() + s = TServer(("127.0.0.1", 0), True, cls.q, SNIHandler) + cls.port = s.port + return s + + def test_echo(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.convert_to_ssl(sni="foo.com") + assert c.rfile.readline() == "foo.com" + + class TestSSLDisconnect(ServerTestBase): @classmethod def makeserver(cls): -- cgit v1.2.3 From 658c9c0446591e41d6ebdb223c62c00342b83206 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 26 Jun 2012 14:49:23 +1200 Subject: Hunt down a tricky WSGI socket hang. --- netlib/tcp.py | 12 +++++++++--- netlib/wsgi.py | 3 ++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index c8ffefdf..aa923fdd 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -61,7 +61,10 @@ class TCPClient: if sni: self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() - self.connection.do_handshake() + try: + self.connection.do_handshake() + except SSL.Error, v: + raise NetLibError("SSL handshake error: %s"%str(v)) self.cert = self.connection.get_peer_certificate() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) @@ -82,7 +85,7 @@ class BaseHandler: The instantiator is expected to call the handle() and finish() methods. """ rbufsize = -1 - wbufsize = 0 + wbufsize = -1 def __init__(self, connection, client_address, server): self.connection = connection self.rfile = self.connection.makefile('rb', self.rbufsize) @@ -100,7 +103,10 @@ class BaseHandler: self.connection = SSL.Connection(ctx, self.connection) self.connection.set_accept_state() # SNI callback happens during do_handshake() - self.connection.do_handshake() + try: + self.connection.do_handshake() + except SSL.Error, v: + raise NetLibError("SSL handshake error: %s"%str(v)) self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 755bea5a..6fe6b6b3 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -104,7 +104,8 @@ class WSGIAdaptor: soc.write(str(h)) soc.write("\r\n") state["headers_sent"] = True - soc.write(data) + if data: + soc.write(data) soc.flush() def start_response(status, headers, exc_info=None): -- cgit v1.2.3 From abe335e57dd2871a6ea6cfe2559f9b29ae0c33bb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 26 Jun 2012 23:52:35 +1200 Subject: Add a flag to track SSL connection establishment. --- netlib/tcp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index aa923fdd..9b1fc65e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -52,6 +52,7 @@ class TCPClient: self.host, self.port = host, port self.connection, self.rfile, self.wfile = None, None, None self.cert = None + self.ssl_established = False def convert_to_ssl(self, clientcert=None, sni=None): context = SSL.Context(SSL.SSLv23_METHOD) @@ -68,6 +69,7 @@ class TCPClient: self.cert = self.connection.get_peer_certificate() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) + self.ssl_established = True def connect(self): try: @@ -94,6 +96,7 @@ class BaseHandler: self.client_address = client_address self.server = server self.finished = False + self.ssl_established = False def convert_to_ssl(self, cert, key): ctx = SSL.Context(SSL.SSLv23_METHOD) @@ -109,6 +112,7 @@ class BaseHandler: raise NetLibError("SSL handshake error: %s"%str(v)) self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) + self.ssl_established = True def finish(self): self.finished = True -- cgit v1.2.3 From d0fd8385e60ea6149d9ff6876fb5b4343187b23a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 27 Jun 2012 12:11:55 +1200 Subject: Fix termiantion error in file read. --- netlib/tcp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 9b1fc65e..0ab7f0e4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -17,7 +17,7 @@ class FileLike: def read(self, length): result = '' - while len(result) < length: + while length > 0: try: data = self.o.read(length) except (SSL.ZeroReturnError, SSL.SysCallError): @@ -25,6 +25,7 @@ class FileLike: if not data: break result += data + length -= len(data) return result def write(self, v): -- cgit v1.2.3 From 5d4c7829bfdda8c0a5fd28896fd925d63221b929 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 27 Jun 2012 16:24:22 +1200 Subject: Minor refactoring. --- netlib/http.py | 3 +++ netlib/tcp.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/netlib/http.py b/netlib/http.py index 150995dd..9c72c601 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -248,6 +248,9 @@ def read_http_body_response(rfile, headers, all, limit): def read_response(rfile, method, body_size_limit): + """ + Return an (httpversion, code, msg, headers, content) tuple. + """ line = rfile.readline() if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() diff --git a/netlib/tcp.py b/netlib/tcp.py index 9b1fc65e..49c8b7a2 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -156,8 +156,8 @@ class TCPServer: self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.bind(self.server_address) self.server_address = self.socket.getsockname() + self.port = self.server_address[1] self.socket.listen(self.request_queue_size) - self.port = self.socket.getsockname()[1] def request_thread(self, request, client_address): try: -- cgit v1.2.3 From f7fcb1c80b2874df05db4603549c6a24d12e58c0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 27 Jun 2012 16:42:00 +1200 Subject: Add certutils to netlib. --- netlib/certutils.py | 219 +++++++++++++++++++++++++++++++++++++++++++++++++ test/data/dercert | Bin 0 -> 1838 bytes test/data/text_cert | 145 ++++++++++++++++++++++++++++++++ test/data/text_cert_2 | 39 +++++++++ test/test_certutils.py | 72 ++++++++++++++++ 5 files changed, 475 insertions(+) create mode 100644 netlib/certutils.py create mode 100644 test/data/dercert create mode 100644 test/data/text_cert create mode 100644 test/data/text_cert_2 create mode 100644 test/test_certutils.py diff --git a/netlib/certutils.py b/netlib/certutils.py new file mode 100644 index 00000000..31b1fa08 --- /dev/null +++ b/netlib/certutils.py @@ -0,0 +1,219 @@ +import os, ssl, hashlib, socket, time, datetime +from pyasn1.type import univ, constraint, char, namedtype, tag +from pyasn1.codec.der.decoder import decode +import OpenSSL + +CERT_SLEEP_TIME = 1 +CERT_EXPIRY = str(365 * 3) + + +def create_ca(): + key = OpenSSL.crypto.PKey() + key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) + ca = OpenSSL.crypto.X509() + ca.set_serial_number(int(time.time()*10000)) + ca.set_version(2) + ca.get_subject().CN = "mitmproxy" + ca.get_subject().O = "mitmproxy" + ca.gmtime_adj_notBefore(0) + ca.gmtime_adj_notAfter(24 * 60 * 60 * 720) + ca.set_issuer(ca.get_subject()) + ca.set_pubkey(key) + ca.add_extensions([ + OpenSSL.crypto.X509Extension("basicConstraints", True, + "CA:TRUE"), + OpenSSL.crypto.X509Extension("nsCertType", True, + "sslCA"), + OpenSSL.crypto.X509Extension("extendedKeyUsage", True, + "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" + ), + OpenSSL.crypto.X509Extension("keyUsage", False, + "keyCertSign, cRLSign"), + OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", + subject=ca), + ]) + ca.sign(key, "sha1") + return key, ca + + +def dummy_ca(path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + if path.endswith(".pem"): + basename, _ = os.path.splitext(path) + else: + basename = path + + key, ca = create_ca() + + # Dump the CA plus private key + f = open(path, "w") + f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Dump the certificate in PEM format + f = open(os.path.join(dirname, basename + "-cert.pem"), "w") + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Create a .cer file with the same contents for Android + f = open(os.path.join(dirname, basename + "-cert.cer"), "w") + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Dump the certificate in PKCS12 format for Windows devices + f = open(os.path.join(dirname, basename + "-cert.p12"), "w") + p12 = OpenSSL.crypto.PKCS12() + p12.set_certificate(ca) + p12.set_privatekey(key) + f.write(p12.export()) + f.close() + return True + + +def dummy_cert(certdir, ca, commonname, sans): + """ + certdir: Certificate directory. + ca: Path to the certificate authority file, or None. + commonname: Common name for the generated certificate. + + Returns cert path if operation succeeded, None if not. + """ + namehash = hashlib.sha256(commonname).hexdigest() + certpath = os.path.join(certdir, namehash + ".pem") + if os.path.exists(certpath): + return certpath + + ss = [] + for i in sans: + ss.append("DNS: %s"%i) + ss = ", ".join(ss) + + if ca: + raw = file(ca, "r").read() + ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + else: + key, ca = create_ca() + + req = OpenSSL.crypto.X509Req() + subj = req.get_subject() + subj.CN = commonname + req.set_pubkey(ca.get_pubkey()) + req.sign(key, "sha1") + if ss: + req.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) + + cert = OpenSSL.crypto.X509() + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) + cert.set_issuer(ca.get_subject()) + cert.set_subject(req.get_subject()) + cert.set_serial_number(int(time.time()*10000)) + if ss: + cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) + cert.set_pubkey(req.get_pubkey()) + cert.sign(key, "sha1") + + f = open(certpath, "w") + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)) + f.close() + + return certpath + + +class _GeneralName(univ.Choice): + # We are only interested in dNSNames. We use a default handler to ignore + # other types. + componentType = namedtype.NamedTypes( + namedtype.NamedType('dNSName', char.IA5String().subtype( + implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) + ) + ), + ) + + +class _GeneralNames(univ.SequenceOf): + componentType = _GeneralName() + sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) + + +class SSLCert: + def __init__(self, pemtxt): + """ + Returns a (common name, [subject alternative names]) tuple. + """ + self.cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pemtxt) + + @classmethod + def from_der(klass, der): + pem = ssl.DER_cert_to_PEM_cert(der) + return klass(pem) + + def digest(self, name): + return self.cert.digest(name) + + @property + def issuer(self): + return self.cert.get_issuer().get_components() + + @property + def notbefore(self): + t = self.cert.get_notBefore() + return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") + + @property + def notafter(self): + t = self.cert.get_notAfter() + return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") + + @property + def has_expired(self): + return self.cert.has_expired() + + @property + def subject(self): + return self.cert.get_subject().get_components() + + @property + def serial(self): + return self.cert.get_serial_number() + + @property + def keyinfo(self): + pk = self.cert.get_pubkey() + types = { + OpenSSL.crypto.TYPE_RSA: "RSA", + OpenSSL.crypto.TYPE_DSA: "DSA", + } + return ( + types.get(pk.type(), "UNKNOWN"), + pk.bits() + ) + + @property + def cn(self): + cn = None + for i in self.subject: + if i[0] == "CN": + cn = i[1] + return cn + + @property + def altnames(self): + altnames = [] + for i in range(self.cert.get_extension_count()): + ext = self.cert.get_extension(i) + if ext.get_short_name() == "subjectAltName": + dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) + for i in dec[0]: + altnames.append(i[0].asOctets()) + return altnames + + +def get_remote_cert(host, port): # pragma: no cover + addr = socket.gethostbyname(host) + s = ssl.get_server_certificate((addr, port)) + return SSLCert(s) diff --git a/test/data/dercert b/test/data/dercert new file mode 100644 index 00000000..370252af Binary files /dev/null and b/test/data/dercert differ diff --git a/test/data/text_cert b/test/data/text_cert new file mode 100644 index 00000000..36ca33b9 --- /dev/null +++ b/test/data/text_cert @@ -0,0 +1,145 @@ +-----BEGIN CERTIFICATE----- +MIIadTCCGd6gAwIBAgIGR09PUAFtMA0GCSqGSIb3DQEBBQUAMEYxCzAJBgNVBAYT +AlVTMRMwEQYDVQQKEwpHb29nbGUgSW5jMSIwIAYDVQQDExlHb29nbGUgSW50ZXJu +ZXQgQXV0aG9yaXR5MB4XDTEyMDExNzEyNTUwNFoXDTEzMDExNzEyNTUwNFowTDEL +MAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExEzARBgNVBAoTCkdvb2ds +ZSBJbmMxEzARBgNVBAMTCmdvb2dsZS5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0A +MIGJAoGBALofcxR2fud5cyFIeld9pj2vGB5GH0y9tmAYa5t33xbJguKKX/el3tXA +KMNiT1SZzu8ELJ1Ey0GcBAgHA9jVPQd0LGdbEtNIxjblAsWAD/FZlSt8X87h7C5w +2JSefOani0qgQqU6sTdsaCUGZ+Eu7D0lBfT5/Vnl2vV+zI3YmDlpAgMBAAGjghhm +MIIYYjAdBgNVHQ4EFgQUL3+JeC/oL9jZhTp3F550LautzV8wHwYDVR0jBBgwFoAU +v8Aw6/VDET5nup6R+/xq2uNrEiQwWwYDVR0fBFQwUjBQoE6gTIZKaHR0cDovL3d3 +dy5nc3RhdGljLmNvbS9Hb29nbGVJbnRlcm5ldEF1dGhvcml0eS9Hb29nbGVJbnRl +cm5ldEF1dGhvcml0eS5jcmwwZgYIKwYBBQUHAQEEWjBYMFYGCCsGAQUFBzAChkpo +dHRwOi8vd3d3LmdzdGF0aWMuY29tL0dvb2dsZUludGVybmV0QXV0aG9yaXR5L0dv +b2dsZUludGVybmV0QXV0aG9yaXR5LmNydDCCF1kGA1UdEQSCF1AwghdMggpnb29n +bGUuY29tggwqLmdvb2dsZS5jb22CCyouZ29vZ2xlLmFjggsqLmdvb2dsZS5hZIIL +Ki5nb29nbGUuYWWCCyouZ29vZ2xlLmFmggsqLmdvb2dsZS5hZ4ILKi5nb29nbGUu +YW2CCyouZ29vZ2xlLmFzggsqLmdvb2dsZS5hdIILKi5nb29nbGUuYXqCCyouZ29v +Z2xlLmJhggsqLmdvb2dsZS5iZYILKi5nb29nbGUuYmaCCyouZ29vZ2xlLmJnggsq +Lmdvb2dsZS5iaYILKi5nb29nbGUuYmqCCyouZ29vZ2xlLmJzggsqLmdvb2dsZS5i +eYILKi5nb29nbGUuY2GCDCouZ29vZ2xlLmNhdIILKi5nb29nbGUuY2OCCyouZ29v +Z2xlLmNkggsqLmdvb2dsZS5jZoILKi5nb29nbGUuY2eCCyouZ29vZ2xlLmNoggsq +Lmdvb2dsZS5jaYILKi5nb29nbGUuY2yCCyouZ29vZ2xlLmNtggsqLmdvb2dsZS5j +boIOKi5nb29nbGUuY28uYW+CDiouZ29vZ2xlLmNvLmJ3gg4qLmdvb2dsZS5jby5j +a4IOKi5nb29nbGUuY28uY3KCDiouZ29vZ2xlLmNvLmh1gg4qLmdvb2dsZS5jby5p +ZIIOKi5nb29nbGUuY28uaWyCDiouZ29vZ2xlLmNvLmltgg4qLmdvb2dsZS5jby5p +boIOKi5nb29nbGUuY28uamWCDiouZ29vZ2xlLmNvLmpwgg4qLmdvb2dsZS5jby5r +ZYIOKi5nb29nbGUuY28ua3KCDiouZ29vZ2xlLmNvLmxzgg4qLmdvb2dsZS5jby5t +YYIOKi5nb29nbGUuY28ubXqCDiouZ29vZ2xlLmNvLm56gg4qLmdvb2dsZS5jby50 +aIIOKi5nb29nbGUuY28udHqCDiouZ29vZ2xlLmNvLnVngg4qLmdvb2dsZS5jby51 +a4IOKi5nb29nbGUuY28udXqCDiouZ29vZ2xlLmNvLnZlgg4qLmdvb2dsZS5jby52 +aYIOKi5nb29nbGUuY28uemGCDiouZ29vZ2xlLmNvLnptgg4qLmdvb2dsZS5jby56 +d4IPKi5nb29nbGUuY29tLmFmgg8qLmdvb2dsZS5jb20uYWeCDyouZ29vZ2xlLmNv +bS5haYIPKi5nb29nbGUuY29tLmFygg8qLmdvb2dsZS5jb20uYXWCDyouZ29vZ2xl +LmNvbS5iZIIPKi5nb29nbGUuY29tLmJogg8qLmdvb2dsZS5jb20uYm6CDyouZ29v +Z2xlLmNvbS5ib4IPKi5nb29nbGUuY29tLmJygg8qLmdvb2dsZS5jb20uYnmCDyou +Z29vZ2xlLmNvbS5ieoIPKi5nb29nbGUuY29tLmNugg8qLmdvb2dsZS5jb20uY2+C +DyouZ29vZ2xlLmNvbS5jdYIPKi5nb29nbGUuY29tLmN5gg8qLmdvb2dsZS5jb20u +ZG+CDyouZ29vZ2xlLmNvbS5lY4IPKi5nb29nbGUuY29tLmVngg8qLmdvb2dsZS5j +b20uZXSCDyouZ29vZ2xlLmNvbS5maoIPKi5nb29nbGUuY29tLmdlgg8qLmdvb2ds +ZS5jb20uZ2iCDyouZ29vZ2xlLmNvbS5naYIPKi5nb29nbGUuY29tLmdygg8qLmdv +b2dsZS5jb20uZ3SCDyouZ29vZ2xlLmNvbS5oa4IPKi5nb29nbGUuY29tLmlxgg8q +Lmdvb2dsZS5jb20uam2CDyouZ29vZ2xlLmNvbS5qb4IPKi5nb29nbGUuY29tLmto +gg8qLmdvb2dsZS5jb20ua3eCDyouZ29vZ2xlLmNvbS5sYoIPKi5nb29nbGUuY29t +Lmx5gg8qLmdvb2dsZS5jb20ubXSCDyouZ29vZ2xlLmNvbS5teIIPKi5nb29nbGUu +Y29tLm15gg8qLmdvb2dsZS5jb20ubmGCDyouZ29vZ2xlLmNvbS5uZoIPKi5nb29n +bGUuY29tLm5ngg8qLmdvb2dsZS5jb20ubmmCDyouZ29vZ2xlLmNvbS5ucIIPKi5n +b29nbGUuY29tLm5ygg8qLmdvb2dsZS5jb20ub22CDyouZ29vZ2xlLmNvbS5wYYIP +Ki5nb29nbGUuY29tLnBlgg8qLmdvb2dsZS5jb20ucGiCDyouZ29vZ2xlLmNvbS5w +a4IPKi5nb29nbGUuY29tLnBsgg8qLmdvb2dsZS5jb20ucHKCDyouZ29vZ2xlLmNv +bS5weYIPKi5nb29nbGUuY29tLnFhgg8qLmdvb2dsZS5jb20ucnWCDyouZ29vZ2xl +LmNvbS5zYYIPKi5nb29nbGUuY29tLnNigg8qLmdvb2dsZS5jb20uc2eCDyouZ29v +Z2xlLmNvbS5zbIIPKi5nb29nbGUuY29tLnN2gg8qLmdvb2dsZS5jb20udGqCDyou +Z29vZ2xlLmNvbS50boIPKi5nb29nbGUuY29tLnRygg8qLmdvb2dsZS5jb20udHeC +DyouZ29vZ2xlLmNvbS51YYIPKi5nb29nbGUuY29tLnV5gg8qLmdvb2dsZS5jb20u +dmOCDyouZ29vZ2xlLmNvbS52ZYIPKi5nb29nbGUuY29tLnZuggsqLmdvb2dsZS5j +doILKi5nb29nbGUuY3qCCyouZ29vZ2xlLmRlggsqLmdvb2dsZS5kaoILKi5nb29n +bGUuZGuCCyouZ29vZ2xlLmRtggsqLmdvb2dsZS5keoILKi5nb29nbGUuZWWCCyou +Z29vZ2xlLmVzggsqLmdvb2dsZS5maYILKi5nb29nbGUuZm2CCyouZ29vZ2xlLmZy +ggsqLmdvb2dsZS5nYYILKi5nb29nbGUuZ2WCCyouZ29vZ2xlLmdnggsqLmdvb2ds +ZS5nbIILKi5nb29nbGUuZ22CCyouZ29vZ2xlLmdwggsqLmdvb2dsZS5ncoILKi5n +b29nbGUuZ3mCCyouZ29vZ2xlLmhrggsqLmdvb2dsZS5oboILKi5nb29nbGUuaHKC +CyouZ29vZ2xlLmh0ggsqLmdvb2dsZS5odYILKi5nb29nbGUuaWWCCyouZ29vZ2xl +Lmltgg0qLmdvb2dsZS5pbmZvggsqLmdvb2dsZS5pcYILKi5nb29nbGUuaXOCCyou +Z29vZ2xlLml0gg4qLmdvb2dsZS5pdC5hb4ILKi5nb29nbGUuamWCCyouZ29vZ2xl +Lmpvgg0qLmdvb2dsZS5qb2JzggsqLmdvb2dsZS5qcIILKi5nb29nbGUua2eCCyou +Z29vZ2xlLmtpggsqLmdvb2dsZS5reoILKi5nb29nbGUubGGCCyouZ29vZ2xlLmxp +ggsqLmdvb2dsZS5sa4ILKi5nb29nbGUubHSCCyouZ29vZ2xlLmx1ggsqLmdvb2ds +ZS5sdoILKi5nb29nbGUubWSCCyouZ29vZ2xlLm1lggsqLmdvb2dsZS5tZ4ILKi5n +b29nbGUubWuCCyouZ29vZ2xlLm1sggsqLmdvb2dsZS5tboILKi5nb29nbGUubXOC +CyouZ29vZ2xlLm11ggsqLmdvb2dsZS5tdoILKi5nb29nbGUubXeCCyouZ29vZ2xl +Lm5lgg4qLmdvb2dsZS5uZS5qcIIMKi5nb29nbGUubmV0ggsqLmdvb2dsZS5ubIIL +Ki5nb29nbGUubm+CCyouZ29vZ2xlLm5yggsqLmdvb2dsZS5udYIPKi5nb29nbGUu +b2ZmLmFpggsqLmdvb2dsZS5wa4ILKi5nb29nbGUucGyCCyouZ29vZ2xlLnBuggsq +Lmdvb2dsZS5wc4ILKi5nb29nbGUucHSCCyouZ29vZ2xlLnJvggsqLmdvb2dsZS5y +c4ILKi5nb29nbGUucnWCCyouZ29vZ2xlLnJ3ggsqLmdvb2dsZS5zY4ILKi5nb29n +bGUuc2WCCyouZ29vZ2xlLnNoggsqLmdvb2dsZS5zaYILKi5nb29nbGUuc2uCCyou +Z29vZ2xlLnNtggsqLmdvb2dsZS5zboILKi5nb29nbGUuc2+CCyouZ29vZ2xlLnN0 +ggsqLmdvb2dsZS50ZIILKi5nb29nbGUudGeCCyouZ29vZ2xlLnRrggsqLmdvb2ds +ZS50bIILKi5nb29nbGUudG2CCyouZ29vZ2xlLnRuggsqLmdvb2dsZS50b4ILKi5n +b29nbGUudHCCCyouZ29vZ2xlLnR0ggsqLmdvb2dsZS51c4ILKi5nb29nbGUudXqC +CyouZ29vZ2xlLnZnggsqLmdvb2dsZS52dYILKi5nb29nbGUud3OCCWdvb2dsZS5h +Y4IJZ29vZ2xlLmFkgglnb29nbGUuYWWCCWdvb2dsZS5hZoIJZ29vZ2xlLmFnggln +b29nbGUuYW2CCWdvb2dsZS5hc4IJZ29vZ2xlLmF0gglnb29nbGUuYXqCCWdvb2ds +ZS5iYYIJZ29vZ2xlLmJlgglnb29nbGUuYmaCCWdvb2dsZS5iZ4IJZ29vZ2xlLmJp +gglnb29nbGUuYmqCCWdvb2dsZS5ic4IJZ29vZ2xlLmJ5gglnb29nbGUuY2GCCmdv +b2dsZS5jYXSCCWdvb2dsZS5jY4IJZ29vZ2xlLmNkgglnb29nbGUuY2aCCWdvb2ds +ZS5jZ4IJZ29vZ2xlLmNogglnb29nbGUuY2mCCWdvb2dsZS5jbIIJZ29vZ2xlLmNt +gglnb29nbGUuY26CDGdvb2dsZS5jby5hb4IMZ29vZ2xlLmNvLmJ3ggxnb29nbGUu +Y28uY2uCDGdvb2dsZS5jby5jcoIMZ29vZ2xlLmNvLmh1ggxnb29nbGUuY28uaWSC +DGdvb2dsZS5jby5pbIIMZ29vZ2xlLmNvLmltggxnb29nbGUuY28uaW6CDGdvb2ds +ZS5jby5qZYIMZ29vZ2xlLmNvLmpwggxnb29nbGUuY28ua2WCDGdvb2dsZS5jby5r +coIMZ29vZ2xlLmNvLmxzggxnb29nbGUuY28ubWGCDGdvb2dsZS5jby5teoIMZ29v +Z2xlLmNvLm56ggxnb29nbGUuY28udGiCDGdvb2dsZS5jby50eoIMZ29vZ2xlLmNv +LnVnggxnb29nbGUuY28udWuCDGdvb2dsZS5jby51eoIMZ29vZ2xlLmNvLnZlggxn +b29nbGUuY28udmmCDGdvb2dsZS5jby56YYIMZ29vZ2xlLmNvLnptggxnb29nbGUu +Y28ueneCDWdvb2dsZS5jb20uYWaCDWdvb2dsZS5jb20uYWeCDWdvb2dsZS5jb20u +YWmCDWdvb2dsZS5jb20uYXKCDWdvb2dsZS5jb20uYXWCDWdvb2dsZS5jb20uYmSC +DWdvb2dsZS5jb20uYmiCDWdvb2dsZS5jb20uYm6CDWdvb2dsZS5jb20uYm+CDWdv +b2dsZS5jb20uYnKCDWdvb2dsZS5jb20uYnmCDWdvb2dsZS5jb20uYnqCDWdvb2ds +ZS5jb20uY26CDWdvb2dsZS5jb20uY2+CDWdvb2dsZS5jb20uY3WCDWdvb2dsZS5j +b20uY3mCDWdvb2dsZS5jb20uZG+CDWdvb2dsZS5jb20uZWOCDWdvb2dsZS5jb20u +ZWeCDWdvb2dsZS5jb20uZXSCDWdvb2dsZS5jb20uZmqCDWdvb2dsZS5jb20uZ2WC +DWdvb2dsZS5jb20uZ2iCDWdvb2dsZS5jb20uZ2mCDWdvb2dsZS5jb20uZ3KCDWdv +b2dsZS5jb20uZ3SCDWdvb2dsZS5jb20uaGuCDWdvb2dsZS5jb20uaXGCDWdvb2ds +ZS5jb20uam2CDWdvb2dsZS5jb20uam+CDWdvb2dsZS5jb20ua2iCDWdvb2dsZS5j +b20ua3eCDWdvb2dsZS5jb20ubGKCDWdvb2dsZS5jb20ubHmCDWdvb2dsZS5jb20u +bXSCDWdvb2dsZS5jb20ubXiCDWdvb2dsZS5jb20ubXmCDWdvb2dsZS5jb20ubmGC +DWdvb2dsZS5jb20ubmaCDWdvb2dsZS5jb20ubmeCDWdvb2dsZS5jb20ubmmCDWdv +b2dsZS5jb20ubnCCDWdvb2dsZS5jb20ubnKCDWdvb2dsZS5jb20ub22CDWdvb2ds +ZS5jb20ucGGCDWdvb2dsZS5jb20ucGWCDWdvb2dsZS5jb20ucGiCDWdvb2dsZS5j +b20ucGuCDWdvb2dsZS5jb20ucGyCDWdvb2dsZS5jb20ucHKCDWdvb2dsZS5jb20u +cHmCDWdvb2dsZS5jb20ucWGCDWdvb2dsZS5jb20ucnWCDWdvb2dsZS5jb20uc2GC +DWdvb2dsZS5jb20uc2KCDWdvb2dsZS5jb20uc2eCDWdvb2dsZS5jb20uc2yCDWdv +b2dsZS5jb20uc3aCDWdvb2dsZS5jb20udGqCDWdvb2dsZS5jb20udG6CDWdvb2ds +ZS5jb20udHKCDWdvb2dsZS5jb20udHeCDWdvb2dsZS5jb20udWGCDWdvb2dsZS5j +b20udXmCDWdvb2dsZS5jb20udmOCDWdvb2dsZS5jb20udmWCDWdvb2dsZS5jb20u +dm6CCWdvb2dsZS5jdoIJZ29vZ2xlLmN6gglnb29nbGUuZGWCCWdvb2dsZS5kaoIJ +Z29vZ2xlLmRrgglnb29nbGUuZG2CCWdvb2dsZS5keoIJZ29vZ2xlLmVlgglnb29n +bGUuZXOCCWdvb2dsZS5maYIJZ29vZ2xlLmZtgglnb29nbGUuZnKCCWdvb2dsZS5n +YYIJZ29vZ2xlLmdlgglnb29nbGUuZ2eCCWdvb2dsZS5nbIIJZ29vZ2xlLmdtggln +b29nbGUuZ3CCCWdvb2dsZS5ncoIJZ29vZ2xlLmd5gglnb29nbGUuaGuCCWdvb2ds +ZS5oboIJZ29vZ2xlLmhygglnb29nbGUuaHSCCWdvb2dsZS5odYIJZ29vZ2xlLmll +gglnb29nbGUuaW2CC2dvb2dsZS5pbmZvgglnb29nbGUuaXGCCWdvb2dsZS5pc4IJ +Z29vZ2xlLml0ggxnb29nbGUuaXQuYW+CCWdvb2dsZS5qZYIJZ29vZ2xlLmpvggtn +b29nbGUuam9ic4IJZ29vZ2xlLmpwgglnb29nbGUua2eCCWdvb2dsZS5raYIJZ29v +Z2xlLmt6gglnb29nbGUubGGCCWdvb2dsZS5saYIJZ29vZ2xlLmxrgglnb29nbGUu +bHSCCWdvb2dsZS5sdYIJZ29vZ2xlLmx2gglnb29nbGUubWSCCWdvb2dsZS5tZYIJ +Z29vZ2xlLm1ngglnb29nbGUubWuCCWdvb2dsZS5tbIIJZ29vZ2xlLm1ugglnb29n +bGUubXOCCWdvb2dsZS5tdYIJZ29vZ2xlLm12gglnb29nbGUubXeCCWdvb2dsZS5u +ZYIMZ29vZ2xlLm5lLmpwggpnb29nbGUubmV0gglnb29nbGUubmyCCWdvb2dsZS5u +b4IJZ29vZ2xlLm5ygglnb29nbGUubnWCDWdvb2dsZS5vZmYuYWmCCWdvb2dsZS5w +a4IJZ29vZ2xlLnBsgglnb29nbGUucG6CCWdvb2dsZS5wc4IJZ29vZ2xlLnB0ggln +b29nbGUucm+CCWdvb2dsZS5yc4IJZ29vZ2xlLnJ1gglnb29nbGUucneCCWdvb2ds +ZS5zY4IJZ29vZ2xlLnNlgglnb29nbGUuc2iCCWdvb2dsZS5zaYIJZ29vZ2xlLnNr +gglnb29nbGUuc22CCWdvb2dsZS5zboIJZ29vZ2xlLnNvgglnb29nbGUuc3SCCWdv +b2dsZS50ZIIJZ29vZ2xlLnRngglnb29nbGUudGuCCWdvb2dsZS50bIIJZ29vZ2xl +LnRtgglnb29nbGUudG6CCWdvb2dsZS50b4IJZ29vZ2xlLnRwgglnb29nbGUudHSC +CWdvb2dsZS51c4IJZ29vZ2xlLnV6gglnb29nbGUudmeCCWdvb2dsZS52dYIJZ29v +Z2xlLndzMA0GCSqGSIb3DQEBBQUAA4GBAJmZ9RyqpUzrP0UcJnHXoLu/AjIEsIvZ +Y9hq/9bLry8InfmvERYHr4hNetkOYlW0FeDZtCpWxdPUgJjmWgKAK6j0goOFavTV +GptkL8gha4p1QUsdLkd36/cvBXeBYSle787veo46N1k4V6Uv2gaDVkre786CNsHv +Q6MYZ5ClQ+kS +-----END CERTIFICATE----- + diff --git a/test/data/text_cert_2 b/test/data/text_cert_2 new file mode 100644 index 00000000..ffe8faae --- /dev/null +++ b/test/data/text_cert_2 @@ -0,0 +1,39 @@ +-----BEGIN CERTIFICATE----- +MIIGujCCBaKgAwIBAgIDAQlEMA0GCSqGSIb3DQEBBQUAMIGMMQswCQYDVQQGEwJJ +TDEWMBQGA1UEChMNU3RhcnRDb20gTHRkLjErMCkGA1UECxMiU2VjdXJlIERpZ2l0 +YWwgQ2VydGlmaWNhdGUgU2lnbmluZzE4MDYGA1UEAxMvU3RhcnRDb20gQ2xhc3Mg +MSBQcmltYXJ5IEludGVybWVkaWF0ZSBTZXJ2ZXIgQ0EwHhcNMTAwMTExMTkyNzM2 +WhcNMTEwMTEyMDkxNDU1WjCBtDEgMB4GA1UEDRMXMTI2ODMyLU1DeExzWTZUbjFn +bTdvOTAxCzAJBgNVBAYTAk5aMR4wHAYDVQQKExVQZXJzb25hIE5vdCBWYWxpZGF0 +ZWQxKTAnBgNVBAsTIFN0YXJ0Q29tIEZyZWUgQ2VydGlmaWNhdGUgTWVtYmVyMRgw +FgYDVQQDEw93d3cuaW5vZGUuY28ubnoxHjAcBgkqhkiG9w0BCQEWD2ppbUBpbm9k +ZS5jby5uejCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAL6ghWlGhqg+ +V0P58R3SvLRiO9OrdekDxzmQbKwQcc05frnF5Z9vT6ga7YOuXVeXxhYCAo0nr6KI ++y/Lx+QHvP5W0nKbs+svzUQErq2ZZFwhh1e1LbVccrNwkHUzKOq0TTaVdU4k8kDQ +zzYF9tTZb+G5Hv1BJjpwYwe8P4cAiPJPrFFOKTySzHqiYsXlx+vR1l1e3zKavhd+ +LVSoLWWXb13yKODq6vnuiHjUJXl8CfVlBhoGotXU4JR5cbuGoW/8+rkwEdX+YoCv +VCqgdx9IkRFB6uWfN6ocUiFvhA0eknO+ewuVfRLiIaSDB8pNyUWVqu4ngFWtWO1O +YZg0I/32BkcCAwEAAaOCAvkwggL1MAkGA1UdEwQCMAAwCwYDVR0PBAQDAgOoMBMG +A1UdJQQMMAoGCCsGAQUFBwMBMB0GA1UdDgQWBBQfaL2Rj6r8iRlBTgppgE7ZZ5WT +UzAfBgNVHSMEGDAWgBTrQjTQmLCrn/Qbawj3zGQu7w4sRTAnBgNVHREEIDAegg93 +d3cuaW5vZGUuY28ubnqCC2lub2RlLmNvLm56MIIBQgYDVR0gBIIBOTCCATUwggEx +BgsrBgEEAYG1NwECATCCASAwLgYIKwYBBQUHAgEWImh0dHA6Ly93d3cuc3RhcnRz +c2wuY29tL3BvbGljeS5wZGYwNAYIKwYBBQUHAgEWKGh0dHA6Ly93d3cuc3RhcnRz +c2wuY29tL2ludGVybWVkaWF0ZS5wZGYwgbcGCCsGAQUFBwICMIGqMBQWDVN0YXJ0 +Q29tIEx0ZC4wAwIBARqBkUxpbWl0ZWQgTGlhYmlsaXR5LCBzZWUgc2VjdGlvbiAq +TGVnYWwgTGltaXRhdGlvbnMqIG9mIHRoZSBTdGFydENvbSBDZXJ0aWZpY2F0aW9u +IEF1dGhvcml0eSBQb2xpY3kgYXZhaWxhYmxlIGF0IGh0dHA6Ly93d3cuc3RhcnRz +c2wuY29tL3BvbGljeS5wZGYwYQYDVR0fBFowWDAqoCigJoYkaHR0cDovL3d3dy5z +dGFydHNzbC5jb20vY3J0MS1jcmwuY3JsMCqgKKAmhiRodHRwOi8vY3JsLnN0YXJ0 +c3NsLmNvbS9jcnQxLWNybC5jcmwwgY4GCCsGAQUFBwEBBIGBMH8wOQYIKwYBBQUH +MAGGLWh0dHA6Ly9vY3NwLnN0YXJ0c3NsLmNvbS9zdWIvY2xhc3MxL3NlcnZlci9j +YTBCBggrBgEFBQcwAoY2aHR0cDovL3d3dy5zdGFydHNzbC5jb20vY2VydHMvc3Vi +LmNsYXNzMS5zZXJ2ZXIuY2EuY3J0MCMGA1UdEgQcMBqGGGh0dHA6Ly93d3cuc3Rh +cnRzc2wuY29tLzANBgkqhkiG9w0BAQUFAAOCAQEAivWID0KT8q1EzWzy+BecsFry +hQhuLFfAsPkHqpNd9OfkRStGBuJlLX+9DQ9TzjqutdY2buNBuDn71buZK+Y5fmjr +28rAT6+WMd+KnCl5WLT5IOS6Z9s3cec5TFQbmOGlepSS9Q6Ts9KsXOHHQvDkQeDq +OV2UqdgXIAyFm5efSL9JXPXntRausNu2s8F2B2rRJe4jPfnUy2LvY8OW1YvjUA++ +vpdWRdfUbJQp55mRfaYMPRnyUm30lAI27QaxgQPFOqDeZUm5llb5eFG/B3f87uhg ++Y1oEykbEvZrIFN4hithioQ0tb+57FKkkG2sW3uemNiQw2qrEo/GAMb1cI50Rg== +-----END CERTIFICATE----- + diff --git a/test/test_certutils.py b/test/test_certutils.py new file mode 100644 index 00000000..5229fc2a --- /dev/null +++ b/test/test_certutils.py @@ -0,0 +1,72 @@ +import os +from netlib import certutils +import tutils + + +def test_dummy_ca(): + with tutils.tmpdir() as d: + path = os.path.join(d, "foo/cert.cnf") + assert certutils.dummy_ca(path) + assert os.path.exists(path) + + path = os.path.join(d, "foo/cert2.pem") + assert certutils.dummy_ca(path) + assert os.path.exists(path) + assert os.path.exists(os.path.join(d, "foo/cert2-cert.pem")) + assert os.path.exists(os.path.join(d, "foo/cert2-cert.p12")) + + +class TestDummyCert: + def test_with_ca(self): + with tutils.tmpdir() as d: + cacert = os.path.join(d, "foo/cert.cnf") + assert certutils.dummy_ca(cacert) + p = certutils.dummy_cert( + os.path.join(d, "foo"), + cacert, + "foo.com", + ["one.com", "two.com", "*.three.com"] + ) + assert os.path.exists(p) + + # Short-circuit + assert certutils.dummy_cert( + os.path.join(d, "foo"), + cacert, + "foo.com", + [] + ) + + def test_no_ca(self): + with tutils.tmpdir() as d: + p = certutils.dummy_cert( + d, + None, + "foo.com", + [] + ) + assert os.path.exists(p) + + +class TestSSLCert: + def test_simple(self): + c = certutils.SSLCert(file(tutils.test_data.path("data/text_cert"), "r").read()) + assert c.cn == "google.com" + assert len(c.altnames) == 436 + + c = certutils.SSLCert(file(tutils.test_data.path("data/text_cert_2"), "r").read()) + assert c.cn == "www.inode.co.nz" + assert len(c.altnames) == 2 + assert c.digest("sha1") + assert c.notbefore + assert c.notafter + assert c.subject + assert c.keyinfo == ("RSA", 2048) + assert c.serial + assert c.issuer + c.has_expired + + def test_der(self): + d = file(tutils.test_data.path("data/dercert")).read() + s = certutils.SSLCert.from_der(d) + assert s.cn -- cgit v1.2.3 From b0ef9ad07ba4b805f3130237dcf9207434c33d84 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 27 Jun 2012 22:11:58 +1200 Subject: Refactor certutils.SSLCert API. --- netlib/certutils.py | 31 ++++++++++++++++++------------- test/test_certutils.py | 4 ++-- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 31b1fa08..6c9a5c57 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -141,49 +141,54 @@ class _GeneralNames(univ.SequenceOf): class SSLCert: - def __init__(self, pemtxt): + def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. """ - self.cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pemtxt) + self.x509 = cert + + @classmethod + def from_pem(klass, txt): + x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) + return klass(x509) @classmethod def from_der(klass, der): pem = ssl.DER_cert_to_PEM_cert(der) - return klass(pem) + return klass.from_pem(pem) def digest(self, name): - return self.cert.digest(name) + return self.x509.digest(name) @property def issuer(self): - return self.cert.get_issuer().get_components() + return self.x509.get_issuer().get_components() @property def notbefore(self): - t = self.cert.get_notBefore() + t = self.x509.get_notBefore() return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") @property def notafter(self): - t = self.cert.get_notAfter() + t = self.x509.get_notAfter() return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") @property def has_expired(self): - return self.cert.has_expired() + return self.x509.has_expired() @property def subject(self): - return self.cert.get_subject().get_components() + return self.x509.get_subject().get_components() @property def serial(self): - return self.cert.get_serial_number() + return self.x509.get_serial_number() @property def keyinfo(self): - pk = self.cert.get_pubkey() + pk = self.x509.get_pubkey() types = { OpenSSL.crypto.TYPE_RSA: "RSA", OpenSSL.crypto.TYPE_DSA: "DSA", @@ -204,8 +209,8 @@ class SSLCert: @property def altnames(self): altnames = [] - for i in range(self.cert.get_extension_count()): - ext = self.cert.get_extension(i) + for i in range(self.x509.get_extension_count()): + ext = self.x509.get_extension(i) if ext.get_short_name() == "subjectAltName": dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) for i in dec[0]: diff --git a/test/test_certutils.py b/test/test_certutils.py index 5229fc2a..85dce600 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -50,11 +50,11 @@ class TestDummyCert: class TestSSLCert: def test_simple(self): - c = certutils.SSLCert(file(tutils.test_data.path("data/text_cert"), "r").read()) + c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert"), "r").read()) assert c.cn == "google.com" assert len(c.altnames) == 436 - c = certutils.SSLCert(file(tutils.test_data.path("data/text_cert_2"), "r").read()) + c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert_2"), "r").read()) assert c.cn == "www.inode.co.nz" assert len(c.altnames) == 2 assert c.digest("sha1") -- cgit v1.2.3 From a1491a6ae037b7874dd71de11f5cd43e10aa46e7 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 Jun 2012 08:15:55 +1200 Subject: Add a get_remote_cert method to tcp client. --- netlib/certutils.py | 10 ++++++---- netlib/tcp.py | 1 + test/test_tcp.py | 5 ++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 6c9a5c57..180e1ac0 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -2,6 +2,7 @@ import os, ssl, hashlib, socket, time, datetime from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode import OpenSSL +import tcp CERT_SLEEP_TIME = 1 CERT_EXPIRY = str(365 * 3) @@ -218,7 +219,8 @@ class SSLCert: return altnames -def get_remote_cert(host, port): # pragma: no cover - addr = socket.gethostbyname(host) - s = ssl.get_server_certificate((addr, port)) - return SSLCert(s) +def get_remote_cert(host, port, sni): + c = tcp.TCPClient(host, port) + c.connect() + c.convert_to_ssl(sni=sni) + return c.cert diff --git a/netlib/tcp.py b/netlib/tcp.py index ef3298d5..6c5b4976 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,5 +1,6 @@ import select, socket, threading, traceback, sys from OpenSSL import SSL +import certutils class NetLibError(Exception): pass diff --git a/test/test_tcp.py b/test/test_tcp.py index a2ee5e36..969daf1e 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,5 +1,5 @@ import cStringIO, threading, Queue -from netlib import tcp +from netlib import tcp, certutils import tutils class ServerThread(threading.Thread): @@ -110,6 +110,9 @@ class TestServerSSL(ServerTestBase): c.wfile.flush() assert c.rfile.readline() == testval + def test_get_remote_cert(self): + assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1") + class TestSNI(ServerTestBase): @classmethod -- cgit v1.2.3 From 92c7d38bd343a0436d73c0a984fe111996e15059 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 Jun 2012 09:56:58 +1200 Subject: Handle obscure termination scenario, where interpreter exits before thread termination. --- netlib/tcp.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 0ab7f0e4..f02be550 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -117,14 +117,11 @@ class BaseHandler: def finish(self): self.finished = True - try: - if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.connection.close() - self.wfile.close() - self.rfile.close() - except IOError: # pragma: no cover - pass + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.connection.close() + self.wfile.close() + self.rfile.close() def handle_sni(self, connection): """ @@ -165,8 +162,15 @@ class TCPServer: self.handle_connection(request, client_address) request.close() except: - self.handle_error(request, client_address) - request.close() + try: + self.handle_error(request, client_address) + request.close() + # Why a blanket except here? In some circumstances, a thread can + # persist until the interpreter exits. When this happens, all modules + # and builtins are set to None, and things balls up indeterminate + # ways. + except: + pass def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() -- cgit v1.2.3 From 3f9aad53ab9b567ddc89848c54234d667a846db8 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 Jun 2012 10:59:03 +1200 Subject: Return a certutils.SSLCert object from get_remote_cert. --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index a265ef7a..b3fc2212 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -68,7 +68,7 @@ class TCPClient: self.connection.do_handshake() except SSL.Error, v: raise NetLibError("SSL handshake error: %s"%str(v)) - self.cert = self.connection.get_peer_certificate() + self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) self.ssl_established = True -- cgit v1.2.3 From 7480f87cd721de6ca9d0cdb7c9437bdb58b16ba0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 Jun 2012 14:56:21 +1200 Subject: Add utility function for converstion to PEM. --- netlib/certutils.py | 3 +++ setup.py | 2 +- test/test_certutils.py | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 180e1ac0..dcd54053 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -158,6 +158,9 @@ class SSLCert: pem = ssl.DER_cert_to_PEM_cert(der) return klass.from_pem(pem) + def to_pem(self): + return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, self.x509) + def digest(self, name): return self.x509.digest(name) diff --git a/setup.py b/setup.py index 06ac8aea..e0dff0ff 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ def findPackages(path, dataExclude=[]): long_description = file("README").read() -packages, package_data = findPackages("libpathod") +packages, package_data = findPackages("netlib") setup( name = "netlib", version = version.VERSION, diff --git a/test/test_certutils.py b/test/test_certutils.py index 85dce600..8f95be67 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -64,6 +64,7 @@ class TestSSLCert: assert c.keyinfo == ("RSA", 2048) assert c.serial assert c.issuer + assert c.to_pem() c.has_expired def test_der(self): -- cgit v1.2.3 From 67669a2a578157782a621fa1ac5531bbb2db8029 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 30 Jun 2012 10:52:28 +1200 Subject: Allow control of buffer size for TCPClient, improve error messages. --- netlib/http.py | 6 +++--- netlib/tcp.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 9c72c601..acd9d85e 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -260,15 +260,15 @@ def read_response(rfile, method, body_size_limit): if len(parts) == 2: # handle missing message gracefully parts.append("") if not len(parts) == 3: - raise HttpError(502, "Invalid server response: %s."%line) + raise HttpError(502, "Invalid server response: %s"%repr(line)) proto, code, msg = parts httpversion = parse_http_protocol(proto) if httpversion is None: - raise HttpError(502, "Invalid HTTP version: %s."%httpversion) + raise HttpError(502, "Invalid HTTP version: %s"%repr(httpversion)) try: code = int(code) except ValueError: - raise HttpError(502, "Invalid server response: %s."%line) + raise HttpError(502, "Invalid server response: %s"%repr(line)) headers = read_headers(rfile) if code >= 100 and code <= 199: return read_response(rfile, method, body_size_limit) diff --git a/netlib/tcp.py b/netlib/tcp.py index b3fc2212..bb0a00b9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -50,6 +50,8 @@ class FileLike: class TCPClient: + rbufsize = -1 + wbufsize = -1 def __init__(self, host, port): self.host, self.port = host, port self.connection, self.rfile, self.wfile = None, None, None @@ -78,7 +80,8 @@ class TCPClient: addr = socket.gethostbyname(self.host) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection.connect((addr, self.port)) - self.rfile, self.wfile = connection.makefile('rb'), connection.makefile('wb') + self.rfile = connection.makefile('rb', self.rbufsize) + self.wfile = connection.makefile('wb', self.wbufsize) except socket.error, err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection -- cgit v1.2.3 From 96af5c16a065a8167d167ed1d4dc9e0a77566e25 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 4 Jul 2012 21:30:07 +1200 Subject: Expose SSL options, use TLSv1 by default for client connections. --- netlib/tcp.py | 46 ++++++++++++++++++++++++++++++++++++++++++---- test/test_tcp.py | 25 ++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index bb0a00b9..54148172 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -2,6 +2,37 @@ import select, socket, threading, traceback, sys from OpenSSL import SSL import certutils +SSLv2_METHOD = SSL.SSLv2_METHOD +SSLv3_METHOD = SSL.SSLv3_METHOD +SSLv23_METHOD = SSL.SSLv23_METHOD +TLSv1_METHOD = SSL.TLSv1_METHOD + +OP_ALL = SSL.OP_ALL +OP_CIPHER_SERVER_PREFERENCE = SSL.OP_CIPHER_SERVER_PREFERENCE +OP_COOKIE_EXCHANGE = SSL.OP_COOKIE_EXCHANGE +OP_DONT_INSERT_EMPTY_FRAGMENTS = SSL.OP_DONT_INSERT_EMPTY_FRAGMENTS +OP_EPHEMERAL_RSA = SSL.OP_EPHEMERAL_RSA +OP_MICROSOFT_BIG_SSLV3_BUFFER = SSL.OP_MICROSOFT_BIG_SSLV3_BUFFER +OP_MICROSOFT_SESS_ID_BUG = SSL.OP_MICROSOFT_SESS_ID_BUG +OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING +OP_NETSCAPE_CA_DN_BUG = SSL.OP_NETSCAPE_CA_DN_BUG +OP_NETSCAPE_CHALLENGE_BUG = SSL.OP_NETSCAPE_CHALLENGE_BUG +OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG +OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG +OP_NO_QUERY_MTU = SSL.OP_NO_QUERY_MTU +OP_NO_SSLv2 = SSL.OP_NO_SSLv2 +OP_NO_SSLv3 = SSL.OP_NO_SSLv3 +OP_NO_TICKET = SSL.OP_NO_TICKET +OP_NO_TLSv1 = SSL.OP_NO_TLSv1 +OP_PKCS1_CHECK_1 = SSL.OP_PKCS1_CHECK_1 +OP_PKCS1_CHECK_2 = SSL.OP_PKCS1_CHECK_2 +OP_SINGLE_DH_USE = SSL.OP_SINGLE_DH_USE +OP_SSLEAY_080_CLIENT_DH_BUG = SSL.OP_SSLEAY_080_CLIENT_DH_BUG +OP_SSLREF2_REUSE_CERT_TYPE_BUG = SSL.OP_SSLREF2_REUSE_CERT_TYPE_BUG +OP_TLS_BLOCK_PADDING_BUG = SSL.OP_TLS_BLOCK_PADDING_BUG +OP_TLS_D5_BUG = SSL.OP_TLS_D5_BUG +OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG + class NetLibError(Exception): pass @@ -58,8 +89,10 @@ class TCPClient: self.cert = None self.ssl_established = False - def convert_to_ssl(self, clientcert=None, sni=None): - context = SSL.Context(SSL.SSLv23_METHOD) + def convert_to_ssl(self, clientcert=None, sni=None, method=TLSv1_METHOD, options=None): + context = SSL.Context(method) + if not options is None: + ctx.set_options(options) if clientcert: context.use_certificate_file(self.clientcert) self.connection = SSL.Connection(context, self.connection) @@ -103,8 +136,13 @@ class BaseHandler: self.finished = False self.ssl_established = False - def convert_to_ssl(self, cert, key): - ctx = SSL.Context(SSL.SSLv23_METHOD) + def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None): + """ + method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD + """ + ctx = SSL.Context(method) + if not options is None: + ctx.set_options(options) ctx.set_tlsext_servername_callback(self.handle_sni) ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) diff --git a/test/test_tcp.py b/test/test_tcp.py index 969daf1e..b9f274ae 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -55,17 +55,26 @@ class DisconnectHandler(tcp.BaseHandler): class TServer(tcp.TCPServer): - def __init__(self, addr, ssl, q, handler): + def __init__(self, addr, ssl, q, handler, v3_only=False): tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q + self.v3_only = v3_only self.handler = handler def handle_connection(self, request, client_address): h = self.handler(request, client_address, self) if self.ssl: + if self.v3_only: + method = tcp.SSLv3_METHOD + options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 + else: + method = tcp.SSLv23_METHOD + options = None h.convert_to_ssl( tutils.test_data.path("data/server.crt"), tutils.test_data.path("data/server.key"), + method = method, + options = options, ) h.handle() h.finish() @@ -114,6 +123,20 @@ class TestServerSSL(ServerTestBase): assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1") +class TestSSLv3Only(ServerTestBase): + @classmethod + def makeserver(cls): + cls.q = Queue.Queue() + s = TServer(("127.0.0.1", 0), True, cls.q, EchoHandler, True) + cls.port = s.port + return s + + def test_failure(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD) + + class TestSNI(ServerTestBase): @classmethod def makeserver(cls): -- cgit v1.2.3 From 20cc1b6aa4488d9b230469ba57b6a92380bfeeca Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 5 Jul 2012 09:37:43 +1200 Subject: Refactor TCP test suite. --- netlib/tcp.py | 2 +- test/test_tcp.py | 30 +++++++++--------------------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 54148172..0af3d463 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -209,7 +209,7 @@ class TCPServer: request.close() # Why a blanket except here? In some circumstances, a thread can # persist until the interpreter exits. When this happens, all modules - # and builtins are set to None, and things balls up indeterminate + # and builtins are set to None, and things balls up in indeterminate # ways. except: pass diff --git a/test/test_tcp.py b/test/test_tcp.py index b9f274ae..359890d5 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -17,7 +17,10 @@ class ServerThread(threading.Thread): class ServerTestBase: @classmethod def setupAll(cls): - cls.server = ServerThread(cls.makeserver()) + cls.q = Queue.Queue() + s = cls.makeserver() + cls.port = s.port + cls.server = ServerThread(s) cls.server.start() @classmethod @@ -88,10 +91,7 @@ class TServer(tcp.TCPServer): class TestServer(ServerTestBase): @classmethod def makeserver(cls): - cls.q = Queue.Queue() - s = TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) - cls.port = s.port - return s + return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) def test_echo(self): testval = "echo!\n" @@ -105,10 +105,7 @@ class TestServer(ServerTestBase): class TestServerSSL(ServerTestBase): @classmethod def makeserver(cls): - cls.q = Queue.Queue() - s = TServer(("127.0.0.1", 0), True, cls.q, EchoHandler) - cls.port = s.port - return s + return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -126,10 +123,7 @@ class TestServerSSL(ServerTestBase): class TestSSLv3Only(ServerTestBase): @classmethod def makeserver(cls): - cls.q = Queue.Queue() - s = TServer(("127.0.0.1", 0), True, cls.q, EchoHandler, True) - cls.port = s.port - return s + return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler, True) def test_failure(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -140,10 +134,7 @@ class TestSSLv3Only(ServerTestBase): class TestSNI(ServerTestBase): @classmethod def makeserver(cls): - cls.q = Queue.Queue() - s = TServer(("127.0.0.1", 0), True, cls.q, SNIHandler) - cls.port = s.port - return s + return TServer(("127.0.0.1", 0), True, cls.q, SNIHandler) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -155,10 +146,7 @@ class TestSNI(ServerTestBase): class TestSSLDisconnect(ServerTestBase): @classmethod def makeserver(cls): - cls.q = Queue.Queue() - s = TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler) - cls.port = s.port - return s + return TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) -- cgit v1.2.3 From ba7437abcbf3db11e227cae5e5c1d2df5975c77c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 8 Jul 2012 23:50:38 +1200 Subject: Add an exception to indicate remote disconnects. --- netlib/tcp.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 0af3d463..281a0438 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -36,6 +36,8 @@ OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG class NetLibError(Exception): pass +class NetLibDisconnect(Exception): pass + class FileLike: def __init__(self, o): @@ -61,7 +63,10 @@ class FileLike: return result def write(self, v): - self.o.sendall(v) + try: + return self.o.sendall(v) + except SSL.SysCallError: + raise NetLibDisconnect() def readline(self, size = None): result = '' @@ -159,11 +164,15 @@ class BaseHandler: def finish(self): self.finished = True - if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.connection.close() - self.wfile.close() - self.rfile.close() + try: + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.wfile.close() + self.rfile.close() + self.connection.close() + except socket.error: + # Remote has disconnected + pass def handle_sni(self, connection): """ -- cgit v1.2.3 From 721e2c8277123a99abf6299ee4703109c57675db Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 10 Jul 2012 16:22:45 +1200 Subject: Somewhat nicer handling of errors after thread termination. --- netlib/tcp.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 281a0438..53ad8a05 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -213,15 +213,8 @@ class TCPServer: self.handle_connection(request, client_address) request.close() except: - try: - self.handle_error(request, client_address) - request.close() - # Why a blanket except here? In some circumstances, a thread can - # persist until the interpreter exits. When this happens, all modules - # and builtins are set to None, and things balls up in indeterminate - # ways. - except: - pass + self.handle_error(request, client_address) + request.close() def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() @@ -257,10 +250,14 @@ class TCPServer: """ Called when handle_connection raises an exception. """ - print >> fp, '-'*40 - print >> fp, "Error processing of request from %s:%s"%client_address - print >> fp, traceback.format_exc() - print >> fp, '-'*40 + # If a thread has persisted after interpreter exit, the module might be + # none. + if traceback: + exc = traceback.format_exc() + print >> fp, '-'*40 + print >> fp, "Error in processing of request from %s:%s"%client_address + print >> fp, exc + print >> fp, '-'*40 def handle_connection(self, request, client_address): # pragma: no cover """ -- cgit v1.2.3 From 4fdc2179e25926d531ea8c4a5d6fc78ce75cd6ff Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 10 Jul 2012 16:34:39 +1200 Subject: Don't write empty values. --- netlib/tcp.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 53ad8a05..6ba58d86 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -63,10 +63,11 @@ class FileLike: return result def write(self, v): - try: - return self.o.sendall(v) - except SSL.SysCallError: - raise NetLibDisconnect() + if v: + try: + return self.o.sendall(v) + except SSL.SysCallError: + raise NetLibDisconnect() def readline(self, size = None): result = '' -- cgit v1.2.3 From 1227369db31bff39707091f562b0ad946d14728a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 11 Jul 2012 07:16:45 +1200 Subject: Signal errors back to caller in WSGI .serve() --- netlib/wsgi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 6fe6b6b3..4fa2c537 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -131,6 +131,7 @@ class WSGIAdaptor: except Exception, v: try: s = traceback.format_exc() + errs.write(s) self.error_page(soc, state["headers_sent"], s) except Exception, v: # pragma: no cover pass # pragma: no cover -- cgit v1.2.3 From 9ab7842c81e8b34cd99d5f3e8e98282729d85344 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 11 Jul 2012 11:09:41 +0200 Subject: fix relative certdir --- netlib/certutils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index dcd54053..3effe610 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -43,8 +43,9 @@ def dummy_ca(path): os.makedirs(dirname) if path.endswith(".pem"): basename, _ = os.path.splitext(path) + basename = os.path.basename(basename) else: - basename = path + basename = os.path.basename(basename) key, ca = create_ca() -- cgit v1.2.3 From 63d789109a7ef0bb18e01fdf63851db86aef23bd Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 20 Jul 2012 14:43:51 +1200 Subject: close() methods for clients and servers. --- netlib/tcp.py | 34 +++++++++++++++++++++++++++++++--- test/test_tcp.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 6ba58d86..b7f2b3bc 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -66,7 +66,7 @@ class FileLike: if v: try: return self.o.sendall(v) - except SSL.SysCallError: + except SSL.Error: raise NetLibDisconnect() def readline(self, size = None): @@ -125,6 +125,20 @@ class TCPClient: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection + def close(self): + """ + Does a hard close of the socket, i.e. a shutdown, followed by a close. + """ + try: + if self.ssl_established: + self.connection.shutdown() + else: + self.connection.shutdown(socket.SHUT_RDWR) + self.connection.close() + except (socket.error, SSL.Error): + # Socket probably already closed + pass + class BaseHandler: """ @@ -170,7 +184,7 @@ class BaseHandler: self.wfile.flush() self.wfile.close() self.rfile.close() - self.connection.close() + self.close() except socket.error: # Remote has disconnected pass @@ -195,6 +209,20 @@ class BaseHandler: def handle(self): # pragma: no cover raise NotImplementedError + def close(self): + """ + Does a hard close of the socket, i.e. a shutdown, followed by a close. + """ + try: + if self.ssl_established: + self.connection.shutdown() + else: + self.connection.shutdown(socket.SHUT_RDWR) + self.connection.close() + except (socket.error, SSL.Error): + # Socket probably already closed + pass + class TCPServer: request_queue_size = 20 @@ -252,7 +280,7 @@ class TCPServer: Called when handle_connection raises an exception. """ # If a thread has persisted after interpreter exit, the module might be - # none. + # none. if traceback: exc = traceback.format_exc() print >> fp, '-'*40 diff --git a/test/test_tcp.py b/test/test_tcp.py index 359890d5..cb27c63b 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -54,7 +54,7 @@ class EchoHandler(tcp.BaseHandler): class DisconnectHandler(tcp.BaseHandler): def handle(self): - self.finish() + self.close() class TServer(tcp.TCPServer): @@ -102,6 +102,20 @@ class TestServer(ServerTestBase): assert c.rfile.readline() == testval +class TestDisconnect(ServerTestBase): + @classmethod + def makeserver(cls): + return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) + + def test_echo(self): + testval = "echo!\n" + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval + + class TestServerSSL(ServerTestBase): @classmethod def makeserver(cls): @@ -154,6 +168,24 @@ class TestSSLDisconnect(ServerTestBase): c.convert_to_ssl() # Excercise SSL.ZeroReturnError c.rfile.read(10) + c.close() + tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") + tutils.raises(Queue.Empty, self.q.get_nowait) + + +class TestDisconnect(ServerTestBase): + @classmethod + def makeserver(cls): + return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler) + + def test_echo(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + # Excercise SSL.ZeroReturnError + c.rfile.read(10) + c.wfile.write("foo") + c.close() + c.close() class TestTCPClient: -- cgit v1.2.3 From a1a1663c0fc3a1e76637a0ef3997da697ea97cfe Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 20 Jul 2012 14:45:58 +1200 Subject: Fix cert path. --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 3effe610..1f61132e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -45,7 +45,7 @@ def dummy_ca(path): basename, _ = os.path.splitext(path) basename = os.path.basename(basename) else: - basename = os.path.basename(basename) + basename = os.path.basename(path) key, ca = create_ca() -- cgit v1.2.3 From ba53d2e4caa34df883a2cd6322d607426c97201b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 20 Jul 2012 15:15:07 +1200 Subject: Set ssl_established right after the connection object is changed. --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index b7f2b3bc..3aee4c74 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -102,6 +102,7 @@ class TCPClient: if clientcert: context.use_certificate_file(self.clientcert) self.connection = SSL.Connection(context, self.connection) + self.ssl_established = True if sni: self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() @@ -112,7 +113,6 @@ class TCPClient: self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) - self.ssl_established = True def connect(self): try: @@ -167,6 +167,7 @@ class BaseHandler: ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) self.connection = SSL.Connection(ctx, self.connection) + self.ssl_established = True self.connection.set_accept_state() # SNI callback happens during do_handshake() try: @@ -175,7 +176,6 @@ class BaseHandler: raise NetLibError("SSL handshake error: %s"%str(v)) self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) - self.ssl_established = True def finish(self): self.finished = True -- cgit v1.2.3 From 2387d2e8ed7d94e42b1ac02a4ea73f54e4c63ab8 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 21 Jul 2012 16:10:54 +1200 Subject: Timeout for TCP clients. --- netlib/tcp.py | 36 ++++++++++++++++++++++++++++-------- test/test_tcp.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 3aee4c74..8771e789 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,4 +1,4 @@ -import select, socket, threading, traceback, sys +import select, socket, threading, traceback, sys, time from OpenSSL import SSL import certutils @@ -35,8 +35,8 @@ OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG class NetLibError(Exception): pass - class NetLibDisconnect(Exception): pass +class NetLibTimeout(Exception): pass class FileLike: @@ -47,15 +47,25 @@ class FileLike: return getattr(self.o, attr) def flush(self): - pass + if hasattr(self.o, "flush"): + self.o.flush() def read(self, length): result = '' + start = time.time() while length > 0: try: data = self.o.read(length) except (SSL.ZeroReturnError, SSL.SysCallError): break + except SSL.WantReadError: + if (time.time() - start) < self.o.gettimeout(): + time.sleep(0.1) + continue + else: + raise NetLibTimeout + except socket.timeout: + raise NetLibTimeout if not data: break result += data @@ -65,7 +75,11 @@ class FileLike: def write(self, v): if v: try: - return self.o.sendall(v) + if hasattr(self.o, "sendall"): + return self.o.sendall(v) + else: + r = self.o.write(v) + return r except SSL.Error: raise NetLibDisconnect() @@ -119,12 +133,18 @@ class TCPClient: addr = socket.gethostbyname(self.host) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection.connect((addr, self.port)) - self.rfile = connection.makefile('rb', self.rbufsize) - self.wfile = connection.makefile('wb', self.wbufsize) + self.rfile = FileLike(connection.makefile('rb', self.rbufsize)) + self.wfile = FileLike(connection.makefile('wb', self.wbufsize)) except socket.error, err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection + def settimeout(self, n): + self.connection.settimeout(n) + + def gettimeout(self): + self.connection.gettimeout() + def close(self): """ Does a hard close of the socket, i.e. a shutdown, followed by a close. @@ -148,8 +168,8 @@ class BaseHandler: wbufsize = -1 def __init__(self, connection, client_address, server): self.connection = connection - self.rfile = self.connection.makefile('rb', self.rbufsize) - self.wfile = self.connection.makefile('wb', self.wbufsize) + self.rfile = FileLike(self.connection.makefile('rb', self.rbufsize)) + self.wfile = FileLike(self.connection.makefile('wb', self.wbufsize)) self.client_address = client_address self.server = server diff --git a/test/test_tcp.py b/test/test_tcp.py index cb27c63b..d6235b01 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,4 +1,4 @@ -import cStringIO, threading, Queue +import cStringIO, threading, Queue, time from netlib import tcp, certutils import tutils @@ -57,6 +57,12 @@ class DisconnectHandler(tcp.BaseHandler): self.close() +class HangHandler(tcp.BaseHandler): + def handle(self): + while 1: + time.sleep(1) + + class TServer(tcp.TCPServer): def __init__(self, addr, ssl, q, handler, v3_only=False): tcp.TCPServer.__init__(self, addr) @@ -188,6 +194,31 @@ class TestDisconnect(ServerTestBase): c.close() +class TestTimeOut(ServerTestBase): + @classmethod + def makeserver(cls): + return TServer(("127.0.0.1", 0), False, cls.q, HangHandler) + + def test_timeout_client(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.settimeout(0.1) + tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) + + +class TestSSLTimeOut(ServerTestBase): + @classmethod + def makeserver(cls): + return TServer(("127.0.0.1", 0), True, cls.q, HangHandler) + + def test_timeout_client(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.convert_to_ssl() + c.settimeout(0.1) + tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) + + class TestTCPClient: def test_conerr(self): c = tcp.TCPClient("127.0.0.1", 0) -- cgit v1.2.3 From 29f907ecf98468a89b5a7575b539938dc6741a8e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 21 Jul 2012 17:27:23 +1200 Subject: Handle HTTP versions malformed due to non-integer major/minor numbers. --- netlib/http.py | 7 +++++-- test/test_http.py | 2 ++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index acd9d85e..88e66ce4 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -145,8 +145,11 @@ def parse_http_protocol(s): if "." not in version: return None major, minor = version.split('.') - major = int(major) - minor = int(minor) + try: + major = int(major) + minor = int(minor) + except ValueError: + return None return major, minor diff --git a/test/test_http.py b/test/test_http.py index 206fc4df..0174a4aa 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -106,6 +106,8 @@ def test_read_http_body(): def test_parse_http_protocol(): assert http.parse_http_protocol("HTTP/1.1") == (1, 1) assert http.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not http.parse_http_protocol("HTTP/a.1") + assert not http.parse_http_protocol("HTTP/1.a") assert not http.parse_http_protocol("foo/0.0") assert not http.parse_http_protocol("HTTP/x") -- cgit v1.2.3 From b2c491fe3936b04b0c8e6349775bf53063c170a6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 21 Jul 2012 17:50:21 +1200 Subject: Handle socket disconnects on reads. --- netlib/tcp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 8771e789..ac4fab95 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -66,6 +66,8 @@ class FileLike: raise NetLibTimeout except socket.timeout: raise NetLibTimeout + except socket.error: + raise NetLibDisconnect if not data: break result += data -- cgit v1.2.3 From 619f3c6edce50a6e83b817d43ee0357cc763dd3d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 21 Jul 2012 20:51:05 +1200 Subject: Handle unexpected SSL connection termination in readline. --- netlib/tcp.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index ac4fab95..a68b608b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -56,7 +56,7 @@ class FileLike: while length > 0: try: data = self.o.read(length) - except (SSL.ZeroReturnError, SSL.SysCallError): + except SSL.ZeroReturnError: break except SSL.WantReadError: if (time.time() - start) < self.o.gettimeout(): @@ -68,6 +68,8 @@ class FileLike: raise NetLibTimeout except socket.error: raise NetLibDisconnect + except SSL.SysCallError, v: + raise NetLibDisconnect if not data: break result += data @@ -82,7 +84,7 @@ class FileLike: else: r = self.o.write(v) return r - except SSL.Error: + except (SSL.Error, socket.error): raise NetLibDisconnect() def readline(self, size = None): @@ -91,7 +93,10 @@ class FileLike: while True: if size is not None and bytes_read >= size: break - ch = self.read(1) + try: + ch = self.read(1) + except NetLibDisconnect: + break bytes_read += 1 if not ch: break -- cgit v1.2.3 From ed64b0e79699681bd5db3ff2823c47a424fbc3e1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 22 Jul 2012 12:35:16 +1200 Subject: Fix http_protocol parsing crash discovered with pathoc fuzzing. --- netlib/http.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 88e66ce4..9d6db003 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -141,10 +141,10 @@ def parse_http_protocol(s): """ if not s.startswith("HTTP/"): return None - _, version = s.split('/') + _, version = s.split('/', 1) if "." not in version: return None - major, minor = version.split('.') + major, minor = version.split('.', 1) try: major = int(major) minor = int(minor) -- cgit v1.2.3 From eb88cea3c74a253d3a08d010bfd328aa845c6d5b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 23 Jul 2012 23:20:32 +1200 Subject: Catch an amazingly subtle SSL connection corruption bug. Closing a set of pseudo-file descriptors in the wrong order caused junk data to be written to the SSL stream. An apparent bug in OpenSSL then lets this corrupt the _next_ SSL connection. --- netlib/tcp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index a68b608b..66a26872 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -209,9 +209,9 @@ class BaseHandler: try: if not getattr(self.wfile, "closed", False): self.wfile.flush() + self.close() self.wfile.close() self.rfile.close() - self.close() except socket.error: # Remote has disconnected pass @@ -245,10 +245,10 @@ class BaseHandler: self.connection.shutdown() else: self.connection.shutdown(socket.SHUT_RDWR) - self.connection.close() - except (socket.error, SSL.Error): + except (socket.error, SSL.Error), v: # Socket probably already closed pass + self.connection.close() class TCPServer: -- cgit v1.2.3 From 91752990d5863526745e5c31cfb4b7459d11047e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 24 Jul 2012 11:39:49 +1200 Subject: Handle HTTP responses that have a body but no content-length or transfer encoding We check if the server sent a connection:close header, and read till the socket closes. Closes #2 --- netlib/http.py | 37 +++++++++++++++++++++++-------------- netlib/tcp.py | 11 ++++++++--- test/test_http.py | 23 ++++++++++++++++++++++- test/test_tcp.py | 6 ++++++ 4 files changed, 59 insertions(+), 18 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 9d6db003..980d3f62 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -97,12 +97,21 @@ def read_chunked(code, fp, limit): return content -def has_chunked_encoding(headers): - for i in headers["transfer-encoding"]: +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: for j in i.split(","): - if j.lower() == "chunked": - return True - return False + toks.append(j.strip()) + return toks + + +def has_chunked_encoding(headers): + return "chunked" in [i.lower() for i in get_header_tokens(headers, "transfer-encoding")] def read_http_body(code, rfile, headers, all, limit): @@ -207,12 +216,11 @@ def request_connection_close(httpversion, headers): Checks the request to see if the client connection should be closed. """ if "connection" in headers: - for value in ",".join(headers['connection']).split(","): - value = value.strip() - if value == "close": - return True - elif value == "keep-alive": - return False + toks = get_header_tokens(headers, "connection") + if "close" in toks: + return True + elif "keep-alive" in toks: + return False # HTTP 1.1 connections are assumed to be persistent if httpversion == (1, 1): return False @@ -243,10 +251,11 @@ def read_http_body_request(rfile, wfile, headers, httpversion, limit): return read_http_body(400, rfile, headers, False, limit) -def read_http_body_response(rfile, headers, all, limit): +def read_http_body_response(rfile, headers, limit): """ Read the HTTP body from a server response. """ + all = "close" in get_header_tokens(headers, "connection") return read_http_body(500, rfile, headers, all, limit) @@ -267,7 +276,7 @@ def read_response(rfile, method, body_size_limit): proto, code, msg = parts httpversion = parse_http_protocol(proto) if httpversion is None: - raise HttpError(502, "Invalid HTTP version: %s"%repr(httpversion)) + raise HttpError(502, "Invalid HTTP version in line: %s"%repr(proto)) try: code = int(code) except ValueError: @@ -278,5 +287,5 @@ def read_response(rfile, method, body_size_limit): if method == "HEAD" or code == 204 or code == 304: content = "" else: - content = read_http_body_response(rfile, headers, False, body_size_limit) + content = read_http_body_response(rfile, headers, body_size_limit) return httpversion, code, msg, headers, content diff --git a/netlib/tcp.py b/netlib/tcp.py index 66a26872..7d3705da 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -40,6 +40,7 @@ class NetLibTimeout(Exception): pass class FileLike: + BLOCKSIZE = 1024 * 32 def __init__(self, o): self.o = o @@ -51,11 +52,14 @@ class FileLike: self.o.flush() def read(self, length): + """ + If length is None, we read until connection closes. + """ result = '' start = time.time() - while length > 0: + while length == -1 or length > 0: try: - data = self.o.read(length) + data = self.o.read(self.BLOCKSIZE if length == -1 else length) except SSL.ZeroReturnError: break except SSL.WantReadError: @@ -73,7 +77,8 @@ class FileLike: if not data: break result += data - length -= len(data) + if length != -1: + length -= len(data) return result def write(self, v): diff --git a/test/test_http.py b/test/test_http.py index 0174a4aa..0b83e65a 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -64,7 +64,28 @@ def test_read_http_body_response(): h = odict.ODictCaseless() h["content-length"] = [7] s = cStringIO.StringIO("testing") - assert http.read_http_body_response(s, h, False, None) == "testing" + assert http.read_http_body_response(s, h, None) == "testing" + + + h = odict.ODictCaseless() + s = cStringIO.StringIO("testing") + assert not http.read_http_body_response(s, h, None) + + h = odict.ODictCaseless() + h["connection"] = ["close"] + s = cStringIO.StringIO("testing") + assert http.read_http_body_response(s, h, None) == "testing" + + +def test_get_header_tokens(): + h = odict.ODictCaseless() + assert http.get_header_tokens(h, "foo") == [] + h["foo"] = ["bar"] + assert http.get_header_tokens(h, "foo") == ["bar"] + h["foo"] = ["bar, voing"] + assert http.get_header_tokens(h, "foo") == ["bar", "voing"] + h["foo"] = ["bar, voing", "oink"] + assert http.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] def test_read_http_body_request(): diff --git a/test/test_tcp.py b/test/test_tcp.py index d6235b01..67c56a37 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -239,3 +239,9 @@ class TestFileLike: s = cStringIO.StringIO("foobar\nfoobar") s = tcp.FileLike(s) assert s.readline(3) == "foo" + + def test_limitless(self): + s = cStringIO.StringIO("f"*(50*1024)) + s = tcp.FileLike(s) + ret = s.read(-1) + assert len(ret) == 50 * 1024 -- cgit v1.2.3 From 728ef107a00e7d6cef0c7d826f39a89197ddb732 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 24 Jul 2012 14:55:54 +1200 Subject: Ignore SAN entries that we don't understand. --- netlib/certutils.py | 6 +++++- test/data/text_cert_weird1 | 31 +++++++++++++++++++++++++++++++ test/test_certutils.py | 5 +++++ 3 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 test/data/text_cert_weird1 diff --git a/netlib/certutils.py b/netlib/certutils.py index 1f61132e..f55a096b 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,6 +1,7 @@ import os, ssl, hashlib, socket, time, datetime from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode +from pyasn1.error import PyAsn1Error import OpenSSL import tcp @@ -217,7 +218,10 @@ class SSLCert: for i in range(self.x509.get_extension_count()): ext = self.x509.get_extension(i) if ext.get_short_name() == "subjectAltName": - dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) + try: + dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) + except PyAsn1Error: + continue for i in dec[0]: altnames.append(i[0].asOctets()) return altnames diff --git a/test/data/text_cert_weird1 b/test/data/text_cert_weird1 new file mode 100644 index 00000000..72b09dcb --- /dev/null +++ b/test/data/text_cert_weird1 @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFNDCCBBygAwIBAgIEDFJFNzANBgkqhkiG9w0BAQUFADCBjDELMAkGA1UEBhMC +REUxHjAcBgNVBAoTFVVuaXZlcnNpdGFldCBNdWVuc3RlcjE6MDgGA1UEAxMxWmVy +dGlmaXppZXJ1bmdzc3RlbGxlIFVuaXZlcnNpdGFldCBNdWVuc3RlciAtIEcwMjEh +MB8GCSqGSIb3DQEJARYSY2FAdW5pLW11ZW5zdGVyLmRlMB4XDTA4MDUyMDEyNDQy +NFoXDTEzMDUxOTEyNDQyNFowezELMAkGA1UEBhMCREUxHjAcBgNVBAoTFVVuaXZl +cnNpdGFldCBNdWVuc3RlcjEuMCwGA1UECxMlWmVudHJ1bSBmdWVyIEluZm9ybWF0 +aW9uc3ZlcmFyYmVpdHVuZzEcMBoGA1UEAxMTd3d3LnVuaS1tdWVuc3Rlci5kZTCC +ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMM0WlCj0ew+tyZ1GurBOqFn +AlChKk4S1F9oDzvp3FwOON4H8YFET7p9ZnoWtkfXSlGNMjekqy67dFlLt1sLusSo +tjNdaOrDLYmnGEgnYAT0RFBvErzIybJoD/Vu3NXyhes+L94R9mEMCwYXmSvG51H9 +c5CvguXBofMchDLCM/U6AYpwu3sST5orV3S1Rsa9sndj8sKJAcw195PYwl6EiEBb +M36ltDBlTYEUAg3Z+VSzB09J3U4vSvguVkDCz+szZh5RG3xlN9mlNfzhf4lHrNgV +0BRbKypa5Uuf81wbMcMMqTxKq+A9ysObpn9J3pNUym+Tn2oqHzGgvwZYB4tzXqUC +AwEAAaOCAawwggGoMAkGA1UdEwQCMAAwCwYDVR0PBAQDAgTwMBMGA1UdJQQMMAoG +CCsGAQUFBwMBMB0GA1UdDgQWBBQ3RFo8awewUTq5TpOFf3jOCEKihzAfBgNVHSME +GDAWgBS+nlGiyZJ8u2CL5rBoZHdaUhmhADAjBgNVHREEHDAagRh3d3dhZG1pbkB1 +bmktbXVlbnN0ZXIuZGUwewYDVR0fBHQwcjA3oDWgM4YxaHR0cDovL2NkcDEucGNh +LmRmbi5kZS93d3UtY2EvcHViL2NybC9nX2NhY3JsLmNybDA3oDWgM4YxaHR0cDov +L2NkcDIucGNhLmRmbi5kZS93d3UtY2EvcHViL2NybC9nX2NhY3JsLmNybDCBlgYI +KwYBBQUHAQEEgYkwgYYwQQYIKwYBBQUHMAKGNWh0dHA6Ly9jZHAxLnBjYS5kZm4u +ZGUvd3d1LWNhL3B1Yi9jYWNlcnQvZ19jYWNlcnQuY3J0MEEGCCsGAQUFBzAChjVo +dHRwOi8vY2RwMi5wY2EuZGZuLmRlL3d3dS1jYS9wdWIvY2FjZXJ0L2dfY2FjZXJ0 +LmNydDANBgkqhkiG9w0BAQUFAAOCAQEAFfNpagtcKUSDKss7TcqjYn99FQ4FtWjE +pGmzYL2zX2wsdCGoVQlGkieL9slbQVEUAnBuqM1LPzUNNe9kZpOPV3Rdhq4y8vyS +xkx3G1v5aGxfPUe8KM8yKIOHRqYefNronHJM0fw7KyjQ73xgbIEgkW+kNXaMLcrb +EPC36O2Zna8GP9FQxJRLgcfQCcYdRKGVn0EtRSkz2ym5Rbh/hrmJBbbC2yJGGMI0 +Vu5A9piK0EZPekZIUmhMQynD9QcMfWhTEFr7YZfx9ktxKDW4spnu7YrgICfZNcCm +tfxmnEAFt6a47u9P0w9lpY8+Sx9MNFfTePym+HP4TYha9bIBes+XnA== +-----END CERTIFICATE----- + diff --git a/test/test_certutils.py b/test/test_certutils.py index 8f95be67..9b8e7085 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -67,6 +67,11 @@ class TestSSLCert: assert c.to_pem() c.has_expired + def test_err_broken_sans(self): + c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert_weird1"), "r").read()) + # This breaks unless we ignore a decoding error. + c.altnames + def test_der(self): d = file(tutils.test_data.path("data/dercert")).read() s = certutils.SSLCert.from_der(d) -- cgit v1.2.3 From 4fb5d15f1480dd6ca86578aca2d0784bfef31dac Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 29 Jul 2012 15:53:42 +1200 Subject: Bump version. --- README | 1 + netlib/version.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README b/README index 958a0302..972e03b8 100644 --- a/README +++ b/README @@ -8,3 +8,4 @@ servers are implemented to allow misbehaviour when needed. At this point, I have no plans to make netlib useful beyond mitmproxy and pathod. Please get in touch if you think parts of netlib might have broader utility. + diff --git a/netlib/version.py b/netlib/version.py index 1c4a4b66..20460ad5 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 1) +IVERSION = (0, 2) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From eafa5566c27ec321131a9d83d85dab512aae7a37 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 30 Jul 2012 11:30:31 +1200 Subject: Handle disconnects on flush. --- netlib/tcp.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 7d3705da..e7bc79a8 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,8 +48,11 @@ class FileLike: return getattr(self.o, attr) def flush(self): - if hasattr(self.o, "flush"): - self.o.flush() + try: + if hasattr(self.o, "flush"): + self.o.flush() + except socket.error, v: + raise NetLibDisconnect(str(v)) def read(self, length): """ -- cgit v1.2.3 From 1c21a28e6423edf3b903191610b45345720e0458 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 30 Jul 2012 12:50:35 +1200 Subject: read_headers: handle some crashes, return None on invalid data. --- netlib/http.py | 10 ++++++++-- test/test_http.py | 40 ++++++++++++++++++++++++++++------------ 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 980d3f62..b71eb72d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -36,8 +36,8 @@ def parse_url(url): def read_headers(fp): """ - Read a set of headers from a file pointer. Stop once a blank line - is reached. Return a ODictCaseless object. + Read a set of headers from a file pointer. Stop once a blank line is + reached. Return a ODictCaseless object, or None if headers are invalid. """ ret = [] name = '' @@ -46,6 +46,8 @@ def read_headers(fp): if not line or line == '\r\n' or line == '\n': break if line[0] in ' \t': + if not ret: + return None # continued header ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() else: @@ -55,6 +57,8 @@ def read_headers(fp): name = line[:i] value = line[i+1:].strip() ret.append([name, value]) + else: + return None return odict.ODictCaseless(ret) @@ -282,6 +286,8 @@ def read_response(rfile, method, body_size_limit): except ValueError: raise HttpError(502, "Invalid server response: %s"%repr(line)) headers = read_headers(rfile) + if headers is None: + raise HttpError(502, "Invalid headers.") if code >= 100 and code <= 199: return read_response(rfile, method, body_size_limit) if method == "HEAD" or code == 204 or code == 304: diff --git a/test/test_http.py b/test/test_http.py index 0b83e65a..a6161fbc 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -169,16 +169,20 @@ def test_parse_init_http(): class TestReadHeaders: + def _read(self, data, verbatim=False): + if not verbatim: + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + return http.read_headers(s) + def test_read_simple(self): data = """ Header: one Header2: two \r\n """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = http.read_headers(s) + h = self._read(data) assert h.lst == [["Header", "one"], ["Header2", "two"]] def test_read_multi(self): @@ -187,10 +191,7 @@ class TestReadHeaders: Header: two \r\n """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = http.read_headers(s) + h = self._read(data) assert h.lst == [["Header", "one"], ["Header", "two"]] def test_read_continued(self): @@ -200,12 +201,19 @@ class TestReadHeaders: Header2: three \r\n """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = http.read_headers(s) + h = self._read(data) assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] + def test_read_continued_err(self): + data = "\tfoo: bar\r\n" + assert self._read(data, True) is None + + def test_read_err(self): + data = """ + foo + """ + assert self._read(data) is None + def test_read_response(): def tst(data, method, limit): @@ -248,6 +256,14 @@ def test_read_response(): assert tst(data, "GET", None)[4] == 'foo' assert tst(data, "HEAD", None)[4] == '' + data = """ + HTTP/1.1 200 OK + \tContent-Length: 3 + + foo + """ + tutils.raises("invalid headers", tst, data, "GET", None) + def test_parse_url(): assert not http.parse_url("") -- cgit v1.2.3 From 877a3e206263edbd8a973689b08f8c004de0225f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 18 Aug 2012 18:14:13 +1200 Subject: Add a get_first convenience function to ODict. --- netlib/odict.py | 6 ++++++ test/test_odict.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/netlib/odict.py b/netlib/odict.py index afc33caa..629fcade 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -80,6 +80,12 @@ class ODict: else: return d + def get_first(self, k, d=None): + if k in self: + return self[k][0] + else: + return d + def items(self): return self.lst[:] diff --git a/test/test_odict.py b/test/test_odict.py index e7453e2d..f27f6f8b 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -85,6 +85,12 @@ class TestODict: assert self.od.get("one") == ["two"] assert self.od.get("two") == None + def test_get_first(self): + self.od.add("one", "two") + self.od.add("one", "three") + assert self.od.get_first("one") == "two" + assert self.od.get_first("two") == None + class TestODictCaseless: def setUp(self): -- cgit v1.2.3 From 33557245bf2212c08cd645bcf21a73b773646607 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 23 Aug 2012 12:57:22 +1200 Subject: v0.2.1 --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index 20460ad5..614b87a1 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 2) +IVERSION = (0, 2, 1) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From 1c80c2fdd7dd9873abc7b0a74936dab7beda7c5c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 1 Sep 2012 23:04:44 +1200 Subject: Add a collection of standard User-Agent strings. These will be used in both mitmproxy and pathod. --- netlib/http_uastrings.py | 77 +++++++++++++++++++++++++++++++++++++++++++++ test/test_http_uastrings.py | 7 +++++ 2 files changed, 84 insertions(+) create mode 100644 netlib/http_uastrings.py create mode 100644 test/test_http_uastrings.py diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py new file mode 100644 index 00000000..826c31a5 --- /dev/null +++ b/netlib/http_uastrings.py @@ -0,0 +1,77 @@ +""" + A small collection of useful user-agent header strings. These should be + kept reasonably current to reflect common usage. +""" + +# A collection of (name, shortcut, string) tuples. + +UASTRINGS = [ + ( + "android", + "a", + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02" + ), + + ( + "blackberry", + "l", + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+" + ), + + ( + "bingbot", + "b", + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)" + ), + + ( + "chrome", + "c", + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1" + ), + + ( + "firefox", + "f", + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1" + ), + + ( + "googlebot", + "g", + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)" + ), + + ( + "ie9", + "i", + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))" + ), + + ( + "ipad", + "p", + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3" + ), + + ( + "iphone", + "h", + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5", + ), + + ( + "safari", + "s", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10" + ) +] + + +def get_by_shortcut(s): + """ + Retrieve a user agent entry by shortcut. + """ + for i in UASTRINGS: + if s == i[1]: + return i diff --git a/test/test_http_uastrings.py b/test/test_http_uastrings.py new file mode 100644 index 00000000..c70b7048 --- /dev/null +++ b/test/test_http_uastrings.py @@ -0,0 +1,7 @@ +from netlib import http_uastrings + + +def test_get_shortcut(): + assert http_uastrings.get_by_shortcut("c")[0] == "chrome" + assert not http_uastrings.get_by_shortcut("_") + -- cgit v1.2.3 From 8a6cca530c5293aa2b77edd3bf928540ec771928 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 24 Sep 2012 10:47:41 +1200 Subject: Don't create fresh FileLike objects when converting to SSL --- netlib/tcp.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index e7bc79a8..0fed7380 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -44,6 +44,9 @@ class FileLike: def __init__(self, o): self.o = o + def set_descriptor(self, o): + self.o = o + def __getattr__(self, attr): return getattr(self.o, attr) @@ -140,8 +143,8 @@ class TCPClient: except SSL.Error, v: raise NetLibError("SSL handshake error: %s"%str(v)) self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) - self.rfile = FileLike(self.connection) - self.wfile = FileLike(self.connection) + self.rfile.set_descriptor(self.connection) + self.wfile.set_descriptor(self.connection) def connect(self): try: @@ -209,8 +212,8 @@ class BaseHandler: self.connection.do_handshake() except SSL.Error, v: raise NetLibError("SSL handshake error: %s"%str(v)) - self.rfile = FileLike(self.connection) - self.wfile = FileLike(self.connection) + self.rfile.set_descriptor(self.connection) + self.wfile.set_descriptor(self.connection) def finish(self): self.finished = True -- cgit v1.2.3 From 3a21e28bf13b5710639337fdc29741e9b6b71405 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 24 Sep 2012 11:10:21 +1200 Subject: Split FileLike into Writer and Reader, and add logging functionality. --- netlib/tcp.py | 69 +++++++++++++++++++++++++++++++++++++++++++------------- test/test_tcp.py | 34 +++++++++++++++++++++++++--- 2 files changed, 84 insertions(+), 19 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 0fed7380..e1318435 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -39,10 +39,11 @@ class NetLibDisconnect(Exception): pass class NetLibTimeout(Exception): pass -class FileLike: +class _FileLike: BLOCKSIZE = 1024 * 32 def __init__(self, o): self.o = o + self._log = None def set_descriptor(self, o): self.o = o @@ -50,6 +51,37 @@ class FileLike: def __getattr__(self, attr): return getattr(self.o, attr) + def start_log(self): + """ + Starts or resets the log. + + This will store all bytes read or written. + """ + self._log = [] + + def stop_log(self): + """ + Stops the log. + """ + self._log = None + + def is_logging(self): + return self._log is not None + + def get_log(self): + """ + Returns the log as a string. + """ + if not self.is_logging(): + raise ValueError("Not logging!") + return "".join(self._log) + + def add_log(self, v): + if self.is_logging(): + self._log.append(v) + + +class Writer(_FileLike): def flush(self): try: if hasattr(self.o, "flush"): @@ -57,6 +89,21 @@ class FileLike: except socket.error, v: raise NetLibDisconnect(str(v)) + def write(self, v): + if v: + try: + if hasattr(self.o, "sendall"): + self.add_log(v) + return self.o.sendall(v) + else: + r = self.o.write(v) + self.add_log(v[:r]) + return r + except (SSL.Error, socket.error): + raise NetLibDisconnect() + + +class Reader(_FileLike): def read(self, length): """ If length is None, we read until connection closes. @@ -85,19 +132,9 @@ class FileLike: result += data if length != -1: length -= len(data) + self.add_log(result) return result - def write(self, v): - if v: - try: - if hasattr(self.o, "sendall"): - return self.o.sendall(v) - else: - r = self.o.write(v) - return r - except (SSL.Error, socket.error): - raise NetLibDisconnect() - def readline(self, size = None): result = '' bytes_read = 0 @@ -151,8 +188,8 @@ class TCPClient: addr = socket.gethostbyname(self.host) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection.connect((addr, self.port)) - self.rfile = FileLike(connection.makefile('rb', self.rbufsize)) - self.wfile = FileLike(connection.makefile('wb', self.wbufsize)) + self.rfile = Reader(connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except socket.error, err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection @@ -186,8 +223,8 @@ class BaseHandler: wbufsize = -1 def __init__(self, connection, client_address, server): self.connection = connection - self.rfile = FileLike(self.connection.makefile('rb', self.rbufsize)) - self.wfile = FileLike(self.connection.makefile('wb', self.wbufsize)) + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) self.client_address = client_address self.server = server diff --git a/test/test_tcp.py b/test/test_tcp.py index 67c56a37..9d581939 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -228,8 +228,8 @@ class TestTCPClient: class TestFileLike: def test_wrap(self): s = cStringIO.StringIO("foobar\nfoobar") - s = tcp.FileLike(s) s.flush() + s = tcp.Reader(s) assert s.readline() == "foobar\n" assert s.readline() == "foobar" # Test __getattr__ @@ -237,11 +237,39 @@ class TestFileLike: def test_limit(self): s = cStringIO.StringIO("foobar\nfoobar") - s = tcp.FileLike(s) + s = tcp.Reader(s) assert s.readline(3) == "foo" def test_limitless(self): s = cStringIO.StringIO("f"*(50*1024)) - s = tcp.FileLike(s) + s = tcp.Reader(s) ret = s.read(-1) assert len(ret) == 50 * 1024 + + def test_readlog(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.Reader(s) + assert not s.is_logging() + s.start_log() + assert s.is_logging() + s.readline() + assert s.get_log() == "foobar\n" + s.read(1) + assert s.get_log() == "foobar\nf" + s.start_log() + assert s.get_log() == "" + s.read(1) + assert s.get_log() == "o" + s.stop_log() + tutils.raises(ValueError, s.get_log) + + def test_writelog(self): + s = cStringIO.StringIO() + s = tcp.Writer(s) + s.start_log() + assert s.is_logging() + s.write("x") + assert s.get_log() == "x" + s.write("x") + assert s.get_log() == "xx" + -- cgit v1.2.3 From b308824193342c11c88b8bad2645a5b09efcf48f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 24 Sep 2012 11:21:48 +1200 Subject: Create netlib.utils, move cleanBin and hexdump from libmproxy.utils. --- netlib/utils.py | 36 ++++++++++++++++++++++++++++++++++++ test/test_utils.py | 13 +++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 netlib/utils.py create mode 100644 test/test_utils.py diff --git a/netlib/utils.py b/netlib/utils.py new file mode 100644 index 00000000..ea749545 --- /dev/null +++ b/netlib/utils.py @@ -0,0 +1,36 @@ + +def cleanBin(s, fixspacing=False): + """ + Cleans binary data to make it safe to display. If fixspacing is True, + tabs, newlines and so forth will be maintained, if not, they will be + replaced with a placeholder. + """ + parts = [] + for i in s: + o = ord(i) + if (o > 31 and o < 127): + parts.append(i) + elif i in "\n\r\t" and not fixspacing: + parts.append(i) + else: + parts.append(".") + return "".join(parts) + + +def hexdump(s): + """ + Returns a set of tuples: + (offset, hex, str) + """ + parts = [] + for i in range(0, len(s), 16): + o = "%.10x"%i + part = s[i:i+16] + x = " ".join("%.2x"%ord(i) for i in part) + if len(part) < 16: + x += " " + x += " ".join(" " for i in range(16 - len(part))) + parts.append( + (o, x, cleanBin(part, True)) + ) + return parts diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..61820a81 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,13 @@ +from netlib import utils + + +def test_hexdump(): + assert utils.hexdump("one\0"*10) + + +def test_cleanBin(): + assert utils.cleanBin("one") == "one" + assert utils.cleanBin("\00ne") == ".ne" + assert utils.cleanBin("\nne") == "\nne" + assert utils.cleanBin("\nne", True) == ".ne" + -- cgit v1.2.3 From 064b4c80018d9b76c2bedc010ab45c8b9ea7faa3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 27 Sep 2012 10:59:46 +1200 Subject: Make cleanBin escape carriage returns. We get confusing output on terminals if we leave \r unescaped. --- netlib/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/utils.py b/netlib/utils.py index ea749545..7621a1dc 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -10,7 +10,7 @@ def cleanBin(s, fixspacing=False): o = ord(i) if (o > 31 and o < 127): parts.append(i) - elif i in "\n\r\t" and not fixspacing: + elif i in "\n\t" and not fixspacing: parts.append(i) else: parts.append(".") -- cgit v1.2.3 From 15679e010d99def2fb7efd1de5533099a12772ca Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 1 Oct 2012 11:30:02 +1300 Subject: Add a settimeout method to tcp.BaseHandler. --- netlib/tcp.py | 3 +++ test/test_tcp.py | 37 +++++++++++++++++++++++++++++++++---- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index e1318435..414c1237 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -284,6 +284,9 @@ class BaseHandler: def handle(self): # pragma: no cover raise NotImplementedError + def settimeout(self, n): + self.connection.settimeout(n) + def close(self): """ Does a hard close of the socket, i.e. a shutdown, followed by a close. diff --git a/test/test_tcp.py b/test/test_tcp.py index 9d581939..c833ce07 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -28,6 +28,11 @@ class ServerTestBase: cls.server.shutdown() + @property + def last_handler(self): + return self.server.server.last_handler + + class SNIHandler(tcp.BaseHandler): sni = None def handle_sni(self, connection): @@ -63,15 +68,27 @@ class HangHandler(tcp.BaseHandler): time.sleep(1) +class TimeoutHandler(tcp.BaseHandler): + def handle(self): + self.timeout = False + self.settimeout(0.01) + try: + self.rfile.read(10) + except tcp.NetLibTimeout: + self.timeout = True + + class TServer(tcp.TCPServer): - def __init__(self, addr, ssl, q, handler, v3_only=False): + def __init__(self, addr, ssl, q, handler_klass, v3_only=False): tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q self.v3_only = v3_only - self.handler = handler + self.handler_klass = handler_klass + self.last_handler = None def handle_connection(self, request, client_address): - h = self.handler(request, client_address, self) + h = self.handler_klass(request, client_address, self) + self.last_handler = h if self.ssl: if self.v3_only: method = tcp.SSLv3_METHOD @@ -194,12 +211,24 @@ class TestDisconnect(ServerTestBase): c.close() +class TestServerTimeOut(ServerTestBase): + @classmethod + def makeserver(cls): + return TServer(("127.0.0.1", 0), False, cls.q, TimeoutHandler) + + def test_timeout(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + time.sleep(0.3) + assert self.last_handler.timeout + + class TestTimeOut(ServerTestBase): @classmethod def makeserver(cls): return TServer(("127.0.0.1", 0), False, cls.q, HangHandler) - def test_timeout_client(self): + def test_timeout(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() c.settimeout(0.1) -- cgit v1.2.3 From 77869634e20ae5a2646d7455e499866e9cfafbab Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 9 Oct 2012 16:25:15 +1300 Subject: Limit reads to block length. --- netlib/tcp.py | 8 ++++++-- test/test_tcp.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 414c1237..f8f877de 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -106,13 +106,17 @@ class Writer(_FileLike): class Reader(_FileLike): def read(self, length): """ - If length is None, we read until connection closes. + If length is -1, we read until connection closes. """ result = '' start = time.time() while length == -1 or length > 0: + if length == -1 or length > self.BLOCKSIZE: + rlen = self.BLOCKSIZE + else: + rlen = length try: - data = self.o.read(self.BLOCKSIZE if length == -1 else length) + data = self.o.read(rlen) except SSL.ZeroReturnError: break except SSL.WantReadError: diff --git a/test/test_tcp.py b/test/test_tcp.py index c833ce07..5a12da91 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -255,6 +255,17 @@ class TestTCPClient: class TestFileLike: + def test_blocksize(self): + s = cStringIO.StringIO("1234567890abcdefghijklmnopqrstuvwxyz") + s = tcp.Reader(s) + s.BLOCKSIZE = 2 + assert s.read(1) == "1" + assert s.read(2) == "23" + assert s.read(3) == "456" + assert s.read(4) == "7890" + d = s.read(-1) + assert d.startswith("abc") and d.endswith("xyz") + def test_wrap(self): s = cStringIO.StringIO("foobar\nfoobar") s.flush() -- cgit v1.2.3 From 6517d9e717883bc3cd0eb361e2aa0f58259cae60 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 14 Oct 2012 09:03:23 +1300 Subject: More info on disconnect exception. --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index f8f877de..7656e398 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -99,8 +99,8 @@ class Writer(_FileLike): r = self.o.write(v) self.add_log(v[:r]) return r - except (SSL.Error, socket.error): - raise NetLibDisconnect() + except (SSL.Error, socket.error), v: + raise NetLibDisconnect(str(v)) class Reader(_FileLike): -- cgit v1.2.3 From f8e10bd6ae1adba0897669bb8b90b9180150350a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 31 Oct 2012 22:24:45 +1300 Subject: Bump version. --- README | 14 +++++--------- netlib/version.py | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/README b/README index 972e03b8..58a61c50 100644 --- a/README +++ b/README @@ -1,11 +1,7 @@ -Netlib is a collection of network utility classes, used by pathod and mitmproxy -projects. It differs from other projects in some fundamental respects, because -both pathod and mitmproxy often need to violate standards. This means that -protocols are implemented as small, well-contained and flexible functions, and -servers are implemented to allow misbehaviour when needed. - -At this point, I have no plans to make netlib useful beyond mitmproxy and -pathod. Please get in touch if you think parts of netlib might have broader -utility. +Netlib is a collection of network utility classes, used by the pathod and +mitmproxy projects. It differs from other projects in some fundamental +respects, because both pathod and mitmproxy often need to violate standards. +This means that protocols are implemented as small, well-contained and flexible +functions, and are designed to allow misbehaviour when needed. diff --git a/netlib/version.py b/netlib/version.py index 614b87a1..30a4c0f9 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 2, 1) +IVERSION = (0, 2, 2) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From 043d05bcdeae482ca1d9b80375a1922e54896a6b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 5 Dec 2012 04:03:39 +0100 Subject: add __iter__ for odict --- netlib/odict.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/netlib/odict.py b/netlib/odict.py index 629fcade..bddb3877 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -22,6 +22,9 @@ class ODict: def __eq__(self, other): return self.lst == other.lst + def __iter__(self): + return self.lst.__iter__() + def __getitem__(self, k): """ Returns a list of values matching key. -- cgit v1.2.3 From 082f398b8fdb8176c94271470df21f6e8f3faff6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 3 Jan 2013 13:54:54 +1300 Subject: Add getcertnames, a tool for retrieving the CN and SANs from a remote server. --- tools/getcertnames | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100755 tools/getcertnames diff --git a/tools/getcertnames b/tools/getcertnames new file mode 100755 index 00000000..f39fc635 --- /dev/null +++ b/tools/getcertnames @@ -0,0 +1,16 @@ +#!/usr/bin/env python +import sys +sys.path.insert(0, "../../") +from netlib import certutils + +if len(sys.argv) > 2: + port = int(sys.argv[2]) +else: + port = 443 + +cert = certutils.get_remote_cert(sys.argv[1], port, None) +print "CN:", cert.cn +if cert.altnames: + print "SANs:", + for i in cert.altnames: + print "\t", i -- cgit v1.2.3 From ddc08efde1a5132734f1f06481a97e484cc368e3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 4 Jan 2013 14:23:52 +1300 Subject: Minor cleanup of http.parse_init* methods. --- netlib/http.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index b71eb72d..3f730a1a 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -166,36 +166,43 @@ def parse_http_protocol(s): return major, minor -def parse_init_connect(line): +def parse_init(line): try: method, url, protocol = string.split(line) except ValueError: return None - if method != 'CONNECT': + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, url, httpversion + + +def parse_init_connect(line): + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + + if method.upper() != 'CONNECT': return None try: host, port = url.split(":") except ValueError: return None port = int(port) - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None return host, port, httpversion def parse_init_proxy(line): - try: - method, url, protocol = string.split(line) - except ValueError: + v = parse_init(line) + if not v: return None + method, url, httpversion = v + parts = parse_url(url) if not parts: return None scheme, host, port, path = parts - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None return method, scheme, host, port, path, httpversion @@ -203,15 +210,13 @@ def parse_init_http(line): """ Returns (method, url, httpversion) """ - try: - method, url, protocol = string.split(line) - except ValueError: + v = parse_init(line) + if not v: return None + method, url, httpversion = v + if not (url.startswith("/") or url == "*"): return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None return method, url, httpversion -- cgit v1.2.3 From d3b46feb6011c106b42d297b1a4807d187991345 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 5 Jan 2013 20:06:55 +1300 Subject: Handle non-integer port error in parse_init_connect correctly --- netlib/http.py | 5 ++++- test/test_http.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/netlib/http.py b/netlib/http.py index 3f730a1a..076baf87 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -189,7 +189,10 @@ def parse_init_connect(line): host, port = url.split(":") except ValueError: return None - port = int(port) + try: + port = int(port) + except ValueError: + return None return host, port, httpversion diff --git a/test/test_http.py b/test/test_http.py index a6161fbc..ed16fb4a 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -139,6 +139,7 @@ def test_parse_init_connect(): assert not http.parse_init_connect("GET host.com:443 HTTP/1.0") assert not http.parse_init_connect("CONNECT host.com443 HTTP/1.0") assert not http.parse_init_connect("CONNECT host.com:443 foo/1.0") + assert not http.parse_init_connect("CONNECT host.com:foo HTTP/1.0") def test_prase_init_proxy(): -- cgit v1.2.3 From a9a4064ff94abdddabc22789a9c32f0cb02c55cb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 5 Jan 2013 20:08:48 +1300 Subject: Unit test for ODict.__iter__ --- test/test_odict.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_odict.py b/test/test_odict.py index f27f6f8b..d59ed67e 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -58,6 +58,11 @@ class TestODict: assert not self.od.in_any("one", "TWO") assert self.od.in_any("one", "TWO", True) + def test_iter(self): + assert not [i for i in self.od] + self.od.add("foo", 1) + assert [i for i in self.od] + def test_copy(self): self.od.add("foo", 1) self.od.add("foo", 2) -- cgit v1.2.3 From 72032d7fe75fae1bc1318cf0390e55af6a93ff4d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 6 Jan 2013 01:15:53 +1300 Subject: Basic certificate store implementation and cert utils API cleanup. --- netlib/certutils.py | 72 +++++++++++++++++++++++++++++++++++++------------- test/test_certutils.py | 46 +++++++++++++++++--------------- 2 files changed, 78 insertions(+), 40 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index f55a096b..51fd9da9 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,4 +1,4 @@ -import os, ssl, hashlib, socket, time, datetime +import os, ssl, hashlib, socket, time, datetime, tempfile, shutil from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error @@ -76,30 +76,24 @@ def dummy_ca(path): return True -def dummy_cert(certdir, ca, commonname, sans): +def dummy_cert(fp, ca, commonname, sans): """ - certdir: Certificate directory. + Generates and writes a certificate to fp. + ca: Path to the certificate authority file, or None. commonname: Common name for the generated certificate. + sans: A list of Subject Alternate Names. Returns cert path if operation succeeded, None if not. """ - namehash = hashlib.sha256(commonname).hexdigest() - certpath = os.path.join(certdir, namehash + ".pem") - if os.path.exists(certpath): - return certpath - ss = [] for i in sans: ss.append("DNS: %s"%i) ss = ", ".join(ss) - if ca: - raw = file(ca, "r").read() - ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) - key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - else: - key, ca = create_ca() + raw = file(ca, "r").read() + ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) req = OpenSSL.crypto.X509Req() subj = req.get_subject() @@ -110,7 +104,7 @@ def dummy_cert(certdir, ca, commonname, sans): req.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notBefore() cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) cert.set_issuer(ca.get_subject()) cert.set_subject(req.get_subject()) @@ -120,11 +114,51 @@ def dummy_cert(certdir, ca, commonname, sans): cert.set_pubkey(req.get_pubkey()) cert.sign(key, "sha1") - f = open(certpath, "w") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)) - f.close() + fp.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)) + fp.close() - return certpath + + +class CertStore: + """ + Implements an on-disk certificate store. + """ + def __init__(self, certdir=None): + """ + certdir: The certificate store directory. If None, a temporary + directory will be created, and destroyed when the .cleanup() method + is called. + """ + if certdir: + self.remove = False + self.certdir = certdir + else: + self.remove = True + self.certdir = tempfile.mkdtemp(prefix="certstore") + + def get_cert(self, commonname, sans, cacert=False): + """ + Returns the path to a certificate. + + commonname: Common name for the generated certificate. Must be a + valid, plain-ASCII, IDNA-encoded domain name. + + sans: A list of Subject Alternate Names. + + cacert: An optional path to a CA certificate. If specified, the + cert is created if it does not exist, else return None. + """ + certpath = os.path.join(self.certdir, commonname + ".pem") + if os.path.exists(certpath): + return certpath + elif cacert: + f = open(certpath, "w") + dummy_cert(f, cacert, commonname, sans) + return certpath + + def cleanup(self): + if self.remove: + shutil.rmtree(self.certdir) class _GeneralName(univ.Choice): diff --git a/test/test_certutils.py b/test/test_certutils.py index 9b8e7085..9b917dc6 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -16,36 +16,40 @@ def test_dummy_ca(): assert os.path.exists(os.path.join(d, "foo/cert2-cert.p12")) +class TestCertStore: + def test_create_explicit(self): + with tutils.tmpdir() as d: + ca = os.path.join(d, "ca") + assert certutils.dummy_ca(ca) + c = certutils.CertStore(d) + c.cleanup() + assert os.path.exists(d) + + def test_create_tmp(self): + with tutils.tmpdir() as d: + ca = os.path.join(d, "ca") + assert certutils.dummy_ca(ca) + c = certutils.CertStore() + assert not c.get_cert("foo.com", []) + assert c.get_cert("foo.com", [], ca) + assert c.get_cert("foo.com", [], ca) + c.cleanup() + + class TestDummyCert: def test_with_ca(self): with tutils.tmpdir() as d: - cacert = os.path.join(d, "foo/cert.cnf") + cacert = os.path.join(d, "cacert") assert certutils.dummy_ca(cacert) - p = certutils.dummy_cert( - os.path.join(d, "foo"), + p = os.path.join(d, "foo") + certutils.dummy_cert( + file(p, "w"), cacert, "foo.com", ["one.com", "two.com", "*.three.com"] ) - assert os.path.exists(p) - - # Short-circuit - assert certutils.dummy_cert( - os.path.join(d, "foo"), - cacert, - "foo.com", - [] - ) + assert file(p).read() - def test_no_ca(self): - with tutils.tmpdir() as d: - p = certutils.dummy_cert( - d, - None, - "foo.com", - [] - ) - assert os.path.exists(p) class TestSSLCert: -- cgit v1.2.3 From 91834ea78f36e1e89d4f19ecdddef83b0286b4d4 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 6 Jan 2013 01:16:58 +1300 Subject: Generate certificates with a commencement date an hour in the past. This helps smooth over small discrepancies in client and server times, where it's possible for a certificate to seem to be "in the future" to the client. --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 51fd9da9..87d9d5d8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -104,7 +104,7 @@ def dummy_cert(fp, ca, commonname, sans): req.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore() + cert.gmtime_adj_notBefore(-3600) cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) cert.set_issuer(ca.get_subject()) cert.set_subject(req.get_subject()) -- cgit v1.2.3 From e4acace8ea741af798523d6ff1d148d129f23582 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 6 Jan 2013 01:34:39 +1300 Subject: Sanity-check certstore common names. --- netlib/certutils.py | 16 ++++++++++++++++ test/test_certutils.py | 9 +++++++++ 2 files changed, 25 insertions(+) diff --git a/netlib/certutils.py b/netlib/certutils.py index 87d9d5d8..3fd57b2b 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -136,6 +136,18 @@ class CertStore: self.remove = True self.certdir = tempfile.mkdtemp(prefix="certstore") + def check_domain(self, commonname): + try: + commonname.decode("idna") + commonname.decode("ascii") + except: + return False + if ".." in commonname: + return False + if "/" in commonname: + return False + return True + def get_cert(self, commonname, sans, cacert=False): """ Returns the path to a certificate. @@ -147,7 +159,11 @@ class CertStore: cacert: An optional path to a CA certificate. If specified, the cert is created if it does not exist, else return None. + + Return None if the certificate could not be found or generated. """ + if not self.check_domain(commonname): + return None certpath = os.path.join(self.certdir, commonname + ".pem") if os.path.exists(certpath): return certpath diff --git a/test/test_certutils.py b/test/test_certutils.py index 9b917dc6..582fb9c4 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -35,6 +35,15 @@ class TestCertStore: assert c.get_cert("foo.com", [], ca) c.cleanup() + def test_check_domain(self): + c = certutils.CertStore() + assert c.check_domain("foo") + assert c.check_domain("\x01foo") + assert not c.check_domain("\xfefoo") + assert not c.check_domain("xn--\0") + assert not c.check_domain("foo..foo") + assert not c.check_domain("foo/foo") + class TestDummyCert: def test_with_ca(self): -- cgit v1.2.3 From 10457e876ad6db9c66973c925b7e65f2a16ffbca Mon Sep 17 00:00:00 2001 From: Israel Nir Date: Thu, 10 Jan 2013 15:51:37 +0200 Subject: adding read timestamp to enable better resolution of when certain reads were performed (timestamp is updated when the first byte is available on the network) --- netlib/tcp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 7656e398..76fb7ca0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -44,6 +44,7 @@ class _FileLike: def __init__(self, o): self.o = o self._log = None + self.timestamp = None def set_descriptor(self, o): self.o = o @@ -80,6 +81,8 @@ class _FileLike: if self.is_logging(): self._log.append(v) + def reset_timestamp(self): + self.timestamp = None class Writer(_FileLike): def flush(self): @@ -131,6 +134,7 @@ class Reader(_FileLike): raise NetLibDisconnect except SSL.SysCallError, v: raise NetLibDisconnect + self.timestamp = self.timestamp or time.time() if not data: break result += data -- cgit v1.2.3 From 04048b4c73f477f11d41788366eddffaae6bbb20 Mon Sep 17 00:00:00 2001 From: Rouli Date: Wed, 16 Jan 2013 22:30:19 +0200 Subject: renaming the timestamp in preparation of other timestamps that will be added later, adding tests --- netlib/tcp.py | 8 ++++---- test/test_tcp.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 76fb7ca0..9c5cfa64 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -44,7 +44,7 @@ class _FileLike: def __init__(self, o): self.o = o self._log = None - self.timestamp = None + self.first_byte_timestamp = None def set_descriptor(self, o): self.o = o @@ -81,8 +81,8 @@ class _FileLike: if self.is_logging(): self._log.append(v) - def reset_timestamp(self): - self.timestamp = None + def reset_timestamps(self): + self.first_byte_timestamp = None class Writer(_FileLike): def flush(self): @@ -134,7 +134,7 @@ class Reader(_FileLike): raise NetLibDisconnect except SSL.SysCallError, v: raise NetLibDisconnect - self.timestamp = self.timestamp or time.time() + self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break result += data diff --git a/test/test_tcp.py b/test/test_tcp.py index 5a12da91..d27a678a 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -313,3 +313,27 @@ class TestFileLike: s.write("x") assert s.get_log() == "xx" + def test_reset_timestamps(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.Reader(s) + s.first_byte_timestamp = 500 + s.reset_timestamps() + assert not s.first_byte_timestamp + + def test_first_byte_timestamp_updated_on_read(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.Reader(s) + s.read(1) + assert s.first_byte_timestamp + expected = s.first_byte_timestamp + s.read(5) + assert s.first_byte_timestamp == expected + + def test_first_byte_timestamp_updated_on_readline(self): + s = cStringIO.StringIO("foobar\nfoobar\nfoobar") + s = tcp.Reader(s) + s.readline() + assert s.first_byte_timestamp + expected = s.first_byte_timestamp + s.readline() + assert s.first_byte_timestamp == expected -- cgit v1.2.3 From 1499529e62e6d2892a6908472398854094af89fb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 18 Jan 2013 17:07:35 +1300 Subject: Fix client cert typo. --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 9c5cfa64..afb7e059 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -177,7 +177,7 @@ class TCPClient: if not options is None: ctx.set_options(options) if clientcert: - context.use_certificate_file(self.clientcert) + context.use_certificate_file(clientcert) self.connection = SSL.Connection(context, self.connection) self.ssl_established = True if sni: -- cgit v1.2.3 From 00d20abdd4863d15fdda826615dab264c8e14d4a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 20 Jan 2013 22:13:38 +1300 Subject: Beef up client certificate handling substantially. --- netlib/certutils.py | 6 +++--- netlib/tcp.py | 10 +++++++++- test/data/clientcert/.gitignore | 3 +++ test/data/clientcert/client.cnf | 5 +++++ test/data/clientcert/client.pem | 42 +++++++++++++++++++++++++++++++++++++++++ test/data/clientcert/make | 8 ++++++++ test/test_tcp.py | 22 +++++++++++++++++++++ 7 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 test/data/clientcert/.gitignore create mode 100644 test/data/clientcert/client.cnf create mode 100644 test/data/clientcert/client.pem create mode 100755 test/data/clientcert/make diff --git a/netlib/certutils.py b/netlib/certutils.py index 3fd57b2b..e1407936 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -256,11 +256,11 @@ class SSLCert: @property def cn(self): - cn = None + c = None for i in self.subject: if i[0] == "CN": - cn = i[1] - return cn + c = i[1] + return c @property def altnames(self): diff --git a/netlib/tcp.py b/netlib/tcp.py index afb7e059..4b547d1f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -173,10 +173,14 @@ class TCPClient: self.ssl_established = False def convert_to_ssl(self, clientcert=None, sni=None, method=TLSv1_METHOD, options=None): + """ + clientcert: Path to a file containing both client cert and private key. + """ context = SSL.Context(method) if not options is None: ctx.set_options(options) if clientcert: + context.use_privatekey_file(clientcert) context.use_certificate_file(clientcert) self.connection = SSL.Connection(context, self.connection) self.ssl_established = True @@ -238,6 +242,7 @@ class BaseHandler: self.server = server self.finished = False self.ssl_established = False + self.clientcert = None def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None): """ @@ -246,13 +251,16 @@ class BaseHandler: ctx = SSL.Context(method) if not options is None: ctx.set_options(options) + # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(self.handle_sni) ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) + def ver(*args): + self.clientcert = certutils.SSLCert(args[1]) + ctx.set_verify(SSL.VERIFY_PEER, ver) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True self.connection.set_accept_state() - # SNI callback happens during do_handshake() try: self.connection.do_handshake() except SSL.Error, v: diff --git a/test/data/clientcert/.gitignore b/test/data/clientcert/.gitignore new file mode 100644 index 00000000..07bc53d2 --- /dev/null +++ b/test/data/clientcert/.gitignore @@ -0,0 +1,3 @@ +client.crt +client.key +client.req diff --git a/test/data/clientcert/client.cnf b/test/data/clientcert/client.cnf new file mode 100644 index 00000000..5046a944 --- /dev/null +++ b/test/data/clientcert/client.cnf @@ -0,0 +1,5 @@ +[ ssl_client ] +basicConstraints = CA:FALSE +nsCertType = client +keyUsage = digitalSignature, keyEncipherment +extendedKeyUsage = clientAuth diff --git a/test/data/clientcert/client.pem b/test/data/clientcert/client.pem new file mode 100644 index 00000000..4927bca2 --- /dev/null +++ b/test/data/clientcert/client.pem @@ -0,0 +1,42 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAzCpoRjSTfIN24kkNap/GYmP9zVWj0Gk8R5BB/PvvN0OB1Zk0 +EEYPsWCcuhEdK0ehiDZX030doF0DOncKKa6mop/d0x2o+ts42peDhZM6JNUrm6d+ +ZWQVtio33mpp77UMhR093vaA+ExDnmE26kBTVijJ1+fRAVDXG/cmQINEri91Kk/G +3YJ5e45UrohGI5seBZ4vV0xbHtmczFRhYFlGOvYsoIe4Lvz/eFS2pIrTIpYQ2VM/ +SQQl+JFy+NlQRsWG2NrxtKOzMnnDE7YN4I3z5D5eZFo1EtwZ48LNCeSwrEOdfuzP +G5q5qbs5KpE/x85H9umuRwSCIArbMwBYV8a8JwIDAQABAoIBAFE3FV/IDltbmHEP +iky93hbJm+6QgKepFReKpRVTyqb7LaygUvueQyPWQMIriKTsy675nxo8DQr7tQsO +y3YlSZgra/xNMikIB6e82c7K8DgyrDQw/rCqjZB3Xt4VCqsWJDLXnQMSn98lx0g7 +d7Lbf8soUpKWXqfdVpSDTi4fibSX6kshXyfSTpcz4AdoncEpViUfU1xkEEmZrjT8 +1GcCsDC41xdNmzCpqRuZX7DKSFRoB+0hUzsC1oiqM7FD5kixonRd4F5PbRXImIzt +6YCsT2okxTA04jX7yByis7LlOLTlkmLtKQYuc3erOFvwx89s4vW+AeFei+GGNitn +tHfSwbECgYEA7SzV+nN62hAERHlg8cEQT4TxnsWvbronYWcc/ev44eHSPDWL5tPi +GHfSbW6YAq5Wa0I9jMWfXyhOYEC3MZTC5EEeLOB71qVrTwcy/sY66rOrcgjFI76Q +5JFHQ4wy3SWU50KxE0oWJO9LIowprG+pW1vzqC3VF0T7q0FqESrY4LUCgYEA3F7Z +80ndnCUlooJAb+Hfotv7peFf1o6+m1PTRcz1lLnVt5R5lXj86kn+tXEpYZo1RiGR +2rE2N0seeznWCooakHcsBN7/qmFIhhooJNF7yW+JP2I4P2UV5+tJ+8bcs/voUkQD +1x+rGOuMn8nvHBd2+Vharft8eGL2mgooPVI2XusCgYEAlMZpO3+w8pTVeHaDP2MR +7i/AuQ3cbCLNjSX3Y7jgGCFllWspZRRIYXzYPNkA9b2SbBnTLjjRLgnEkFBIGgvs +7O2EFjaCuDRvydUEQhjq4ErwIsopj7B8h0QyZcbOKTbn3uFQ3n68wVJx2Sv/ADHT +FIHrp/WIE96r19Niy34LKXkCgYB2W59VsuOKnMz01l5DeR5C+0HSWxS9SReIl2IO +yEFSKullWyJeLIgyUaGy0990430feKI8whcrZXYumuah7IDN/KOwzhCk8vEfzWao +N7bzfqtJVrh9HA7C7DVlO+6H4JFrtcoWPZUIomJ549w/yz6EN3ckoMC+a/Ck1TW9 +ka1QFwKBgQCywG6TrZz0UmOjyLQZ+8Q4uvZklSW5NAKBkNnyuQ2kd5rzyYgMPE8C +Er8T88fdVIKvkhDyHhwcI7n58xE5Gr7wkwsrk/Hbd9/ZB2GgAPY3cATskK1v1McU +YeX38CU0fUS4aoy26hWQXkViB47IGQ3jWo3ZCtzIJl8DI9/RsBWTnw== +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICYDCCAckCAQEwDQYJKoZIhvcNAQEFBQAwKDESMBAGA1UEAxMJbWl0bXByb3h5 +MRIwEAYDVQQKEwltaXRtcHJveHkwHhcNMTMwMTIwMDEwODEzWhcNMTUxMDE3MDEw +ODEzWjBFMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UE +ChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEAzCpoRjSTfIN24kkNap/GYmP9zVWj0Gk8R5BB/PvvN0OB1Zk0 +EEYPsWCcuhEdK0ehiDZX030doF0DOncKKa6mop/d0x2o+ts42peDhZM6JNUrm6d+ +ZWQVtio33mpp77UMhR093vaA+ExDnmE26kBTVijJ1+fRAVDXG/cmQINEri91Kk/G +3YJ5e45UrohGI5seBZ4vV0xbHtmczFRhYFlGOvYsoIe4Lvz/eFS2pIrTIpYQ2VM/ +SQQl+JFy+NlQRsWG2NrxtKOzMnnDE7YN4I3z5D5eZFo1EtwZ48LNCeSwrEOdfuzP +G5q5qbs5KpE/x85H9umuRwSCIArbMwBYV8a8JwIDAQABMA0GCSqGSIb3DQEBBQUA +A4GBAFvI+cd47B85PQ970n2dU/PlA2/Hb1ldrrXh2guR4hX6vYx/uuk5yRI/n0Rd +KOXJ3czO0bd2Fpe3ZoNpkW0pOSDej/Q+58ScuJd0gWCT/Sh1eRk6ZdC0kusOuWoY +bPOPMkG45LPgUMFOnZEsfJP6P5mZIxlbCvSMFC25nPHWlct7 +-----END CERTIFICATE----- diff --git a/test/data/clientcert/make b/test/data/clientcert/make new file mode 100755 index 00000000..d1caea81 --- /dev/null +++ b/test/data/clientcert/make @@ -0,0 +1,8 @@ +#!/bin/sh + +openssl genrsa -out client.key 2048 +openssl req -key client.key -new -out client.req +openssl x509 -req -days 365 -in client.req -signkey client.key -out client.crt -extfile client.cnf -extensions ssl_client +openssl x509 -req -days 1000 -in client.req -CA ~/.mitmproxy/mitmproxy-ca.pem -CAkey ~/.mitmproxy/mitmproxy-ca.pem -set_serial 00001 -out client.crt -extensions ssl_client +cat client.key client.crt > client.pem +openssl x509 -text -noout -in client.pem diff --git a/test/test_tcp.py b/test/test_tcp.py index d27a678a..034e43b9 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -57,6 +57,16 @@ class EchoHandler(tcp.BaseHandler): self.wfile.flush() +class CertHandler(tcp.BaseHandler): + sni = None + def handle_sni(self, connection): + self.sni = connection.get_servername() + + def handle(self): + self.wfile.write("%s\n"%self.clientcert.serial) + self.wfile.flush() + + class DisconnectHandler(tcp.BaseHandler): def handle(self): self.close() @@ -168,6 +178,18 @@ class TestSSLv3Only(ServerTestBase): tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD) +class TestSSLClientCert(ServerTestBase): + @classmethod + def makeserver(cls): + return TServer(("127.0.0.1", 0), True, cls.q, CertHandler) + + def test_clientcert(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.convert_to_ssl(clientcert=tutils.test_data.path("data/clientcert/client.pem")) + assert c.rfile.readline().strip() == "1" + + class TestSNI(ServerTestBase): @classmethod def makeserver(cls): -- cgit v1.2.3 From 7248a22d5e381dd57d69c06f8e67e60fd55e55ba Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 20 Jan 2013 22:36:54 +1300 Subject: Improve error signalling for client certificates. --- netlib/tcp.py | 9 ++++++--- test/test_tcp.py | 9 +++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 4b547d1f..d0ca09f3 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -177,11 +177,14 @@ class TCPClient: clientcert: Path to a file containing both client cert and private key. """ context = SSL.Context(method) - if not options is None: + if options is not None: ctx.set_options(options) if clientcert: - context.use_privatekey_file(clientcert) - context.use_certificate_file(clientcert) + try: + context.use_privatekey_file(clientcert) + context.use_certificate_file(clientcert) + except SSL.Error, v: + raise NetLibError("SSL client certificate error: %s"%str(v)) self.connection = SSL.Connection(context, self.connection) self.ssl_established = True if sni: diff --git a/test/test_tcp.py b/test/test_tcp.py index 034e43b9..0417aa21 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -189,6 +189,15 @@ class TestSSLClientCert(ServerTestBase): c.convert_to_ssl(clientcert=tutils.test_data.path("data/clientcert/client.pem")) assert c.rfile.readline().strip() == "1" + def test_clientcert_err(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + tutils.raises( + tcp.NetLibError, + c.convert_to_ssl, + clientcert=tutils.test_data.path("data/clientcert/make") + ) + class TestSNI(ServerTestBase): @classmethod -- cgit v1.2.3 From 2eb6651e5180035cd3e17f9048b16ea38719a9ac Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 25 Jan 2013 15:54:41 +1300 Subject: Extract TCP test utilities into netlib.test --- netlib/tcp.py | 11 ++-- netlib/test.py | 67 +++++++++++++++++++++ test/test_certutils.py | 1 + test/test_tcp.py | 159 ++++++++++++++++++++++--------------------------- 4 files changed, 146 insertions(+), 92 deletions(-) create mode 100644 netlib/test.py diff --git a/netlib/tcp.py b/netlib/tcp.py index d0ca09f3..56cc0dea 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,4 +1,4 @@ -import select, socket, threading, traceback, sys, time +import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils @@ -84,13 +84,14 @@ class _FileLike: def reset_timestamps(self): self.first_byte_timestamp = None + class Writer(_FileLike): def flush(self): - try: - if hasattr(self.o, "flush"): + if hasattr(self.o, "flush"): + try: self.o.flush() - except socket.error, v: - raise NetLibDisconnect(str(v)) + except socket.error, v: + raise NetLibDisconnect(str(v)) def write(self, v): if v: diff --git a/netlib/test.py b/netlib/test.py new file mode 100644 index 00000000..2f72f979 --- /dev/null +++ b/netlib/test.py @@ -0,0 +1,67 @@ +import threading, Queue, cStringIO +import tcp + +class ServerThread(threading.Thread): + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class ServerTestBase: + @classmethod + def setupAll(cls): + cls.q = Queue.Queue() + s = cls.makeserver() + cls.port = s.port + cls.server = ServerThread(s) + cls.server.start() + + @classmethod + def teardownAll(cls): + cls.server.shutdown() + + + @property + def last_handler(self): + return self.server.server.last_handler + + +class TServer(tcp.TCPServer): + def __init__(self, ssl, q, handler_klass, addr=("127.0.0.1", 0)): + """ + ssl: A {cert, key, v3_only} dict. + """ + tcp.TCPServer.__init__(self, addr) + self.ssl, self.q = ssl, q + self.handler_klass = handler_klass + self.last_handler = None + + def handle_connection(self, request, client_address): + h = self.handler_klass(request, client_address, self) + self.last_handler = h + if self.ssl: + if self.ssl["v3_only"]: + method = tcp.SSLv3_METHOD + options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 + else: + method = tcp.SSLv23_METHOD + options = None + h.convert_to_ssl( + self.ssl["cert"], + self.ssl["key"], + method = method, + options = options, + ) + h.handle() + h.finish() + + def handle_error(self, request, client_address): + s = cStringIO.StringIO() + tcp.TCPServer.handle_error(self, request, client_address, s) + self.q.put(s.getvalue()) diff --git a/test/test_certutils.py b/test/test_certutils.py index 582fb9c4..334a6be4 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -30,6 +30,7 @@ class TestCertStore: ca = os.path.join(d, "ca") assert certutils.dummy_ca(ca) c = certutils.CertStore() + assert not c.get_cert("../foo.com", []) assert not c.get_cert("foo.com", []) assert c.get_cert("foo.com", [], ca) assert c.get_cert("foo.com", [], ca) diff --git a/test/test_tcp.py b/test/test_tcp.py index 0417aa21..ce06ad66 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,38 +1,7 @@ import cStringIO, threading, Queue, time -from netlib import tcp, certutils +from netlib import tcp, certutils, test import tutils -class ServerThread(threading.Thread): - def __init__(self, server): - self.server = server - threading.Thread.__init__(self) - - def run(self): - self.server.serve_forever() - - def shutdown(self): - self.server.shutdown() - - -class ServerTestBase: - @classmethod - def setupAll(cls): - cls.q = Queue.Queue() - s = cls.makeserver() - cls.port = s.port - cls.server = ServerThread(s) - cls.server.start() - - @classmethod - def teardownAll(cls): - cls.server.shutdown() - - - @property - def last_handler(self): - return self.server.server.last_handler - - class SNIHandler(tcp.BaseHandler): sni = None def handle_sni(self, connection): @@ -88,43 +57,10 @@ class TimeoutHandler(tcp.BaseHandler): self.timeout = True -class TServer(tcp.TCPServer): - def __init__(self, addr, ssl, q, handler_klass, v3_only=False): - tcp.TCPServer.__init__(self, addr) - self.ssl, self.q = ssl, q - self.v3_only = v3_only - self.handler_klass = handler_klass - self.last_handler = None - - def handle_connection(self, request, client_address): - h = self.handler_klass(request, client_address, self) - self.last_handler = h - if self.ssl: - if self.v3_only: - method = tcp.SSLv3_METHOD - options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 - else: - method = tcp.SSLv23_METHOD - options = None - h.convert_to_ssl( - tutils.test_data.path("data/server.crt"), - tutils.test_data.path("data/server.key"), - method = method, - options = options, - ) - h.handle() - h.finish() - - def handle_error(self, request, client_address): - s = cStringIO.StringIO() - tcp.TCPServer.handle_error(self, request, client_address, s) - self.q.put(s.getvalue()) - - -class TestServer(ServerTestBase): +class TestServer(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) + return test.TServer(False, cls.q, EchoHandler) def test_echo(self): testval = "echo!\n" @@ -135,10 +71,10 @@ class TestServer(ServerTestBase): assert c.rfile.readline() == testval -class TestDisconnect(ServerTestBase): +class TestDisconnect(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) + return test.TServer(False, cls.q, EchoHandler) def test_echo(self): testval = "echo!\n" @@ -149,10 +85,18 @@ class TestDisconnect(ServerTestBase): assert c.rfile.readline() == testval -class TestServerSSL(ServerTestBase): +class TestServerSSL(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + EchoHandler + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -167,10 +111,19 @@ class TestServerSSL(ServerTestBase): assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1") -class TestSSLv3Only(ServerTestBase): +class TestSSLv3Only(test.ServerTestBase): + v3_only = True @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler, True) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = True + ), + cls.q, + EchoHandler, + ) def test_failure(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -178,10 +131,18 @@ class TestSSLv3Only(ServerTestBase): tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD) -class TestSSLClientCert(ServerTestBase): +class TestSSLClientCert(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, CertHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + CertHandler + ) def test_clientcert(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -199,10 +160,18 @@ class TestSSLClientCert(ServerTestBase): ) -class TestSNI(ServerTestBase): +class TestSNI(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, SNIHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + SNIHandler + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -211,10 +180,18 @@ class TestSNI(ServerTestBase): assert c.rfile.readline() == "foo.com" -class TestSSLDisconnect(ServerTestBase): +class TestSSLDisconnect(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + DisconnectHandler + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -227,10 +204,10 @@ class TestSSLDisconnect(ServerTestBase): tutils.raises(Queue.Empty, self.q.get_nowait) -class TestDisconnect(ServerTestBase): +class TestSSLDisconnect(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler) + return test.TServer(False, cls.q, DisconnectHandler) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -242,10 +219,10 @@ class TestDisconnect(ServerTestBase): c.close() -class TestServerTimeOut(ServerTestBase): +class TestServerTimeOut(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, TimeoutHandler) + return test.TServer(False, cls.q, TimeoutHandler) def test_timeout(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -254,10 +231,10 @@ class TestServerTimeOut(ServerTestBase): assert self.last_handler.timeout -class TestTimeOut(ServerTestBase): +class TestTimeOut(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, HangHandler) + return test.TServer(False, cls.q, HangHandler) def test_timeout(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -266,10 +243,18 @@ class TestTimeOut(ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) -class TestSSLTimeOut(ServerTestBase): +class TestSSLTimeOut(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, HangHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + HangHandler + ) def test_timeout_client(self): c = tcp.TCPClient("127.0.0.1", self.port) -- cgit v1.2.3 From cc4867064be42409fd5fb8271901b03029b787de Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 25 Jan 2013 16:03:59 +1300 Subject: Streamline netlib.test API --- netlib/test.py | 7 ++- test/test_tcp.py | 130 ++++++++++++++++--------------------------------------- 2 files changed, 44 insertions(+), 93 deletions(-) diff --git a/netlib/test.py b/netlib/test.py index 2f72f979..7d24d80e 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -14,6 +14,8 @@ class ServerThread(threading.Thread): class ServerTestBase: + ssl = None + handler = None @classmethod def setupAll(cls): cls.q = Queue.Queue() @@ -22,11 +24,14 @@ class ServerTestBase: cls.server = ServerThread(s) cls.server.start() + @classmethod + def makeserver(cls): + return TServer(cls.ssl, cls.q, cls.handler) + @classmethod def teardownAll(cls): cls.server.shutdown() - @property def last_handler(self): return self.server.server.last_handler diff --git a/test/test_tcp.py b/test/test_tcp.py index ce06ad66..ad09143d 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -58,10 +58,7 @@ class TimeoutHandler(tcp.BaseHandler): class TestServer(test.ServerTestBase): - @classmethod - def makeserver(cls): - return test.TServer(False, cls.q, EchoHandler) - + handler = EchoHandler def test_echo(self): testval = "echo!\n" c = tcp.TCPClient("127.0.0.1", self.port) @@ -72,10 +69,7 @@ class TestServer(test.ServerTestBase): class TestDisconnect(test.ServerTestBase): - @classmethod - def makeserver(cls): - return test.TServer(False, cls.q, EchoHandler) - + handler = EchoHandler def test_echo(self): testval = "echo!\n" c = tcp.TCPClient("127.0.0.1", self.port) @@ -86,18 +80,12 @@ class TestDisconnect(test.ServerTestBase): class TestServerSSL(test.ServerTestBase): - @classmethod - def makeserver(cls): - return test.TServer( - dict( + handler = EchoHandler + ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), v3_only = False - ), - cls.q, - EchoHandler - ) - + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() @@ -112,19 +100,12 @@ class TestServerSSL(test.ServerTestBase): class TestSSLv3Only(test.ServerTestBase): - v3_only = True - @classmethod - def makeserver(cls): - return test.TServer( - dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - v3_only = True - ), - cls.q, - EchoHandler, - ) - + handler = EchoHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = True + ) def test_failure(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() @@ -132,18 +113,12 @@ class TestSSLv3Only(test.ServerTestBase): class TestSSLClientCert(test.ServerTestBase): - @classmethod - def makeserver(cls): - return test.TServer( - dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - v3_only = False - ), - cls.q, - CertHandler - ) - + handler = CertHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ) def test_clientcert(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() @@ -161,18 +136,12 @@ class TestSSLClientCert(test.ServerTestBase): class TestSNI(test.ServerTestBase): - @classmethod - def makeserver(cls): - return test.TServer( - dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - v3_only = False - ), - cls.q, - SNIHandler - ) - + handler = SNIHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() @@ -181,18 +150,12 @@ class TestSNI(test.ServerTestBase): class TestSSLDisconnect(test.ServerTestBase): - @classmethod - def makeserver(cls): - return test.TServer( - dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - v3_only = False - ), - cls.q, - DisconnectHandler - ) - + handler = DisconnectHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() @@ -204,15 +167,10 @@ class TestSSLDisconnect(test.ServerTestBase): tutils.raises(Queue.Empty, self.q.get_nowait) -class TestSSLDisconnect(test.ServerTestBase): - @classmethod - def makeserver(cls): - return test.TServer(False, cls.q, DisconnectHandler) - +class TestDisconnect(test.ServerTestBase): def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() - # Excercise SSL.ZeroReturnError c.rfile.read(10) c.wfile.write("foo") c.close() @@ -220,10 +178,7 @@ class TestSSLDisconnect(test.ServerTestBase): class TestServerTimeOut(test.ServerTestBase): - @classmethod - def makeserver(cls): - return test.TServer(False, cls.q, TimeoutHandler) - + handler = TimeoutHandler def test_timeout(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() @@ -232,10 +187,7 @@ class TestServerTimeOut(test.ServerTestBase): class TestTimeOut(test.ServerTestBase): - @classmethod - def makeserver(cls): - return test.TServer(False, cls.q, HangHandler) - + handler = HangHandler def test_timeout(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() @@ -244,18 +196,12 @@ class TestTimeOut(test.ServerTestBase): class TestSSLTimeOut(test.ServerTestBase): - @classmethod - def makeserver(cls): - return test.TServer( - dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - v3_only = False - ), - cls.q, - HangHandler - ) - + handler = HangHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ) def test_timeout_client(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() -- cgit v1.2.3 From e5b125eec8e732112af9884cf3ab35377913303a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 26 Jan 2013 21:19:35 +1300 Subject: Introduce the mock module to improve unit tests. There are a few socket corner-cases that are incredibly hard to reproduce in a unit test suite, so we use mock to trigger the exceptions instead. --- netlib/tcp.py | 6 ++++++ test/test_tcp.py | 19 ++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 56cc0dea..a79f3ac4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -87,6 +87,9 @@ class _FileLike: class Writer(_FileLike): def flush(self): + """ + May raise NetLibDisconnect + """ if hasattr(self.o, "flush"): try: self.o.flush() @@ -94,6 +97,9 @@ class Writer(_FileLike): raise NetLibDisconnect(str(v)) def write(self, v): + """ + May raise NetLibDisconnect + """ if v: try: if hasattr(self.o, "sendall"): diff --git a/test/test_tcp.py b/test/test_tcp.py index ad09143d..e7524fdc 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,5 +1,6 @@ -import cStringIO, threading, Queue, time +import cStringIO, threading, Queue, time, socket from netlib import tcp, certutils, test +import mock import tutils class SNIHandler(tcp.BaseHandler): @@ -275,6 +276,22 @@ class TestFileLike: s.write("x") assert s.get_log() == "xx" + def test_writer_flush_error(self): + s = cStringIO.StringIO() + s = tcp.Writer(s) + o = mock.MagicMock() + o.flush = mock.MagicMock(side_effect=socket.error) + s.o = o + tutils.raises(tcp.NetLibDisconnect, s.flush) + + def test_reader_read_error(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.Reader(s) + o = mock.MagicMock() + o.read = mock.MagicMock(side_effect=socket.error) + s.o = o + tutils.raises(tcp.NetLibDisconnect, s.read, 10) + def test_reset_timestamps(self): s = cStringIO.StringIO("foobar\nfoobar") s = tcp.Reader(s) -- cgit v1.2.3 From 7433dfceae3b2ac7e709fbcedd9e298800d2ac1b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 26 Jan 2013 21:29:45 +1300 Subject: Bump unit tests, fix two serious wee buglets discovered. --- netlib/tcp.py | 4 ++-- test/test_tcp.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index a79f3ac4..40bd4bde 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -185,7 +185,7 @@ class TCPClient: """ context = SSL.Context(method) if options is not None: - ctx.set_options(options) + context.set_options(options) if clientcert: try: context.use_privatekey_file(clientcert) @@ -220,7 +220,7 @@ class TCPClient: self.connection.settimeout(n) def gettimeout(self): - self.connection.gettimeout() + return self.connection.gettimeout() def close(self): """ diff --git a/test/test_tcp.py b/test/test_tcp.py index e7524fdc..6ff42072 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -90,7 +90,7 @@ class TestServerSSL(test.ServerTestBase): def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() - c.convert_to_ssl(sni="foo.com") + c.convert_to_ssl(sni="foo.com", options=tcp.OP_ALL) testval = "echo!\n" c.wfile.write(testval) c.wfile.flush() @@ -193,6 +193,7 @@ class TestTimeOut(test.ServerTestBase): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() c.settimeout(0.1) + assert c.gettimeout() == 0.1 tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) -- cgit v1.2.3 From 7d185356655fa2f40c452c273a3cd039360d20c1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 27 Jan 2013 19:21:18 +1300 Subject: 100% test coverage --- netlib/tcp.py | 21 +++++++-------------- test/test_tcp.py | 35 +++++++++++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 40bd4bde..556f97ac 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -355,20 +355,13 @@ class TCPServer: while not self.__shutdown_request: r, w, e = select.select([self.socket], [], [], poll_interval) if self.socket in r: - try: - request, client_address = self.socket.accept() - except socket.error: - return - try: - t = threading.Thread( - target = self.request_thread, - args = (request, client_address) - ) - t.setDaemon(1) - t.start() - except: - self.handle_error(request, client_address) - request.close() + request, client_address = self.socket.accept() + t = threading.Thread( + target = self.request_thread, + args = (request, client_address) + ) + t.setDaemon(1) + t.start() finally: self.__shutdown_request = False self.__is_shut_down.set() diff --git a/test/test_tcp.py b/test/test_tcp.py index 6ff42072..f12a131b 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -20,10 +20,7 @@ class EchoHandler(tcp.BaseHandler): def handle(self): v = self.rfile.readline() - if v.startswith("echo"): - self.wfile.write(v) - elif v.startswith("error"): - raise ValueError("Testing an error.") + self.wfile.write(v) self.wfile.flush() @@ -69,6 +66,35 @@ class TestServer(test.ServerTestBase): assert c.rfile.readline() == testval + +class FinishFailHandler(tcp.BaseHandler): + def handle(self): + v = self.rfile.readline() + self.wfile.write(v) + self.wfile.flush() + o = mock.MagicMock() + self.wfile.close() + self.rfile.close() + self.close = mock.MagicMock(side_effect=socket.error) + + +class TestFinishFail(test.ServerTestBase): + """ + This tests a difficult-to-trigger exception in the .finish() method of + the handler. + """ + handler = FinishFailHandler + def test_disconnect_in_finish(self): + testval = "echo!\n" + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.wfile.write("foo\n") + c.wfile.flush() + c.rfile.read(4) + h = self.last_handler + h.finish() + + class TestDisconnect(test.ServerTestBase): handler = EchoHandler def test_echo(self): @@ -317,3 +343,4 @@ class TestFileLike: expected = s.first_byte_timestamp s.readline() assert s.first_byte_timestamp == expected + -- cgit v1.2.3 From c6f9a2d74dc0b2d9185743a02e4c1410983f0c3f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Feb 2013 11:08:43 +1300 Subject: More accurate description of an HTTP read error, make pyflakes happy. --- netlib/certutils.py | 2 +- netlib/http.py | 2 +- netlib/tcp.py | 4 ++-- netlib/wsgi.py | 8 ++++---- test/test_http.py | 2 +- test/test_tcp.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index e1407936..b3ba1dcf 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,4 +1,4 @@ -import os, ssl, hashlib, socket, time, datetime, tempfile, shutil +import os, ssl, time, datetime, tempfile, shutil from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error diff --git a/netlib/http.py b/netlib/http.py index 076baf87..29bcf43d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -279,7 +279,7 @@ def read_response(rfile, method, body_size_limit): if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() if not line: - raise HttpError(502, "Blank server response.") + raise HttpError(502, "Server disconnect.") parts = line.strip().split(" ", 2) if len(parts) == 2: # handle missing message gracefully parts.append("") diff --git a/netlib/tcp.py b/netlib/tcp.py index 556f97ac..0a15d2ac 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -139,7 +139,7 @@ class Reader(_FileLike): raise NetLibTimeout except socket.error: raise NetLibDisconnect - except SSL.SysCallError, v: + except SSL.SysCallError: raise NetLibDisconnect self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: @@ -322,7 +322,7 @@ class BaseHandler: self.connection.shutdown() else: self.connection.shutdown(socket.SHUT_RDWR) - except (socket.error, SSL.Error), v: + except (socket.error, SSL.Error): # Socket probably already closed pass self.connection.close() diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 4fa2c537..dffc2ace 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,4 +1,4 @@ -import cStringIO, urllib, time, sys, traceback +import cStringIO, urllib, time, traceback import odict @@ -128,13 +128,13 @@ class WSGIAdaptor: write(i) if not state["headers_sent"]: write("") - except Exception, v: + except Exception: try: s = traceback.format_exc() errs.write(s) self.error_page(soc, state["headers_sent"], s) - except Exception, v: # pragma: no cover - pass # pragma: no cover + except Exception: # pragma: no cover + pass return errs.getvalue() diff --git a/test/test_http.py b/test/test_http.py index ed16fb4a..666dfdbb 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -222,7 +222,7 @@ def test_read_response(): r = cStringIO.StringIO(data) return http.read_response(r, method, limit) - tutils.raises("blank server response", tst, "", "GET", None) + tutils.raises("server disconnect", tst, "", "GET", None) tutils.raises("invalid server response", tst, "foo", "GET", None) data = """ HTTP/1.1 200 OK diff --git a/test/test_tcp.py b/test/test_tcp.py index f12a131b..5b616969 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,4 +1,4 @@ -import cStringIO, threading, Queue, time, socket +import cStringIO, Queue, time, socket from netlib import tcp, certutils, test import mock import tutils -- cgit v1.2.3 From 97e11a219fb2a752d5b726b203874101d7ab651c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 24 Feb 2013 15:36:15 +1300 Subject: Housekeeping and cleanup, some minor argument name changes. --- netlib/certutils.py | 1 - netlib/http.py | 9 ++++++--- netlib/tcp.py | 10 +++++----- test/test_tcp.py | 4 ++-- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index b3ba1dcf..859c93f1 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -118,7 +118,6 @@ def dummy_cert(fp, ca, commonname, sans): fp.close() - class CertStore: """ Implements an on-disk certificate store. diff --git a/netlib/http.py b/netlib/http.py index 29bcf43d..58993686 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -9,6 +9,9 @@ class HttpError(Exception): return "HttpError(%s, %s)"%(self.code, self.msg) +class HttpErrorConnClosed(HttpError): pass + + def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. @@ -73,7 +76,7 @@ def read_chunked(code, fp, limit): while 1: line = fp.readline(128) if line == "": - raise HttpError(code, "Connection closed prematurely") + raise HttpErrorConnClosed(code, "Connection closed prematurely") if line != '\r\n' and line != '\n': try: length = int(line, 16) @@ -95,7 +98,7 @@ def read_chunked(code, fp, limit): while 1: line = fp.readline() if line == "": - raise HttpError(code, "Connection closed prematurely") + raise HttpErrorConnClosed(code, "Connection closed prematurely") if line == '\r\n' or line == '\n': break return content @@ -279,7 +282,7 @@ def read_response(rfile, method, body_size_limit): if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() if not line: - raise HttpError(502, "Server disconnect.") + raise HttpErrorConnClosed(502, "Server disconnect.") parts = line.strip().split(" ", 2) if len(parts) == 2: # handle missing message gracefully parts.append("") diff --git a/netlib/tcp.py b/netlib/tcp.py index 0a15d2ac..d909a5a4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -179,17 +179,17 @@ class TCPClient: self.cert = None self.ssl_established = False - def convert_to_ssl(self, clientcert=None, sni=None, method=TLSv1_METHOD, options=None): + def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ - clientcert: Path to a file containing both client cert and private key. + cert: Path to a file containing both client cert and private key. """ context = SSL.Context(method) if options is not None: context.set_options(options) - if clientcert: + if cert: try: - context.use_privatekey_file(clientcert) - context.use_certificate_file(clientcert) + context.use_privatekey_file(cert) + context.use_certificate_file(cert) except SSL.Error, v: raise NetLibError("SSL client certificate error: %s"%str(v)) self.connection = SSL.Connection(context, self.connection) diff --git a/test/test_tcp.py b/test/test_tcp.py index 5b616969..de14ab25 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -149,7 +149,7 @@ class TestSSLClientCert(test.ServerTestBase): def test_clientcert(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() - c.convert_to_ssl(clientcert=tutils.test_data.path("data/clientcert/client.pem")) + c.convert_to_ssl(cert=tutils.test_data.path("data/clientcert/client.pem")) assert c.rfile.readline().strip() == "1" def test_clientcert_err(self): @@ -158,7 +158,7 @@ class TestSSLClientCert(test.ServerTestBase): tutils.raises( tcp.NetLibError, c.convert_to_ssl, - clientcert=tutils.test_data.path("data/clientcert/make") + cert=tutils.test_data.path("data/clientcert/make") ) -- cgit v1.2.3 From f30df13384b1c31ee7bcd78b0caea37043434bcf Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 25 Feb 2013 21:11:09 +1300 Subject: Make sni_handler an argument to BaseHandler.convert_to_ssl --- netlib/tcp.py | 35 +++++++++++++++-------------------- netlib/test.py | 1 + 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index d909a5a4..485d821f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -254,15 +254,27 @@ class BaseHandler: self.ssl_established = False self.clientcert = None - def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None): + def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None): """ method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD + handle_sni: SNI handler, should take a connection object. Server + name can be retrieved like this: + + connection.get_servername() + + And you can specify the connection keys as follows: + + new_context = Context(TLSv1_METHOD) + new_context.use_privatekey(key) + new_context.use_certificate(cert) + connection.set_context(new_context) """ ctx = SSL.Context(method) if not options is None: ctx.set_options(options) - # SNI callback happens during do_handshake() - ctx.set_tlsext_servername_callback(self.handle_sni) + if handle_sni: + # SNI callback happens during do_handshake() + ctx.set_tlsext_servername_callback(handle_sni) ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) def ver(*args): @@ -290,23 +302,6 @@ class BaseHandler: # Remote has disconnected pass - def handle_sni(self, connection): - """ - Called if the client has given a server name indication. - - Server name can be retrieved like this: - - connection.get_servername() - - And you can specify the connection keys as follows: - - new_context = Context(TLSv1_METHOD) - new_context.use_privatekey(key) - new_context.use_certificate(cert) - connection.set_context(new_context) - """ - pass - def handle(self): # pragma: no cover raise NotImplementedError diff --git a/netlib/test.py b/netlib/test.py index 7d24d80e..3378279b 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -62,6 +62,7 @@ class TServer(tcp.TCPServer): self.ssl["key"], method = method, options = options, + handle_sni = getattr(h, "handle_sni", None) ) h.handle() h.finish() -- cgit v1.2.3 From 0fa63519654db2567995f3c3ac6e464796de66a3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 Feb 2013 09:28:48 +1300 Subject: ODict.keys --- netlib/odict.py | 3 +++ test/test_odict.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/netlib/odict.py b/netlib/odict.py index bddb3877..0759a5bf 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -36,6 +36,9 @@ class ODict: ret.append(i[1]) return ret + def keys(self): + return list(set([self._kconv(i[0]) for i in self.lst])) + def _filter_lst(self, k, lst): k = self._kconv(k) new = [] diff --git a/test/test_odict.py b/test/test_odict.py index d59ed67e..26bff357 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -63,6 +63,15 @@ class TestODict: self.od.add("foo", 1) assert [i for i in self.od] + def test_keys(self): + assert not self.od.keys() + self.od.add("foo", 1) + assert self.od.keys() == ["foo"] + self.od.add("foo", 2) + assert self.od.keys() == ["foo"] + self.od.add("bar", 2) + assert len(self.od.keys()) == 2 + def test_copy(self): self.od.add("foo", 1) self.od.add("foo", 2) @@ -122,3 +131,13 @@ class TestODictCaseless: self.od.add("bar", 3) del self.od["foo"] assert len(self.od) == 1 + + def test_keys(self): + assert not self.od.keys() + self.od.add("foo", 1) + assert self.od.keys() == ["foo"] + self.od.add("Foo", 2) + assert self.od.keys() == ["foo"] + self.od.add("bar", 2) + assert len(self.od.keys()) == 2 + -- cgit v1.2.3 From 97537417f01c17903fb4cebd59991eea57faa5e6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 2 Mar 2013 16:57:38 +1300 Subject: Factor out http.parse_response_line --- netlib/http.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 58993686..bc09c8a1 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -274,6 +274,20 @@ def read_http_body_response(rfile, headers, limit): return read_http_body(500, rfile, headers, all, limit) +def parse_response_line(line): + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if len(parts) != 3: + return None + proto, code, msg = parts + try: + code = int(code) + except ValueError: + return None + return (proto, code, msg) + + def read_response(rfile, method, body_size_limit): """ Return an (httpversion, code, msg, headers, content) tuple. @@ -283,19 +297,13 @@ def read_response(rfile, method, body_size_limit): line = rfile.readline() if not line: raise HttpErrorConnClosed(502, "Server disconnect.") - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if not len(parts) == 3: + parts = parse_response_line(line) + if not parts: raise HttpError(502, "Invalid server response: %s"%repr(line)) proto, code, msg = parts httpversion = parse_http_protocol(proto) if httpversion is None: raise HttpError(502, "Invalid HTTP version in line: %s"%repr(proto)) - try: - code = int(code) - except ValueError: - raise HttpError(502, "Invalid server response: %s"%repr(line)) headers = read_headers(rfile) if headers is None: raise HttpError(502, "Invalid headers.") -- cgit v1.2.3 From 0acab862a65ef4a1823a1bfb702d8be1e3d7b83d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 10:37:28 +1300 Subject: Integrate HTTP auth, test to 100% --- .coveragerc | 3 +- netlib/contrib/__init__.py | 0 netlib/contrib/md5crypt.py | 94 +++++++++++++++++++++++++++++++++++++ netlib/http.py | 22 ++++++++- netlib/http_auth.py | 113 +++++++++++++++++++++++++++++++++++++++++++++ test/data/htpasswd | 1 + test/test_http.py | 11 ++++- test/test_http_auth.py | 81 ++++++++++++++++++++++++++++++++ 8 files changed, 322 insertions(+), 3 deletions(-) create mode 100644 netlib/contrib/__init__.py create mode 100644 netlib/contrib/md5crypt.py create mode 100644 netlib/http_auth.py create mode 100644 test/data/htpasswd create mode 100644 test/test_http_auth.py diff --git a/.coveragerc b/.coveragerc index 99f57cb0..8076aebe 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,2 +1,3 @@ [report] -include = *netlib* +omit = *contrib* +include = *netlib/netlib* diff --git a/netlib/contrib/__init__.py b/netlib/contrib/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/netlib/contrib/md5crypt.py b/netlib/contrib/md5crypt.py new file mode 100644 index 00000000..d64ea8ac --- /dev/null +++ b/netlib/contrib/md5crypt.py @@ -0,0 +1,94 @@ +# Based on FreeBSD src/lib/libcrypt/crypt.c 1.2 +# http://www.freebsd.org/cgi/cvsweb.cgi/~checkout~/src/lib/libcrypt/crypt.c?rev=1.2&content-type=text/plain + +# Original license: +# * "THE BEER-WARE LICENSE" (Revision 42): +# * wrote this file. As long as you retain this notice you +# * can do whatever you want with this stuff. If we meet some day, and you think +# * this stuff is worth it, you can buy me a beer in return. Poul-Henning Kamp + +# This port adds no further stipulations. I forfeit any copyright interest. + +import md5 + +def md5crypt(password, salt, magic='$1$'): + # /* The password first, since that is what is most unknown */ /* Then our magic string */ /* Then the raw salt */ + m = md5.new() + m.update(password + magic + salt) + + # /* Then just as many characters of the MD5(pw,salt,pw) */ + mixin = md5.md5(password + salt + password).digest() + for i in range(0, len(password)): + m.update(mixin[i % 16]) + + # /* Then something really weird... */ + # Also really broken, as far as I can tell. -m + i = len(password) + while i: + if i & 1: + m.update('\x00') + else: + m.update(password[0]) + i >>= 1 + + final = m.digest() + + # /* and now, just to make sure things don't run too fast */ + for i in range(1000): + m2 = md5.md5() + if i & 1: + m2.update(password) + else: + m2.update(final) + + if i % 3: + m2.update(salt) + + if i % 7: + m2.update(password) + + if i & 1: + m2.update(final) + else: + m2.update(password) + + final = m2.digest() + + # This is the bit that uses to64() in the original code. + + itoa64 = './0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' + + rearranged = '' + for a, b, c in ((0, 6, 12), (1, 7, 13), (2, 8, 14), (3, 9, 15), (4, 10, 5)): + v = ord(final[a]) << 16 | ord(final[b]) << 8 | ord(final[c]) + for i in range(4): + rearranged += itoa64[v & 0x3f]; v >>= 6 + + v = ord(final[11]) + for i in range(2): + rearranged += itoa64[v & 0x3f]; v >>= 6 + + return magic + salt + '$' + rearranged + +if __name__ == '__main__': + + def test(clear_password, the_hash): + magic, salt = the_hash[1:].split('$')[:2] + magic = '$' + magic + '$' + return md5crypt(clear_password, salt, magic) == the_hash + + test_cases = ( + (' ', '$1$yiiZbNIH$YiCsHZjcTkYd31wkgW8JF.'), + ('pass', '$1$YeNsbWdH$wvOF8JdqsoiLix754LTW90'), + ('____fifteen____', '$1$s9lUWACI$Kk1jtIVVdmT01p0z3b/hw1'), + ('____sixteen_____', '$1$dL3xbVZI$kkgqhCanLdxODGq14g/tW1'), + ('____seventeen____', '$1$NaH5na7J$j7y8Iss0hcRbu3kzoJs5V.'), + ('__________thirty-three___________', '$1$HO7Q6vzJ$yGwp2wbL5D7eOVzOmxpsy.'), + ('apache', '$apr1$J.w5a/..$IW9y6DR0oO/ADuhlMF5/X1') + ) + + for clearpw, hashpw in test_cases: + if test(clearpw, hashpw): + print '%s: pass' % clearpw + else: + print '%s: FAIL' % clearpw diff --git a/netlib/http.py b/netlib/http.py index bc09c8a1..10b6a402 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,4 +1,4 @@ -import string, urlparse +import string, urlparse, binascii import odict class HttpError(Exception): @@ -169,6 +169,26 @@ def parse_http_protocol(s): return major, minor +def parse_http_basic_auth(s): + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v + + def parse_init(line): try: method, url, protocol = string.split(line) diff --git a/netlib/http_auth.py b/netlib/http_auth.py new file mode 100644 index 00000000..d478ab10 --- /dev/null +++ b/netlib/http_auth.py @@ -0,0 +1,113 @@ +import binascii +import contrib.md5crypt as md5crypt +import http + + +class NullProxyAuth(): + """ + No proxy auth at all (returns empty challange headers) + """ + def __init__(self, password_manager): + self.password_manager = password_manager + + def clean(self, headers): + """ + Clean up authentication headers, so they're not passed upstream. + """ + pass + + def authenticate(self, headers): + """ + Tests that the user is allowed to use the proxy + """ + return True + + def auth_challenge_headers(self): + """ + Returns a dictionary containing the headers require to challenge the user + """ + return {} + + +class BasicProxyAuth(NullProxyAuth): + CHALLENGE_HEADER = 'Proxy-Authenticate' + AUTH_HEADER = 'Proxy-Authorization' + def __init__(self, password_manager, realm): + NullProxyAuth.__init__(self, password_manager) + self.realm = realm + + def clean(self, headers): + del headers[self.AUTH_HEADER] + + def authenticate(self, headers): + auth_value = headers.get(self.AUTH_HEADER, []) + if not auth_value: + return False + parts = http.parse_http_basic_auth(auth_value[0]) + if not parts: + return False + scheme, username, password = parts + if scheme.lower()!='basic': + return False + if not self.password_manager.test(username, password): + return False + self.username = username + return True + + def auth_challenge_headers(self): + return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm} + + +class PassMan(): + def test(self, username, password_token): + return False + + +class PassManNonAnon: + """ + Ensure the user specifies a username, accept any password. + """ + def test(self, username, password_token): + if username: + return True + return False + + +class PassManHtpasswd: + """ + Read usernames and passwords from an htpasswd file + """ + def __init__(self, fp): + """ + Raises ValueError if htpasswd file is invalid. + """ + self.usernames = {} + for l in fp: + l = l.strip().split(':') + if len(l) != 2: + raise ValueError("Invalid htpasswd file.") + parts = l[1].split('$') + if len(parts) != 4: + raise ValueError("Invalid htpasswd file.") + self.usernames[l[0]] = dict( + token = l[1], + dummy = parts[0], + magic = parts[1], + salt = parts[2], + hashed_password = parts[3] + ) + + def test(self, username, password_token): + ui = self.usernames.get(username) + if not ui: + return False + expected = md5crypt.md5crypt(password_token, ui["salt"], '$'+ui["magic"]+'$') + return expected==ui["token"] + + +class PassManSingleUser: + def __init__(self, username, password): + self.username, self.password = username, password + + def test(self, username, password_token): + return self.username==username and self.password==password_token diff --git a/test/data/htpasswd b/test/data/htpasswd new file mode 100644 index 00000000..54c95b8c --- /dev/null +++ b/test/data/htpasswd @@ -0,0 +1 @@ +test:$apr1$/LkYxy3x$WI4.YbiJlu537jLGEW2eu1 diff --git a/test/test_http.py b/test/test_http.py index 666dfdbb..1c89900c 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -1,4 +1,4 @@ -import cStringIO, textwrap +import cStringIO, textwrap, binascii from netlib import http, odict import tutils @@ -291,3 +291,12 @@ def test_parse_url(): assert not http.parse_url("https://foo:bar") assert not http.parse_url("https://foo:") + +def test_parse_http_basic_auth(): + vals = ("basic", "foo", "bar") + assert http.parse_http_basic_auth(http.assemble_http_basic_auth(*vals)) == vals + assert not http.parse_http_basic_auth("") + assert not http.parse_http_basic_auth("foo bar") + v = "basic " + binascii.b2a_base64("foo") + assert not http.parse_http_basic_auth(v) + diff --git a/test/test_http_auth.py b/test/test_http_auth.py new file mode 100644 index 00000000..cae69f5e --- /dev/null +++ b/test/test_http_auth.py @@ -0,0 +1,81 @@ +import binascii, cStringIO +from netlib import odict, http_auth, http +import tutils + +class TestPassManNonAnon: + def test_simple(self): + p = http_auth.PassManNonAnon() + assert not p.test("", "") + assert p.test("user", "") + + +class TestPassManHtpasswd: + def test_file_errors(self): + s = cStringIO.StringIO("foo") + tutils.raises("invalid htpasswd", http_auth.PassManHtpasswd, s) + s = cStringIO.StringIO("foo:bar$foo") + tutils.raises("invalid htpasswd", http_auth.PassManHtpasswd, s) + + def test_simple(self): + f = open(tutils.test_data.path("data/htpasswd")) + pm = http_auth.PassManHtpasswd(f) + + vals = ("basic", "test", "test") + p = http.assemble_http_basic_auth(*vals) + assert pm.test("test", "test") + assert not pm.test("test", "foo") + assert not pm.test("foo", "test") + assert not pm.test("test", "") + assert not pm.test("", "") + + +class TestPassManSingleUser: + def test_simple(self): + pm = http_auth.PassManSingleUser("test", "test") + assert pm.test("test", "test") + assert not pm.test("test", "foo") + assert not pm.test("foo", "test") + + +class TestNullProxyAuth: + def test_simple(self): + na = http_auth.NullProxyAuth(http_auth.PassManNonAnon()) + assert not na.auth_challenge_headers() + assert na.authenticate("foo") + na.clean({}) + + +class TestBasicProxyAuth: + def test_simple(self): + ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test") + h = odict.ODictCaseless() + assert ba.auth_challenge_headers() + assert not ba.authenticate(h) + + def test_authenticate_clean(self): + ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test") + + hdrs = odict.ODictCaseless() + vals = ("basic", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] + assert ba.authenticate(hdrs) + + ba.clean(hdrs) + assert not ba.AUTH_HEADER in hdrs + + + hdrs[ba.AUTH_HEADER] = [""] + assert not ba.authenticate(hdrs) + + hdrs[ba.AUTH_HEADER] = ["foo"] + assert not ba.authenticate(hdrs) + + vals = ("foo", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] + assert not ba.authenticate(hdrs) + + ba = http_auth.BasicProxyAuth(http_auth.PassMan(), "test") + vals = ("basic", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] + assert not ba.authenticate(hdrs) + -- cgit v1.2.3 From 1fe1a802adbef93b5b024a85d8dafb112ed652bb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 12:16:09 +1300 Subject: 100% test coverage. --- netlib/http_auth.py | 2 +- netlib/tcp.py | 2 +- test/test_http.py | 1 + test/test_imports.py | 3 +++ 4 files changed, 6 insertions(+), 2 deletions(-) create mode 100644 test/test_imports.py diff --git a/netlib/http_auth.py b/netlib/http_auth.py index d478ab10..4adae179 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -96,7 +96,7 @@ class PassManHtpasswd: salt = parts[2], hashed_password = parts[3] ) - + def test(self, username, password_token): ui = self.usernames.get(username) if not ui: diff --git a/netlib/tcp.py b/netlib/tcp.py index 485d821f..07b28cf9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -298,7 +298,7 @@ class BaseHandler: self.close() self.wfile.close() self.rfile.close() - except socket.error: + except (socket.error, NetLibDisconnect): # Remote has disconnected pass diff --git a/test/test_http.py b/test/test_http.py index 1c89900c..05dfdb8f 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -2,6 +2,7 @@ import cStringIO, textwrap, binascii from netlib import http, odict import tutils + def test_httperror(): e = http.HttpError(404, "Not found") assert str(e) diff --git a/test/test_imports.py b/test/test_imports.py new file mode 100644 index 00000000..7b8a643b --- /dev/null +++ b/test/test_imports.py @@ -0,0 +1,3 @@ +# These are actually tests! +import netlib.http_status +import netlib.version -- cgit v1.2.3 From 2897ddfbee5ec3da72863cb8d5ee1370c9698f8a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 14:52:06 +1300 Subject: Stricter error checking for http.parse_url --- netlib/http.py | 13 +++++++++++++ test/test_http.py | 5 +++++ 2 files changed, 18 insertions(+) diff --git a/netlib/http.py b/netlib/http.py index 10b6a402..c864f1de 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -15,6 +15,11 @@ class HttpErrorConnClosed(HttpError): pass def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. + + Checks that: + port is an integer + host is a valid IDNA-encoded hostname + path is valid ASCII """ scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) if not scheme: @@ -34,6 +39,14 @@ def parse_url(url): path = urlparse.urlunparse(('', '', path, params, query, fragment)) if not path.startswith("/"): path = "/" + path + try: + host.decode("idna") + except ValueError: + return None + try: + path.decode("ascii") + except ValueError: + return None return scheme, host, port, path diff --git a/test/test_http.py b/test/test_http.py index 05dfdb8f..2cbba936 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -292,6 +292,11 @@ def test_parse_url(): assert not http.parse_url("https://foo:bar") assert not http.parse_url("https://foo:") + # Invalid IDNA + assert not http.parse_url("http://\xfafoo") + + assert not http.parse_url("http:/\xc6/localhost:56121") + def test_parse_http_basic_auth(): vals = ("basic", "foo", "bar") -- cgit v1.2.3 From cd4ed8530fa04fcbd54009e9db6ad9ea2518a10b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 15:03:57 +1300 Subject: Check that hosts in parse_url do not contain NULL bytes. --- netlib/http.py | 4 +++- test/test_http.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index c864f1de..1b03d330 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -18,7 +18,7 @@ def parse_url(url): Checks that: port is an integer - host is a valid IDNA-encoded hostname + host is a valid IDNA-encoded hostname with no null-bytes path is valid ASCII """ scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) @@ -43,6 +43,8 @@ def parse_url(url): host.decode("idna") except ValueError: return None + if "\0" in host: + return None try: path.decode("ascii") except ValueError: diff --git a/test/test_http.py b/test/test_http.py index 2cbba936..f41a4e2d 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -294,8 +294,9 @@ def test_parse_url(): # Invalid IDNA assert not http.parse_url("http://\xfafoo") - assert not http.parse_url("http:/\xc6/localhost:56121") + assert not http.parse_url("http://foo\0") + def test_parse_http_basic_auth(): -- cgit v1.2.3 From 7b9300743e879a8a2e35f5786b23a17261350ff9 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 15:08:17 +1300 Subject: More parse_url solidification: check that port is in range 0-65535 --- netlib/http.py | 4 +++- test/test_http.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/netlib/http.py b/netlib/http.py index 1b03d330..5628dd4d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -17,7 +17,7 @@ def parse_url(url): Returns a (scheme, host, port, path) tuple, or None on error. Checks that: - port is an integer + port is an integer 0-65535 host is a valid IDNA-encoded hostname with no null-bytes path is valid ASCII """ @@ -49,6 +49,8 @@ def parse_url(url): path.decode("ascii") except ValueError: return None + if not 0 <= port <= 65535: + return None return scheme, host, port, path diff --git a/test/test_http.py b/test/test_http.py index f41a4e2d..061aeb22 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -296,6 +296,7 @@ def test_parse_url(): assert not http.parse_url("http://\xfafoo") assert not http.parse_url("http:/\xc6/localhost:56121") assert not http.parse_url("http://foo\0") + assert not http.parse_url("http://foo:999999") -- cgit v1.2.3 From b21a7da142625e3b47d712cd21cbd440eb48f490 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 15:12:58 +1300 Subject: parse_url: Handle invalid IPv6 addresses --- netlib/http.py | 5 ++++- test/test_http.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 5628dd4d..2c9e69cb 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -21,7 +21,10 @@ def parse_url(url): host is a valid IDNA-encoded hostname with no null-bytes path is valid ASCII """ - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + try: + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + except ValueError: + return None if not scheme: return None if ':' in netloc: diff --git a/test/test_http.py b/test/test_http.py index 061aeb22..f7d861fd 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -294,11 +294,14 @@ def test_parse_url(): # Invalid IDNA assert not http.parse_url("http://\xfafoo") + # Invalid PATH assert not http.parse_url("http:/\xc6/localhost:56121") + # Null byte in host assert not http.parse_url("http://foo\0") + # Port out of range assert not http.parse_url("http://foo:999999") - - + # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt + assert not http.parse_url('http://lo[calhost') def test_parse_http_basic_auth(): vals = ("basic", "foo", "bar") -- cgit v1.2.3 From 5a050bb6b2b1a0bf05f4cd35d87e6f1d7a2608c0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 21:36:19 +1300 Subject: Tighten up checks on port ranges and path character sets. --- netlib/http.py | 37 ++++++++++++++++++++++++++----------- netlib/utils.py | 8 ++++++++ test/test_http.py | 5 +++-- 3 files changed, 37 insertions(+), 13 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 2c9e69cb..0f2caa5a 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,5 +1,5 @@ import string, urlparse, binascii -import odict +import odict, utils class HttpError(Exception): def __init__(self, code, msg): @@ -12,6 +12,22 @@ class HttpError(Exception): class HttpErrorConnClosed(HttpError): pass +def _is_valid_port(port): + if not 0 <= port <= 65535: + return False + return True + + +def _is_valid_host(host): + try: + host.decode("idna") + except ValueError: + return False + if "\0" in host: + return None + return True + + def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. @@ -42,17 +58,11 @@ def parse_url(url): path = urlparse.urlunparse(('', '', path, params, query, fragment)) if not path.startswith("/"): path = "/" + path - try: - host.decode("idna") - except ValueError: + if not _is_valid_host(host): return None - if "\0" in host: + if not utils.isascii(path): return None - try: - path.decode("ascii") - except ValueError: - return None - if not 0 <= port <= 65535: + if not _is_valid_port(port): return None return scheme, host, port, path @@ -236,6 +246,10 @@ def parse_init_connect(line): port = int(port) except ValueError: return None + if not _is_valid_port(port): + return None + if not _is_valid_host(host): + return None return host, port, httpversion @@ -260,7 +274,8 @@ def parse_init_http(line): if not v: return None method, url, httpversion = v - + if not utils.isascii(url): + return None if not (url.startswith("/") or url == "*"): return None return method, url, httpversion diff --git a/netlib/utils.py b/netlib/utils.py index 7621a1dc..61fd54ae 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,4 +1,12 @@ +def isascii(s): + try: + s.decode("ascii") + except ValueError: + return False + return True + + def cleanBin(s, fixspacing=False): """ Cleans binary data to make it safe to display. If fixspacing is True, diff --git a/test/test_http.py b/test/test_http.py index f7d861fd..e98a891f 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -136,6 +136,8 @@ def test_parse_http_protocol(): def test_parse_init_connect(): assert http.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not http.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") + assert not http.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") assert not http.parse_init_connect("bogus") assert not http.parse_init_connect("GET host.com:443 HTTP/1.0") assert not http.parse_init_connect("CONNECT host.com443 HTTP/1.0") @@ -164,11 +166,10 @@ def test_parse_init_http(): assert m == "GET" assert u == "/test" assert httpversion == (1, 1) - assert not http.parse_init_http("invalid") assert not http.parse_init_http("GET invalid HTTP/1.1") assert not http.parse_init_http("GET /test foo/1.1") - + assert not http.parse_init_http("GET /test\xc0 HTTP/1.1") class TestReadHeaders: def _read(self, data, verbatim=False): -- cgit v1.2.3 From 5f0ad7b2a6b857419017e3e72062ab4e0e328238 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 3 Mar 2013 22:13:23 +1300 Subject: Ensure that HTTP methods are ASCII. --- netlib/http.py | 2 ++ test/test_http.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/netlib/http.py b/netlib/http.py index 0f2caa5a..f1a2bfb5 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -227,6 +227,8 @@ def parse_init(line): httpversion = parse_http_protocol(protocol) if not httpversion: return None + if not utils.isascii(method): + return None return method, url, httpversion diff --git a/test/test_http.py b/test/test_http.py index e98a891f..77cc2624 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -136,6 +136,7 @@ def test_parse_http_protocol(): def test_parse_init_connect(): assert http.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not http.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") assert not http.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") assert not http.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") assert not http.parse_init_connect("bogus") @@ -155,6 +156,9 @@ def test_prase_init_proxy(): assert pa == "/test" assert httpversion == (1, 1) + u = "G\xfeET http://foo.com:8888/test HTTP/1.1" + assert not http.parse_init_proxy(u) + assert not http.parse_init_proxy("invalid") assert not http.parse_init_proxy("GET invalid HTTP/1.1") assert not http.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") @@ -162,10 +166,14 @@ def test_prase_init_proxy(): def test_parse_init_http(): u = "GET /test HTTP/1.1" - m, u, httpversion= http.parse_init_http(u) + m, u, httpversion = http.parse_init_http(u) assert m == "GET" assert u == "/test" assert httpversion == (1, 1) + + u = "G\xfeET /test HTTP/1.1" + assert not http.parse_init_http(u) + assert not http.parse_init_http("invalid") assert not http.parse_init_http("GET invalid HTTP/1.1") assert not http.parse_init_http("GET /test foo/1.1") -- cgit v1.2.3 From a94d17970e739cdda4e6223b3af8136b05e6e192 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 5 Mar 2013 09:09:52 +1300 Subject: Sync version number with mitmproxy. --- netlib/version.py | 2 +- test/test_wsgi.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/netlib/version.py b/netlib/version.py index 30a4c0f9..d90c000c 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 2, 2) +IVERSION = (0, 9) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION diff --git a/test/test_wsgi.py b/test/test_wsgi.py index 7763b9e5..91a8ff7a 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -1,5 +1,4 @@ import cStringIO, sys -import libpry from netlib import wsgi, odict -- cgit v1.2.3 From 241465c368c0117a8d86c17c44b39fed3116c6e0 Mon Sep 17 00:00:00 2001 From: Tim Becker Date: Fri, 19 Apr 2013 15:37:14 +0200 Subject: extensions aren't supported in v1, set to v3 (value=2) if using them. --- netlib/certutils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/netlib/certutils.py b/netlib/certutils.py index 859c93f1..8407dcc8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -110,6 +110,7 @@ def dummy_cert(fp, ca, commonname, sans): cert.set_subject(req.get_subject()) cert.set_serial_number(int(time.time()*10000)) if ss: + cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert.set_pubkey(req.get_pubkey()) cert.sign(key, "sha1") -- cgit v1.2.3 From 9c13224353eefbb6b1824ded20846036b07c558f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 5 May 2013 13:49:20 +1200 Subject: Fix exception hierarchy. --- netlib/tcp.py | 4 ++-- test/test_http.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 07b28cf9..b67ad0bb 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -35,8 +35,8 @@ OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG class NetLibError(Exception): pass -class NetLibDisconnect(Exception): pass -class NetLibTimeout(Exception): pass +class NetLibDisconnect(NetLibError): pass +class NetLibTimeout(NetLibError): pass class _FileLike: diff --git a/test/test_http.py b/test/test_http.py index 77cc2624..62d0c3dc 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -146,7 +146,7 @@ def test_parse_init_connect(): assert not http.parse_init_connect("CONNECT host.com:foo HTTP/1.0") -def test_prase_init_proxy(): +def test_parse_init_proxy(): u = "GET http://foo.com:8888/test HTTP/1.1" m, s, h, po, pa, httpversion = http.parse_init_proxy(u) assert m == "GET" -- cgit v1.2.3 From 7f0aa415e1ab95ed6b27a760cc9aa8ff4ee85080 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 13 May 2013 08:48:21 +1200 Subject: Add a request_client_cert argument to server SSL conversion. By default, we now do not request the client cert. We're supposed to be able to do this with no negative effects - if the client has no cert to present, we're notified and proceed as usual. Unfortunately, Android seems to have a bug (tested on 4.2.2) - when an Android client is asked to present a certificate it does not have, it hangs up, which is frankly bogus. Some time down the track we may be able to make the proper behaviour the default again, but until then we're conservative. --- netlib/certutils.py | 3 --- netlib/tcp.py | 20 ++++++++++++++++---- netlib/test.py | 3 ++- test/test_tcp.py | 6 ++++++ 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 8407dcc8..f18318f6 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -5,9 +5,6 @@ from pyasn1.error import PyAsn1Error import OpenSSL import tcp -CERT_SLEEP_TIME = 1 -CERT_EXPIRY = str(365 * 3) - def create_ca(): key = OpenSSL.crypto.PKey() diff --git a/netlib/tcp.py b/netlib/tcp.py index b67ad0bb..47953724 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -240,6 +240,7 @@ class TCPClient: class BaseHandler: """ The instantiator is expected to call the handle() and finish() methods. + """ rbufsize = -1 wbufsize = -1 @@ -252,9 +253,10 @@ class BaseHandler: self.server = server self.finished = False self.ssl_established = False + self.clientcert = None - def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None): + def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False): """ method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD handle_sni: SNI handler, should take a connection object. Server @@ -268,6 +270,15 @@ class BaseHandler: new_context.use_privatekey(key) new_context.use_certificate(cert) connection.set_context(new_context) + + The request_client_cert argument requires some explanation. We're + supposed to be able to do this with no negative effects - if the + client has no cert to present, we're notified and proceed as usual. + Unfortunately, Android seems to have a bug (tested on 4.2.2) - when + an Android client is asked to present a certificate it does not + have, it hangs up, which is frankly bogus. Some time down the track + we may be able to make the proper behaviour the default again, but + until then we're conservative. """ ctx = SSL.Context(method) if not options is None: @@ -277,9 +288,10 @@ class BaseHandler: ctx.set_tlsext_servername_callback(handle_sni) ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) - def ver(*args): - self.clientcert = certutils.SSLCert(args[1]) - ctx.set_verify(SSL.VERIFY_PEER, ver) + if request_client_cert: + def ver(*args): + self.clientcert = certutils.SSLCert(args[1]) + ctx.set_verify(SSL.VERIFY_PEER, ver) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True self.connection.set_accept_state() diff --git a/netlib/test.py b/netlib/test.py index 3378279b..deaef64e 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -62,7 +62,8 @@ class TServer(tcp.TCPServer): self.ssl["key"], method = method, options = options, - handle_sni = getattr(h, "handle_sni", None) + handle_sni = getattr(h, "handle_sni", None), + request_client_cert = self.ssl["request_client_cert"] ) h.handle() h.finish() diff --git a/test/test_tcp.py b/test/test_tcp.py index de14ab25..318d2abc 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -111,6 +111,7 @@ class TestServerSSL(test.ServerTestBase): ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), + request_client_cert = False, v3_only = False ) def test_echo(self): @@ -131,6 +132,7 @@ class TestSSLv3Only(test.ServerTestBase): ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), + request_client_cert = False, v3_only = True ) def test_failure(self): @@ -144,6 +146,7 @@ class TestSSLClientCert(test.ServerTestBase): ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), + request_client_cert = True, v3_only = False ) def test_clientcert(self): @@ -167,6 +170,7 @@ class TestSNI(test.ServerTestBase): ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), + request_client_cert = False, v3_only = False ) def test_echo(self): @@ -181,6 +185,7 @@ class TestSSLDisconnect(test.ServerTestBase): ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), + request_client_cert = False, v3_only = False ) def test_echo(self): @@ -228,6 +233,7 @@ class TestSSLTimeOut(test.ServerTestBase): ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), + request_client_cert = False, v3_only = False ) def test_timeout_client(self): -- cgit v1.2.3 From d698ee50a74ac33730ce19cf5eeb36935bd3643d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 15 May 2013 08:36:22 +1200 Subject: Add MANIFEST.in --- MANIFEST.in | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..59226fdc --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,4 @@ +include README +recursive-include test * +recursive-include netlib * +recursive-exclude test *.swo *.swp *.pyc -- cgit v1.2.3 From f02c04d9d8d5460b45810c466b42dda2842d8002 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 14 Jun 2013 20:46:14 +0200 Subject: add test case for invalid characters in cert commonnames --- test/test_certutils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_certutils.py b/test/test_certutils.py index 334a6be4..f57f8f6d 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -34,6 +34,7 @@ class TestCertStore: assert not c.get_cert("foo.com", []) assert c.get_cert("foo.com", [], ca) assert c.get_cert("foo.com", [], ca) + assert c.get_cert("*.foo.com", [], ca) c.cleanup() def test_check_domain(self): -- cgit v1.2.3 From c9ab1c60b5d43f0b4d645c751350b16e9e562b55 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 16 Jun 2013 00:28:21 +0200 Subject: always read files in binary mode --- netlib/certutils.py | 12 ++++++------ test/test_certutils.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index f18318f6..4c06eb8f 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -48,23 +48,23 @@ def dummy_ca(path): key, ca = create_ca() # Dump the CA plus private key - f = open(path, "w") + f = open(path, "wb") f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Dump the certificate in PEM format - f = open(os.path.join(dirname, basename + "-cert.pem"), "w") + f = open(os.path.join(dirname, basename + "-cert.pem"), "wb") f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Create a .cer file with the same contents for Android - f = open(os.path.join(dirname, basename + "-cert.cer"), "w") + f = open(os.path.join(dirname, basename + "-cert.cer"), "wb") f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(dirname, basename + "-cert.p12"), "w") + f = open(os.path.join(dirname, basename + "-cert.p12"), "wb") p12 = OpenSSL.crypto.PKCS12() p12.set_certificate(ca) p12.set_privatekey(key) @@ -88,7 +88,7 @@ def dummy_cert(fp, ca, commonname, sans): ss.append("DNS: %s"%i) ss = ", ".join(ss) - raw = file(ca, "r").read() + raw = file(ca, "rb").read() ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) @@ -165,7 +165,7 @@ class CertStore: if os.path.exists(certpath): return certpath elif cacert: - f = open(certpath, "w") + f = open(certpath, "wb") dummy_cert(f, cacert, commonname, sans) return certpath diff --git a/test/test_certutils.py b/test/test_certutils.py index 334a6be4..89b6ff50 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -53,22 +53,22 @@ class TestDummyCert: assert certutils.dummy_ca(cacert) p = os.path.join(d, "foo") certutils.dummy_cert( - file(p, "w"), + file(p, "wb"), cacert, "foo.com", ["one.com", "two.com", "*.three.com"] ) - assert file(p).read() + assert file(p,"rb").read() class TestSSLCert: def test_simple(self): - c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert"), "r").read()) + c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert"), "rb").read()) assert c.cn == "google.com" assert len(c.altnames) == 436 - c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert_2"), "r").read()) + c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert_2"), "rb").read()) assert c.cn == "www.inode.co.nz" assert len(c.altnames) == 2 assert c.digest("sha1") @@ -82,11 +82,11 @@ class TestSSLCert: c.has_expired def test_err_broken_sans(self): - c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert_weird1"), "r").read()) + c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert_weird1"), "rb").read()) # This breaks unless we ignore a decoding error. c.altnames def test_der(self): - d = file(tutils.test_data.path("data/dercert")).read() + d = file(tutils.test_data.path("data/dercert"),"rb").read() s = certutils.SSLCert.from_der(d) assert s.cn -- cgit v1.2.3 From 73f8a1e2e0006c2a37ae6264afe70a8207ffbb54 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 16 Jun 2013 13:38:39 +1200 Subject: Bump version. --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index d90c000c..63a9d862 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 9) +IVERSION = (0, 9, 1) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From 68e2e782b0afdc03844b107c28627391c51dd036 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 17 Jun 2013 17:03:17 +0200 Subject: attempt to fix 'half-duplex' TCP close sequence --- netlib/tcp.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 47953724..e37cb707 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -230,11 +230,15 @@ class TCPClient: if self.ssl_established: self.connection.shutdown() else: - self.connection.shutdown(socket.SHUT_RDWR) - self.connection.close() + self.connection.shutdown(socket.SHUT_WR) + #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. + #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + while self.connection.recv(4096): + pass except (socket.error, SSL.Error): # Socket probably already closed pass + self.connection.close() class BaseHandler: @@ -328,10 +332,15 @@ class BaseHandler: if self.ssl_established: self.connection.shutdown() else: - self.connection.shutdown(socket.SHUT_RDWR) + self.connection.shutdown(socket.SHUT_WR) + #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. + #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + while self.connection.recv(4096): + pass except (socket.error, SSL.Error): # Socket probably already closed pass + self.connection.close() -- cgit v1.2.3 From 02376b6a75fdb397a865697723f7282dbf70deca Mon Sep 17 00:00:00 2001 From: Andrey Plotnikov Date: Sun, 7 Jul 2013 13:33:56 +0800 Subject: Add socket binding support for TCPClient --- netlib/tcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 47953724..b5e9e2c4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -173,11 +173,12 @@ class Reader(_FileLike): class TCPClient: rbufsize = -1 wbufsize = -1 - def __init__(self, host, port): + def __init__(self, host, port, source_address=None): self.host, self.port = host, port self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False + self.source_address = source_address def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ @@ -209,6 +210,8 @@ class TCPClient: try: addr = socket.gethostbyname(self.host) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if self.source_address: + connection.bind(self.source_address) connection.connect((addr, self.port)) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) -- cgit v1.2.3 From f5fdfd8a9f17e0fe213a9cf54acae84e4bc31462 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 30 Jul 2013 09:42:13 +1200 Subject: Clarify the interface for flush and close methods. --- netlib/tcp.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 69ad2da5..123c6515 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -93,7 +93,7 @@ class Writer(_FileLike): if hasattr(self.o, "flush"): try: self.o.flush() - except socket.error, v: + except (socket.error, IOError), v: raise NetLibDisconnect(str(v)) def write(self, v): @@ -215,7 +215,7 @@ class TCPClient: connection.connect((addr, self.port)) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) - except socket.error, err: + except (socket.error, IOError), err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection @@ -238,16 +238,16 @@ class TCPClient: #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html while self.connection.recv(4096): pass - except (socket.error, SSL.Error): + self.connection.close() + except (socket.error, SSL.Error, IOError): # Socket probably already closed pass - self.connection.close() class BaseHandler: """ The instantiator is expected to call the handle() and finish() methods. - + """ rbufsize = -1 wbufsize = -1 -- cgit v1.2.3 From b9f06b473cd464e82bc53a973c5e190f93377bce Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 10 Aug 2013 23:07:09 +1200 Subject: Better handling of cert errors. --- netlib/tcp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 123c6515..df1f8fea 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -37,6 +37,7 @@ OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG class NetLibError(Exception): pass class NetLibDisconnect(NetLibError): pass class NetLibTimeout(NetLibError): pass +class NetLibSSLError(NetLibError): pass class _FileLike: @@ -129,6 +130,8 @@ class Reader(_FileLike): data = self.o.read(rlen) except SSL.ZeroReturnError: break + except SSL.Error, v: + raise NetLibSSLError(v.message) except SSL.WantReadError: if (time.time() - start) < self.o.gettimeout(): time.sleep(0.1) -- cgit v1.2.3 From 2da57ecff0e9572e45663dbad1c5f520e57c531f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 11 Aug 2013 11:47:07 +1200 Subject: Correct order of precedence for SSL errors. --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index df1f8fea..f4a8acf9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -130,8 +130,6 @@ class Reader(_FileLike): data = self.o.read(rlen) except SSL.ZeroReturnError: break - except SSL.Error, v: - raise NetLibSSLError(v.message) except SSL.WantReadError: if (time.time() - start) < self.o.gettimeout(): time.sleep(0.1) @@ -144,6 +142,8 @@ class Reader(_FileLike): raise NetLibDisconnect except SSL.SysCallError: raise NetLibDisconnect + except SSL.Error, v: + raise NetLibSSLError(v.message) self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break -- cgit v1.2.3 From 62edceee093dd54956ed5b623dfb4cb8c1309a16 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 12 Aug 2013 16:03:29 +1200 Subject: Revamp dummy cert generation. We no longer use on-disk storage - we just keep the certs in memory. --- netlib/certutils.py | 45 +++++++++++++-------------------------------- netlib/tcp.py | 3 ++- netlib/test.py | 7 +++++-- test/test_certutils.py | 14 +++----------- 4 files changed, 23 insertions(+), 46 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 4c06eb8f..7dcb5450 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -73,7 +73,7 @@ def dummy_ca(path): return True -def dummy_cert(fp, ca, commonname, sans): +def dummy_cert(ca, commonname, sans): """ Generates and writes a certificate to fp. @@ -111,27 +111,15 @@ def dummy_cert(fp, ca, commonname, sans): cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert.set_pubkey(req.get_pubkey()) cert.sign(key, "sha1") - - fp.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)) - fp.close() + return SSLCert(cert) class CertStore: """ - Implements an on-disk certificate store. + Implements an in-memory certificate store. """ - def __init__(self, certdir=None): - """ - certdir: The certificate store directory. If None, a temporary - directory will be created, and destroyed when the .cleanup() method - is called. - """ - if certdir: - self.remove = False - self.certdir = certdir - else: - self.remove = True - self.certdir = tempfile.mkdtemp(prefix="certstore") + def __init__(self): + self.certs = {} def check_domain(self, commonname): try: @@ -145,33 +133,26 @@ class CertStore: return False return True - def get_cert(self, commonname, sans, cacert=False): + def get_cert(self, commonname, sans, cacert): """ - Returns the path to a certificate. + Returns an SSLCert object. commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. sans: A list of Subject Alternate Names. - cacert: An optional path to a CA certificate. If specified, the - cert is created if it does not exist, else return None. + cacert: The path to a CA certificate. Return None if the certificate could not be found or generated. """ if not self.check_domain(commonname): return None - certpath = os.path.join(self.certdir, commonname + ".pem") - if os.path.exists(certpath): - return certpath - elif cacert: - f = open(certpath, "wb") - dummy_cert(f, cacert, commonname, sans) - return certpath - - def cleanup(self): - if self.remove: - shutil.rmtree(self.certdir) + if commonname in self.certs: + return self.certs[commonname] + c = dummy_cert(cacert, commonname, sans) + self.certs[commonname] = c + return c class _GeneralName(univ.Choice): diff --git a/netlib/tcp.py b/netlib/tcp.py index f4a8acf9..31e9a398 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -268,6 +268,7 @@ class BaseHandler: def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False): """ + cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: @@ -297,7 +298,7 @@ class BaseHandler: # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(handle_sni) ctx.use_privatekey_file(key) - ctx.use_certificate_file(cert) + ctx.use_certificate(cert.x509) if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) diff --git a/netlib/test.py b/netlib/test.py index deaef64e..661395c5 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,5 +1,5 @@ import threading, Queue, cStringIO -import tcp +import tcp, certutils class ServerThread(threading.Thread): def __init__(self, server): @@ -51,6 +51,9 @@ class TServer(tcp.TCPServer): h = self.handler_klass(request, client_address, self) self.last_handler = h if self.ssl: + cert = certutils.SSLCert.from_pem( + file(self.ssl["cert"], "r").read() + ) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 @@ -58,7 +61,7 @@ class TServer(tcp.TCPServer): method = tcp.SSLv23_METHOD options = None h.convert_to_ssl( - self.ssl["cert"], + cert, self.ssl["key"], method = method, options = options, diff --git a/test/test_certutils.py b/test/test_certutils.py index b335e946..0b4baf75 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -21,21 +21,16 @@ class TestCertStore: with tutils.tmpdir() as d: ca = os.path.join(d, "ca") assert certutils.dummy_ca(ca) - c = certutils.CertStore(d) - c.cleanup() - assert os.path.exists(d) + c = certutils.CertStore() def test_create_tmp(self): with tutils.tmpdir() as d: ca = os.path.join(d, "ca") assert certutils.dummy_ca(ca) c = certutils.CertStore() - assert not c.get_cert("../foo.com", []) - assert not c.get_cert("foo.com", []) assert c.get_cert("foo.com", [], ca) assert c.get_cert("foo.com", [], ca) assert c.get_cert("*.foo.com", [], ca) - c.cleanup() def test_check_domain(self): c = certutils.CertStore() @@ -52,15 +47,12 @@ class TestDummyCert: with tutils.tmpdir() as d: cacert = os.path.join(d, "cacert") assert certutils.dummy_ca(cacert) - p = os.path.join(d, "foo") - certutils.dummy_cert( - file(p, "wb"), + r = certutils.dummy_cert( cacert, "foo.com", ["one.com", "two.com", "*.three.com"] ) - assert file(p,"rb").read() - + assert r.cn == "foo.com" class TestSSLCert: -- cgit v1.2.3 From 0fed8dc8eb2440a35b5ce95ba7e7360441ff677c Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 17 Aug 2013 14:44:57 +0200 Subject: update gitignore to not include PyCharms --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index f53cd2e2..e66d51fe 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ MANIFEST *.swp *.swo .coverage +.idea \ No newline at end of file -- cgit v1.2.3 From c44f354fd0f9b4f1432913dd70cf1579910dfa4b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 17 Aug 2013 16:15:37 +0200 Subject: fix windows bugs --- netlib/tcp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 31e9a398..2de647ae 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -235,6 +235,7 @@ class TCPClient: try: if self.ssl_established: self.connection.shutdown() + self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. @@ -302,6 +303,7 @@ class BaseHandler: if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) + return True ctx.set_verify(SSL.VERIFY_PEER, ver) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True @@ -338,6 +340,7 @@ class BaseHandler: try: if self.ssl_established: self.connection.shutdown() + self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. -- cgit v1.2.3 From 28a0030c1ecacb8ac5c6e6453b6a22bdf94d9f7e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 19 Aug 2013 19:41:20 +0200 Subject: compatibility fixes for windows --- netlib/tcp.py | 3 ++- netlib/test.py | 2 +- setup.py | 2 +- test/test_http_auth.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 2de647ae..f4a713f9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -303,7 +303,8 @@ class BaseHandler: if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) - return True + # err 20 = X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY + #return True ctx.set_verify(SSL.VERIFY_PEER, ver) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True diff --git a/netlib/test.py b/netlib/test.py index 661395c5..87802bd5 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -52,7 +52,7 @@ class TServer(tcp.TCPServer): self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( - file(self.ssl["cert"], "r").read() + file(self.ssl["cert"], "rb").read() ) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD diff --git a/setup.py b/setup.py index e0dff0ff..1b2a14f9 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ def findPackages(path, dataExclude=[]): return packages, package_data -long_description = file("README").read() +long_description = file("README","rb").read() packages, package_data = findPackages("netlib") setup( name = "netlib", diff --git a/test/test_http_auth.py b/test/test_http_auth.py index cae69f5e..83de0fa1 100644 --- a/test/test_http_auth.py +++ b/test/test_http_auth.py @@ -17,7 +17,7 @@ class TestPassManHtpasswd: tutils.raises("invalid htpasswd", http_auth.PassManHtpasswd, s) def test_simple(self): - f = open(tutils.test_data.path("data/htpasswd")) + f = open(tutils.test_data.path("data/htpasswd"),"rb") pm = http_auth.PassManHtpasswd(f) vals = ("basic", "test", "test") -- cgit v1.2.3 From d5b3e397e142ae60275fb89ea765423903e99bb6 Mon Sep 17 00:00:00 2001 From: Israel Nir Date: Wed, 21 Aug 2013 13:42:30 +0300 Subject: adding cipher list selection option to BaseHandler --- netlib/tcp.py | 4 +++- netlib/test.py | 3 ++- test/test_tcp.py | 25 +++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 31e9a398..f1496a32 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -266,7 +266,7 @@ class BaseHandler: self.clientcert = None - def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False): + def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -294,6 +294,8 @@ class BaseHandler: ctx = SSL.Context(method) if not options is None: ctx.set_options(options) + if cipher_list: + ctx.set_cipher_list(cipher_list) if handle_sni: # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(handle_sni) diff --git a/netlib/test.py b/netlib/test.py index 661395c5..139d95bb 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -66,7 +66,8 @@ class TServer(tcp.TCPServer): method = method, options = options, handle_sni = getattr(h, "handle_sni", None), - request_client_cert = self.ssl["request_client_cert"] + request_client_cert = self.ssl["request_client_cert"], + cipher_list = self.ssl.get("cipher_list", None) ) h.handle() h.finish() diff --git a/test/test_tcp.py b/test/test_tcp.py index 318d2abc..8fa151af 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -34,6 +34,15 @@ class CertHandler(tcp.BaseHandler): self.wfile.flush() +class ClientCipherListHandler(tcp.BaseHandler): + sni = None + + def handle(self): + print self.connection.get_cipher_list() + self.wfile.write("%s"%self.connection.get_cipher_list()) + self.wfile.flush() + + class DisconnectHandler(tcp.BaseHandler): def handle(self): self.close() @@ -180,6 +189,22 @@ class TestSNI(test.ServerTestBase): assert c.rfile.readline() == "foo.com" +class TestClientCipherList(test.ServerTestBase): + handler = ClientCipherListHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + request_client_cert = False, + v3_only = False, + cipher_list = 'RC4-SHA' + ) + def test_echo(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.convert_to_ssl(sni="foo.com") + assert c.rfile.readline() == "['RC4-SHA']" + + class TestSSLDisconnect(test.ServerTestBase): handler = DisconnectHandler ssl = dict( -- cgit v1.2.3 From 7428f954744725381ced7c273609ca14d767dfff Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 25 Aug 2013 10:22:09 +1200 Subject: Handle interrupted system call errors. --- netlib/tcp.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 31e9a398..bee1f75b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -376,7 +376,13 @@ class TCPServer: self.__is_shut_down.clear() try: while not self.__shutdown_request: - r, w, e = select.select([self.socket], [], [], poll_interval) + try: + r, w, e = select.select([self.socket], [], [], poll_interval) + except select.error, ex: + if ex[0] == 4: + continue + else: + raise if self.socket in r: request, client_address = self.socket.accept() t = threading.Thread( -- cgit v1.2.3 From 8a261b2c01fe49de896bf9808af8fbb66b300cfc Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 25 Aug 2013 10:30:48 +1200 Subject: Bump version. --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index 63a9d862..32013c35 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 9, 1) +IVERSION = (0, 9, 2) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From 98f765f693fc4fa7245c3179da1d791661ed502a Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 24 Sep 2013 21:18:41 +0200 Subject: Don't create a certificate request when creating a dummy cert --- netlib/certutils.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 7dcb5450..60e41427 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -92,24 +92,16 @@ def dummy_cert(ca, commonname, sans): ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - req = OpenSSL.crypto.X509Req() - subj = req.get_subject() - subj.CN = commonname - req.set_pubkey(ca.get_pubkey()) - req.sign(key, "sha1") - if ss: - req.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) - cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600) cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) cert.set_issuer(ca.get_subject()) - cert.set_subject(req.get_subject()) + cert.get_subject().CN = commonname cert.set_serial_number(int(time.time()*10000)) if ss: cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) - cert.set_pubkey(req.get_pubkey()) + cert.set_pubkey(ca.get_pubkey()) cert.sign(key, "sha1") return SSLCert(cert) -- cgit v1.2.3 From 53b7c5abdd7c6dbb8ecaa1aa1000296f86eb45fa Mon Sep 17 00:00:00 2001 From: Sean Coates Date: Mon, 7 Oct 2013 16:48:30 -0400 Subject: allow specification of o, cn, expiry --- netlib/certutils.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 60e41427..a21f0188 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -5,17 +5,20 @@ from pyasn1.error import PyAsn1Error import OpenSSL import tcp +default_exp = 62208000 # =24 * 60 * 60 * 720 +default_o = "mitmproxy" +default_cn = "mitmproxy" -def create_ca(): +def create_ca(o=default_o, cn=default_cn, exp=default_exp): key = OpenSSL.crypto.PKey() key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) ca = OpenSSL.crypto.X509() ca.set_serial_number(int(time.time()*10000)) ca.set_version(2) - ca.get_subject().CN = "mitmproxy" - ca.get_subject().O = "mitmproxy" + ca.get_subject().CN = cn + ca.get_subject().O = o ca.gmtime_adj_notBefore(0) - ca.gmtime_adj_notAfter(24 * 60 * 60 * 720) + ca.gmtime_adj_notAfter(exp) ca.set_issuer(ca.get_subject()) ca.set_pubkey(key) ca.add_extensions([ @@ -35,7 +38,7 @@ def create_ca(): return key, ca -def dummy_ca(path): +def dummy_ca(path, o=default_o, cn=default_cn, exp=default_exp): dirname = os.path.dirname(path) if not os.path.exists(dirname): os.makedirs(dirname) @@ -45,7 +48,7 @@ def dummy_ca(path): else: basename = os.path.basename(path) - key, ca = create_ca() + key, ca = create_ca(o=o, cn=cn, exp=exp) # Dump the CA plus private key f = open(path, "wb") -- cgit v1.2.3 From 642b3f002ed7020ee359d23d46802b0bb02c1018 Mon Sep 17 00:00:00 2001 From: Sean Coates Date: Mon, 7 Oct 2013 16:55:35 -0400 Subject: remove tempfile and shutil imports because they're not actually used --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 60e41427..dab7e318 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,4 +1,4 @@ -import os, ssl, time, datetime, tempfile, shutil +import os, ssl, time, datetime from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error -- cgit v1.2.3 From 5e4ccbd7edc6eebf9eee25fd4d6ca64994ed6522 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 19 Nov 2013 04:11:24 +0100 Subject: attempt to fix #24 --- netlib/http.py | 17 ++++------------- test/test_http.py | 22 +++++----------------- 2 files changed, 9 insertions(+), 30 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index f1a2bfb5..7060b688 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -283,32 +283,23 @@ def parse_init_http(line): return method, url, httpversion -def request_connection_close(httpversion, headers): +def connection_close(httpversion, headers): """ - Checks the request to see if the client connection should be closed. + Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1 """ + # At first, check if we have an explicit Connection header. if "connection" in headers: toks = get_header_tokens(headers, "connection") if "close" in toks: return True elif "keep-alive" in toks: return False - # HTTP 1.1 connections are assumed to be persistent + # If we don't have a Connection header, HTTP 1.1 connections are assumed to be persistent if httpversion == (1, 1): return False return True -def response_connection_close(httpversion, headers): - """ - Checks the response to see if the client connection should be closed. - """ - if request_connection_close(httpversion, headers): - return True - elif (not has_chunked_encoding(headers)) and "content-length" in headers: - return False - return True - def read_http_body_request(rfile, wfile, headers, httpversion, limit): """ diff --git a/test/test_http.py b/test/test_http.py index 62d0c3dc..4d89bf24 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -38,28 +38,16 @@ def test_read_chunked(): tutils.raises("too large", http.read_chunked, 500, s, 2) -def test_request_connection_close(): +def test_connection_close(): h = odict.ODictCaseless() - assert http.request_connection_close((1, 0), h) - assert not http.request_connection_close((1, 1), h) + assert http.connection_close((1, 0), h) + assert not http.connection_close((1, 1), h) h["connection"] = ["keep-alive"] - assert not http.request_connection_close((1, 1), h) + assert not http.connection_close((1, 1), h) h["connection"] = ["close"] - assert http.request_connection_close((1, 1), h) - - -def test_response_connection_close(): - h = odict.ODictCaseless() - assert http.response_connection_close((1, 1), h) - - h["content-length"] = [10] - assert not http.response_connection_close((1, 1), h) - - h["connection"] = ["close"] - assert http.response_connection_close((1, 1), h) - + assert http.connection_close((1, 1), h) def test_read_http_body_response(): h = odict.ODictCaseless() -- cgit v1.2.3 From e402e3b862312ca4f7bd7dd633db3654143c3380 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 21 Nov 2013 01:07:56 +0100 Subject: add custom argparse actions to seamlessly integrate ProxyAuth classes --- netlib/http_auth.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 4adae179..6c91c7c5 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,6 +1,7 @@ import binascii import contrib.md5crypt as md5crypt import http +from argparse import Action, ArgumentTypeError class NullProxyAuth(): @@ -111,3 +112,46 @@ class PassManSingleUser: def test(self, username, password_token): return self.username==username and self.password==password_token + + +class AuthAction(Action): + """ + Helper class to allow seamless integration int argparse. Example usage: + parser.add_argument( + "--nonanonymous", + action=NonanonymousAuthAction, nargs=0, + help="Allow access to any user long as a credentials are specified." + ) + """ + def __call__(self, parser, namespace, values, option_string=None): + passman = self.getPasswordManager(values) + if passman: + authenticator = BasicProxyAuth(passman, "mitmproxy") + else: + authenticator = NullProxyAuth(None) + setattr(namespace, "authenticator", authenticator) + + def getPasswordManager(self, s): + """ + returns the password manager + """ + raise NotImplementedError() + + +class SingleuserAuthAction(AuthAction): + def getPasswordManager(self, s): + if len(s.split(':')) != 2: + raise ArgumentTypeError("Invalid single-user specification. Please use the format username:password") + username, password = s.split(':') + return PassManSingleUser(username, password) + + +class NonanonymousAuthAction(AuthAction): + def getPasswordManager(self, s): + return PassManNonAnon() + + +class HtpasswdAuthAction(AuthAction): + def getPasswordManager(self, s): + with open(s, "r") as f: + return PassManHtpasswd(f) \ No newline at end of file -- cgit v1.2.3 From 5aad09ab816b2343ca686d45e6c5d2b8ba07b10b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 8 Dec 2013 10:15:19 +1300 Subject: Fix client certificate request feature. --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index f4a713f9..23458742 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -303,8 +303,8 @@ class BaseHandler: if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) - # err 20 = X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY - #return True + # Return true to prevent cert verification error + return True ctx.set_verify(SSL.VERIFY_PEER, ver) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True -- cgit v1.2.3 From 75745cb0af9a9b13d075355524947e70209d484b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 8 Dec 2013 13:04:27 +1300 Subject: Zap stray print in tests. --- test/test_tcp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_tcp.py b/test/test_tcp.py index 8fa151af..f45acb00 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -38,7 +38,6 @@ class ClientCipherListHandler(tcp.BaseHandler): sni = None def handle(self): - print self.connection.get_cipher_list() self.wfile.write("%s"%self.connection.get_cipher_list()) self.wfile.flush() -- cgit v1.2.3 From d05c20d8fab3345e19c06ac0de00a2c8f30c44ef Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 8 Dec 2013 13:15:08 +1300 Subject: Domain checks for persistent cert store is now irrelevant. We no longer store these on disk, so we don't care about path components. --- netlib/certutils.py | 14 -------------- netlib/tcp.py | 5 +++-- test/test_certutils.py | 9 --------- 3 files changed, 3 insertions(+), 25 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 22b5c35c..d9b8ce57 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -116,18 +116,6 @@ class CertStore: def __init__(self): self.certs = {} - def check_domain(self, commonname): - try: - commonname.decode("idna") - commonname.decode("ascii") - except: - return False - if ".." in commonname: - return False - if "/" in commonname: - return False - return True - def get_cert(self, commonname, sans, cacert): """ Returns an SSLCert object. @@ -141,8 +129,6 @@ class CertStore: Return None if the certificate could not be found or generated. """ - if not self.check_domain(commonname): - return None if commonname in self.certs: return self.certs[commonname] c = dummy_cert(cacert, commonname, sans) diff --git a/netlib/tcp.py b/netlib/tcp.py index 8fe04d2e..b3be43d6 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -346,8 +346,9 @@ class BaseHandler: self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. - #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any + # pending readable data could lead to an immediate RST being sent. + # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html while self.connection.recv(4096): pass except (socket.error, SSL.Error): diff --git a/test/test_certutils.py b/test/test_certutils.py index 0b4baf75..7a00caca 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -32,15 +32,6 @@ class TestCertStore: assert c.get_cert("foo.com", [], ca) assert c.get_cert("*.foo.com", [], ca) - def test_check_domain(self): - c = certutils.CertStore() - assert c.check_domain("foo") - assert c.check_domain("\x01foo") - assert not c.check_domain("\xfefoo") - assert not c.check_domain("xn--\0") - assert not c.check_domain("foo..foo") - assert not c.check_domain("foo/foo") - class TestDummyCert: def test_with_ca(self): -- cgit v1.2.3 From 7213f86d49960a625643fb6179e6a3731b16d462 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 8 Dec 2013 13:35:42 +1300 Subject: Unit test auth actions. --- netlib/http_auth.py | 17 +++++++---------- test/test_http_auth.py | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 6c91c7c5..71f120d6 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -125,23 +125,19 @@ class AuthAction(Action): """ def __call__(self, parser, namespace, values, option_string=None): passman = self.getPasswordManager(values) - if passman: - authenticator = BasicProxyAuth(passman, "mitmproxy") - else: - authenticator = NullProxyAuth(None) + authenticator = BasicProxyAuth(passman, "mitmproxy") setattr(namespace, "authenticator", authenticator) - def getPasswordManager(self, s): - """ - returns the password manager - """ + def getPasswordManager(self, s): # pragma: nocover raise NotImplementedError() class SingleuserAuthAction(AuthAction): def getPasswordManager(self, s): if len(s.split(':')) != 2: - raise ArgumentTypeError("Invalid single-user specification. Please use the format username:password") + raise ArgumentTypeError( + "Invalid single-user specification. Please use the format username:password" + ) username, password = s.split(':') return PassManSingleUser(username, password) @@ -154,4 +150,5 @@ class NonanonymousAuthAction(AuthAction): class HtpasswdAuthAction(AuthAction): def getPasswordManager(self, s): with open(s, "r") as f: - return PassManHtpasswd(f) \ No newline at end of file + return PassManHtpasswd(f) + diff --git a/test/test_http_auth.py b/test/test_http_auth.py index 83de0fa1..8238d4ca 100644 --- a/test/test_http_auth.py +++ b/test/test_http_auth.py @@ -1,5 +1,6 @@ import binascii, cStringIO from netlib import odict, http_auth, http +import mock import tutils class TestPassManNonAnon: @@ -79,3 +80,25 @@ class TestBasicProxyAuth: hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] assert not ba.authenticate(hdrs) + +class Bunch: pass + +class TestAuthAction: + def test_nonanonymous(self): + m = Bunch() + aa = http_auth.NonanonymousAuthAction(None, None) + aa(None, m, None, None) + assert m.authenticator + + def test_singleuser(self): + m = Bunch() + aa = http_auth.SingleuserAuthAction(None, None) + aa(None, m, "foo:bar", None) + assert m.authenticator + tutils.raises("invalid", aa, None, m, "foo", None) + + def test_httppasswd(self): + m = Bunch() + aa = http_auth.HtpasswdAuthAction(None, None) + aa(None, m, tutils.test_data.path("data/htpasswd"), None) + assert m.authenticator -- cgit v1.2.3 From 390f2a46c920ee332d758d6c46999b5147e0b30b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 8 Dec 2013 01:37:45 +0100 Subject: make AuthAction generic --- netlib/http_auth.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 6c91c7c5..948d503a 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -33,6 +33,7 @@ class NullProxyAuth(): class BasicProxyAuth(NullProxyAuth): CHALLENGE_HEADER = 'Proxy-Authenticate' AUTH_HEADER = 'Proxy-Authorization' + def __init__(self, password_manager, realm): NullProxyAuth.__init__(self, password_manager) self.realm = realm @@ -125,11 +126,10 @@ class AuthAction(Action): """ def __call__(self, parser, namespace, values, option_string=None): passman = self.getPasswordManager(values) - if passman: - authenticator = BasicProxyAuth(passman, "mitmproxy") - else: - authenticator = NullProxyAuth(None) - setattr(namespace, "authenticator", authenticator) + if not passman: + raise ArgumentTypeError("Error creating password manager for proxy authentication.") + authenticator = BasicProxyAuth(passman, "mitmproxy") + setattr(namespace, self.dest, authenticator) def getPasswordManager(self, s): """ -- cgit v1.2.3 From bae2b6ea36f8438103f8e9dcc20eb6b9183fb527 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 8 Dec 2013 02:24:00 +0100 Subject: fix AuthAction tests failures from last merge --- test/test_http_auth.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_http_auth.py b/test/test_http_auth.py index 8238d4ca..dd0273fe 100644 --- a/test/test_http_auth.py +++ b/test/test_http_auth.py @@ -83,22 +83,23 @@ class TestBasicProxyAuth: class Bunch: pass + class TestAuthAction: def test_nonanonymous(self): m = Bunch() - aa = http_auth.NonanonymousAuthAction(None, None) + aa = http_auth.NonanonymousAuthAction(None, "authenticator") aa(None, m, None, None) assert m.authenticator def test_singleuser(self): m = Bunch() - aa = http_auth.SingleuserAuthAction(None, None) + aa = http_auth.SingleuserAuthAction(None, "authenticator") aa(None, m, "foo:bar", None) assert m.authenticator tutils.raises("invalid", aa, None, m, "foo", None) def test_httppasswd(self): m = Bunch() - aa = http_auth.HtpasswdAuthAction(None, None) + aa = http_auth.HtpasswdAuthAction(None, "authenticator") aa(None, m, tutils.test_data.path("data/htpasswd"), None) assert m.authenticator -- cgit v1.2.3 From 4840c6b3bf5c9e992895f9c3117ceddca4c0cc33 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 8 Dec 2013 15:26:30 +1300 Subject: Fix race condition in test suite. --- netlib/tcp.py | 1 - test/test_tcp.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index b3be43d6..5a07c013 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -354,7 +354,6 @@ class BaseHandler: except (socket.error, SSL.Error): # Socket probably already closed pass - self.connection.close() diff --git a/test/test_tcp.py b/test/test_tcp.py index f45acb00..220ece15 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -80,7 +80,6 @@ class FinishFailHandler(tcp.BaseHandler): v = self.rfile.readline() self.wfile.write(v) self.wfile.flush() - o = mock.MagicMock() self.wfile.close() self.rfile.close() self.close = mock.MagicMock(side_effect=socket.error) @@ -99,8 +98,6 @@ class TestFinishFail(test.ServerTestBase): c.wfile.write("foo\n") c.wfile.flush() c.rfile.read(4) - h = self.last_handler - h.finish() class TestDisconnect(test.ServerTestBase): -- cgit v1.2.3 From d66fd5ba1b11ad57b7825b7feb67392f45e88c24 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 10 Dec 2013 22:20:12 +1300 Subject: Bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index 32013c35..9b2e037e 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 9, 2) +IVERSION = (0, 10) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From a7ac97eb823f599ca04f588f6cbe4da28e00a194 Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Thu, 12 Dec 2013 07:00:58 +0100 Subject: support ipv6 --- netlib/tcp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 5a07c013..ee5fe618 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -176,12 +176,13 @@ class Reader(_FileLike): class TCPClient: rbufsize = -1 wbufsize = -1 - def __init__(self, host, port, source_address=None): + def __init__(self, host, port, source_address=None, use_ipv6=False): self.host, self.port = host, port self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False self.source_address = source_address + self.use_ipv6 = use_ipv6 def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ @@ -211,11 +212,10 @@ class TCPClient: def connect(self): try: - addr = socket.gethostbyname(self.host) - connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) if self.source_address: connection.bind(self.source_address) - connection.connect((addr, self.port)) + connection.connect((self.host, self.port)) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: -- cgit v1.2.3 From 6f26cec83e77f8998b50988c54196f9dfae5b7dd Mon Sep 17 00:00:00 2001 From: Matthias Urlichs Date: Thu, 12 Dec 2013 07:11:13 +0100 Subject: tab fix --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index ee5fe618..aa9ca027 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -182,7 +182,7 @@ class TCPClient: self.cert = None self.ssl_established = False self.source_address = source_address - self.use_ipv6 = use_ipv6 + self.use_ipv6 = use_ipv6 def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ -- cgit v1.2.3 From 22aae5fb6654e685e5a1f42ad0f0ea5864f0e2c8 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 13 Dec 2013 06:15:32 +0100 Subject: add travis CI file --- .travis.yml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..7e4209c0 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,14 @@ +language: python +python: + - "2.7" +# command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors +install: + - "pip install coveralls --use-mirrors" + - "pip install nose-cov --use-mirrors" + - "pip install -r requirements.txt --use-mirrors" + - "pip install --upgrade git+https://github.com/mitmproxy/pathod.git" +# command to run tests, e.g. python setup.py test +script: + - "nosetests --with-cov --cov-report term-missing" +after_success: + - coveralls \ No newline at end of file -- cgit v1.2.3 From 969595cca70edc4d02d5f676221267edf01e4252 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 13 Dec 2013 06:24:08 +0100 Subject: add requirements.txt, small changes --- netlib/http.py | 4 ++++ netlib/http_auth.py | 2 -- requirements.txt | 4 ++++ 3 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 requirements.txt diff --git a/netlib/http.py b/netlib/http.py index 7060b688..e160bd79 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -233,6 +233,10 @@ def parse_init(line): def parse_init_connect(line): + """ + Returns (host, port, httpversion) if line is a valid CONNECT line. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + """ v = parse_init(line) if not v: return None diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 69bee5c1..8f062826 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -126,8 +126,6 @@ class AuthAction(Action): """ def __call__(self, parser, namespace, values, option_string=None): passman = self.getPasswordManager(values) - if not passman: - raise ArgumentTypeError("Error creating password manager for proxy authentication.") authenticator = BasicProxyAuth(passman, "mitmproxy") setattr(namespace, self.dest, authenticator) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..ede8bf4a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +pyasn1>=0.1.7 +pyOpenSSL>=0.13 +nose>=1.3.0 +pathod>=0.9.2 \ No newline at end of file -- cgit v1.2.3 From 9ea4646262d855b1564cbf78e7bf9ab0be332dfd Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 13 Dec 2013 15:09:42 +0100 Subject: use markdown for readme --- MANIFEST.in | 2 +- README | 8 -------- README.mkd | 8 ++++++++ setup.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) delete mode 100644 README create mode 100644 README.mkd diff --git a/MANIFEST.in b/MANIFEST.in index 59226fdc..2c1bf265 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ -include README +include README.mkd recursive-include test * recursive-include netlib * recursive-exclude test *.swo *.swp *.pyc diff --git a/README b/README deleted file mode 100644 index f3516faf..00000000 --- a/README +++ /dev/null @@ -1,8 +0,0 @@ -[![Build Status](https://travis-ci.org/mitmproxy/netlib.png)](https://travis-ci.org/mitmproxy/netlib) [![Coverage Status](https://coveralls.io/repos/mitmproxy/netlib/badge.png)](https://coveralls.io/r/mitmproxy/netlib) - -Netlib is a collection of network utility classes, used by the pathod and -mitmproxy projects. It differs from other projects in some fundamental -respects, because both pathod and mitmproxy often need to violate standards. -This means that protocols are implemented as small, well-contained and flexible -functions, and are designed to allow misbehaviour when needed. - diff --git a/README.mkd b/README.mkd new file mode 100644 index 00000000..f3516faf --- /dev/null +++ b/README.mkd @@ -0,0 +1,8 @@ +[![Build Status](https://travis-ci.org/mitmproxy/netlib.png)](https://travis-ci.org/mitmproxy/netlib) [![Coverage Status](https://coveralls.io/repos/mitmproxy/netlib/badge.png)](https://coveralls.io/r/mitmproxy/netlib) + +Netlib is a collection of network utility classes, used by the pathod and +mitmproxy projects. It differs from other projects in some fundamental +respects, because both pathod and mitmproxy often need to violate standards. +This means that protocols are implemented as small, well-contained and flexible +functions, and are designed to allow misbehaviour when needed. + diff --git a/setup.py b/setup.py index 1b2a14f9..db7adc0b 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ def findPackages(path, dataExclude=[]): return packages, package_data -long_description = file("README","rb").read() +long_description = file("README.mkd", "rb").read() packages, package_data = findPackages("netlib") setup( name = "netlib", -- cgit v1.2.3 From 0187d92ec0fb3924a66b6b607f3fc50a5b311259 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 14 Dec 2013 00:19:24 +0100 Subject: test tcpclient.source_address, increase coverage --- test/test_tcp.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/test/test_tcp.py b/test/test_tcp.py index 75dcad13..a4e66516 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,4 +1,4 @@ -import cStringIO, Queue, time, socket +import cStringIO, Queue, time, socket, random from netlib import tcp, certutils, test import mock import tutils @@ -24,6 +24,12 @@ class EchoHandler(tcp.BaseHandler): self.wfile.flush() +class ClientPeernameHandler(tcp.BaseHandler): + def handle(self): + self.wfile.write(str(self.connection.getpeername())) + self.wfile.flush() + + class CertHandler(tcp.BaseHandler): sni = None def handle_sni(self, connection): @@ -74,6 +80,22 @@ class TestServer(test.ServerTestBase): assert c.rfile.readline() == testval +class TestServerBind(test.ServerTestBase): + handler = ClientPeernameHandler + + def test_bind(self): + """ Test to bind to a given random port. Try again if the random port turned out to be blocked. """ + for i in range(20): + random_port = random.randrange(1024, 65535) + try: + c = tcp.TCPClient("127.0.0.1", self.port, source_address=("127.0.0.1", random_port)) + c.connect() + assert c.rfile.readline() == str(("127.0.0.1", random_port)) + return + except tcp.NetLibError: # port probably already in use + pass + + class TestServerIPv6(test.ServerTestBase): handler = EchoHandler use_ipv6 = True -- cgit v1.2.3 From cebec67e08bcb9a4dc353ca18aedc53d0230ea42 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 15 Dec 2013 06:43:54 +0100 Subject: refactor read_http_body --- netlib/http.py | 95 +++++++++++++++++++++++-------------------------------- netlib/test.py | 2 +- test/test_http.py | 88 +++++++++++++++++++++++++++++---------------------- test/test_tcp.py | 1 - 4 files changed, 91 insertions(+), 95 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index e160bd79..454edb3a 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -95,14 +95,17 @@ def read_headers(fp): return odict.ODictCaseless(ret) -def read_chunked(code, fp, limit): +def read_chunked(fp, headers, limit, is_request): """ Read a chunked HTTP body. May raise HttpError. """ + # FIXME: Should check if chunked is the final encoding in the headers + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2. content = "" total = 0 + code = 400 if is_request else 502 while 1: line = fp.readline(128) if line == "": @@ -151,35 +154,6 @@ def has_chunked_encoding(headers): return "chunked" in [i.lower() for i in get_header_tokens(headers, "transfer-encoding")] -def read_http_body(code, rfile, headers, all, limit): - """ - Read an HTTP body: - - code: The HTTP error code to be used when raising HttpError - rfile: A file descriptor to read from - headers: An ODictCaseless object - all: Should we read all data? - limit: Size limit. - """ - if has_chunked_encoding(headers): - content = read_chunked(code, rfile, limit) - elif "content-length" in headers: - try: - l = int(headers["content-length"][0]) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise HttpError(code, "Invalid content-length header: %s"%headers["content-length"]) - if limit is not None and l > limit: - raise HttpError(code, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) - content = rfile.read(l) - elif all: - content = rfile.read(limit if limit else -1) - else: - content = "" - return content - - def parse_http_protocol(s): """ Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or @@ -304,28 +278,6 @@ def connection_close(httpversion, headers): return True - -def read_http_body_request(rfile, wfile, headers, httpversion, limit): - """ - Read the HTTP body from a client request. - """ - if "expect" in headers: - # FIXME: Should be forwarded upstream - if "100-continue" in headers['expect'] and httpversion >= (1, 1): - wfile.write('HTTP/1.1 100 Continue\r\n') - wfile.write('\r\n') - del headers['expect'] - return read_http_body(400, rfile, headers, False, limit) - - -def read_http_body_response(rfile, headers, limit): - """ - Read the HTTP body from a server response. - """ - all = "close" in get_header_tokens(headers, "connection") - return read_http_body(500, rfile, headers, all, limit) - - def parse_response_line(line): parts = line.strip().split(" ", 2) if len(parts) == 2: # handle missing message gracefully @@ -359,10 +311,41 @@ def read_response(rfile, method, body_size_limit): headers = read_headers(rfile) if headers is None: raise HttpError(502, "Invalid headers.") - if code >= 100 and code <= 199: - return read_response(rfile, method, body_size_limit) - if method == "HEAD" or code == 204 or code == 304: + + # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + if method == "HEAD" or (code in [204, 304]) or 100 <= code <= 199: content = "" else: - content = read_http_body_response(rfile, headers, body_size_limit) + content = read_http_body(rfile, headers, body_size_limit, False) return httpversion, code, msg, headers, content + + +def read_http_body(rfile, headers, limit, is_request): + """ + Read an HTTP message body: + + rfile: A file descriptor to read from + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False otherwise + """ + if has_chunked_encoding(headers): + content = read_chunked(rfile, headers, limit, is_request) + elif "content-length" in headers: + try: + l = int(headers["content-length"][0]) + if l < 0: + raise ValueError() + except ValueError: + raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) + if limit is not None and l > limit: + raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) + content = rfile.read(l) + elif is_request: + content = "" + else: + content = rfile.read(limit if limit else -1) + not_done = rfile.read(1) + if not_done: + raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) + return content \ No newline at end of file diff --git a/netlib/test.py b/netlib/test.py index cd1a3847..85a56739 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -18,7 +18,7 @@ class ServerTestBase: handler = None addr = ("localhost", 0) use_ipv6 = False - + @classmethod def setupAll(cls): cls.q = Queue.Queue() diff --git a/test/test_http.py b/test/test_http.py index 4d89bf24..a0386115 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -1,5 +1,5 @@ import cStringIO, textwrap, binascii -from netlib import http, odict +from netlib import http, odict, tcp, test import tutils @@ -17,25 +17,25 @@ def test_has_chunked_encoding(): def test_read_chunked(): s = cStringIO.StringIO("1\r\na\r\n0\r\n") - tutils.raises("closed prematurely", http.read_chunked, 500, s, None) + tutils.raises("closed prematurely", http.read_chunked, s, None, None, True) s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert http.read_chunked(500, s, None) == "a" + assert http.read_chunked(s, None, None, True) == "a" s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") - assert http.read_chunked(500, s, None) == "a" + assert http.read_chunked(s, None, None, True) == "a" s = cStringIO.StringIO("\r\n") - tutils.raises("closed prematurely", http.read_chunked, 500, s, None) + tutils.raises("closed prematurely", http.read_chunked, s, None, None, True) s = cStringIO.StringIO("1\r\nfoo") - tutils.raises("malformed chunked body", http.read_chunked, 500, s, None) + tutils.raises("malformed chunked body", http.read_chunked, s, None, None, True) s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises(http.HttpError, http.read_chunked, 500, s, None) + tutils.raises(http.HttpError, http.read_chunked, s, None, None, True) s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - tutils.raises("too large", http.read_chunked, 500, s, 2) + tutils.raises("too large", http.read_chunked, s, None, 2, True) def test_connection_close(): @@ -49,23 +49,6 @@ def test_connection_close(): h["connection"] = ["close"] assert http.connection_close((1, 1), h) -def test_read_http_body_response(): - h = odict.ODictCaseless() - h["content-length"] = [7] - s = cStringIO.StringIO("testing") - assert http.read_http_body_response(s, h, None) == "testing" - - - h = odict.ODictCaseless() - s = cStringIO.StringIO("testing") - assert not http.read_http_body_response(s, h, None) - - h = odict.ODictCaseless() - h["connection"] = ["close"] - s = cStringIO.StringIO("testing") - assert http.read_http_body_response(s, h, None) == "testing" - - def test_get_header_tokens(): h = odict.ODictCaseless() assert http.get_header_tokens(h, "foo") == [] @@ -79,38 +62,54 @@ def test_get_header_tokens(): def test_read_http_body_request(): h = odict.ODictCaseless() - h["expect"] = ["100-continue"] r = cStringIO.StringIO("testing") - w = cStringIO.StringIO() - assert http.read_http_body_request(r, w, h, (1, 1), None) == "" - assert "100 Continue" in w.getvalue() + assert http.read_http_body(r, h, None, True) == "" +def test_read_http_body_response(): + h = odict.ODictCaseless() + s = cStringIO.StringIO("testing") + assert http.read_http_body(s, h, None, False) == "testing" def test_read_http_body(): + # test default case h = odict.ODictCaseless() + h["content-length"] = [7] s = cStringIO.StringIO("testing") - assert http.read_http_body(500, s, h, False, None) == "" + assert http.read_http_body(s, h, None, False) == "testing" + # test content length: invalid header h["content-length"] = ["foo"] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, 500, s, h, False, None) + tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) + # test content length: invalid header #2 + h["content-length"] = [-1] + s = cStringIO.StringIO("testing") + tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) + + # test content length: content length > actual content h["content-length"] = [5] s = cStringIO.StringIO("testing") - assert len(http.read_http_body(500, s, h, False, None)) == 5 + tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) + + # test content length: content length < actual content s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, 500, s, h, False, 4) + assert len(http.read_http_body(s, h, None, False)) == 5 + # test no content length: limit > actual content h = odict.ODictCaseless() s = cStringIO.StringIO("testing") - assert len(http.read_http_body(500, s, h, True, 4)) == 4 + assert len(http.read_http_body(s, h, 100, False)) == 7 + + # test no content length: limit < actual content s = cStringIO.StringIO("testing") - assert len(http.read_http_body(500, s, h, True, 100)) == 7 + tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) + # test chunked h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - assert http.read_http_body(500, s, h, True, 100) == "aaaaa" + assert http.read_http_body(s, h, 100, False) == "aaaaa" def test_parse_http_protocol(): @@ -214,6 +213,21 @@ class TestReadHeaders: assert self._read(data) is None +class NoContentLengthHTTPHandler(tcp.BaseHandler): + def handle(self): + self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") + self.wfile.flush() + + +class TestReadResponseNoContentLength(test.ServerTestBase): + handler = NoContentLengthHTTPHandler + + def test_no_content_length(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None) + assert content == "bar\r\n\r\n" + def test_read_response(): def tst(data, method, limit): data = textwrap.dedent(data) @@ -244,7 +258,7 @@ def test_read_response(): HTTP/1.1 200 OK """ - assert tst(data, "GET", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), '') + assert tst(data, "GET", None) == ((1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '') data = """ HTTP/1.1 200 OK diff --git a/test/test_tcp.py b/test/test_tcp.py index a4e66516..7f2c21c4 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -133,7 +133,6 @@ class TestFinishFail(test.ServerTestBase): c.wfile.flush() c.rfile.read(4) - class TestDisconnect(test.ServerTestBase): handler = EchoHandler def test_echo(self): -- cgit v1.2.3 From c7606ffdf9ef1f94a5e065c1cce2241f60dcb81e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 29 Dec 2013 10:52:37 +0100 Subject: list mock as requirement (via @droope) --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ede8bf4a..3b530817 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ pyasn1>=0.1.7 pyOpenSSL>=0.13 nose>=1.3.0 -pathod>=0.9.2 \ No newline at end of file +mock>=1.0.1 +pathod>=0.9.2 -- cgit v1.2.3 From 5717e7300c1cc4a17f0fb0659dcf591fbd0a6e40 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 5 Jan 2014 10:57:50 +1300 Subject: Make it possible to pass custom environment variables into wsgi apps. --- netlib/wsgi.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index dffc2ace..647cb899 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -33,7 +33,7 @@ class WSGIAdaptor: def __init__(self, app, domain, port, sversion): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion - def make_environ(self, request, errsoc): + def make_environ(self, request, errsoc, **extra): if '?' in request.path: path_info, query = request.path.split('?', 1) else: @@ -59,6 +59,7 @@ class WSGIAdaptor: # FIXME: We need to pick up the protocol read from the request. 'SERVER_PROTOCOL': "HTTP/1.1", } + environ.update(extra) if request.client_conn.address: environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address @@ -86,7 +87,7 @@ class WSGIAdaptor: soc.write("\r\n") soc.write(c) - def serve(self, request, soc): + def serve(self, request, soc, **env): state = dict( response_started = False, headers_sent = False, @@ -123,7 +124,7 @@ class WSGIAdaptor: errs = cStringIO.StringIO() try: - dataiter = self.app(self.make_environ(request, errs), start_response) + dataiter = self.app(self.make_environ(request, errs, **env), start_response) for i in dataiter: write(i) if not state["headers_sent"]: -- cgit v1.2.3 From ac1a700fa16e2ae2146425844823bff70cc86f4b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 8 Jan 2014 14:46:55 +1300 Subject: Make certificate not-before time 48 hours. Fixes #200 --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index d9b8ce57..0349bec7 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -96,7 +96,7 @@ def dummy_cert(ca, commonname, sans): key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore(-3600) + cert.gmtime_adj_notBefore(-3600*48) cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) cert.set_issuer(ca.get_subject()) cert.get_subject().CN = commonname -- cgit v1.2.3 From 951f2d517fa2e464d654a54bebacbd983f944c62 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Jan 2014 01:57:37 +0100 Subject: change parameter names to reflect changes --- netlib/tcp.py | 29 +++++++++++++---------------- netlib/test.py | 2 +- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 33f7ef3a..d35818bf 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -138,8 +138,8 @@ class Reader(_FileLike): raise NetLibTimeout except socket.timeout: raise NetLibTimeout - except socket.error: - raise NetLibDisconnect + except socket.error, v: + raise NetLibDisconnect(v[1]) except SSL.SysCallError: raise NetLibDisconnect except SSL.Error, v: @@ -255,16 +255,13 @@ class BaseHandler: """ rbufsize = -1 wbufsize = -1 - def __init__(self, connection, client_address, server): + def __init__(self, connection): self.connection = connection self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) - self.client_address = client_address - self.server = server self.finished = False self.ssl_established = False - self.clientcert = None def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None): @@ -371,13 +368,13 @@ class TCPServer: self.port = self.server_address[1] self.socket.listen(self.request_queue_size) - def request_thread(self, request, client_address): + def connection_thread(self, connection, client_address): try: - self.handle_connection(request, client_address) - request.close() + self.handle_client_connection(connection, client_address) except: - self.handle_error(request, client_address) - request.close() + self.handle_error(connection, client_address) + finally: + connection.close() def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() @@ -391,10 +388,10 @@ class TCPServer: else: raise if self.socket in r: - request, client_address = self.socket.accept() + connection, client_address = self.socket.accept() t = threading.Thread( - target = self.request_thread, - args = (request, client_address) + target = self.connection_thread, + args = (connection, client_address) ) t.setDaemon(1) t.start() @@ -410,7 +407,7 @@ class TCPServer: def handle_error(self, request, client_address, fp=sys.stderr): """ - Called when handle_connection raises an exception. + Called when handle_client_connection raises an exception. """ # If a thread has persisted after interpreter exit, the module might be # none. @@ -421,7 +418,7 @@ class TCPServer: print >> fp, exc print >> fp, '-'*40 - def handle_connection(self, request, client_address): # pragma: no cover + def handle_client_connection(self, conn, client_address): # pragma: no cover """ Called after client connection. """ diff --git a/netlib/test.py b/netlib/test.py index cd1a3847..0c36da6a 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -50,7 +50,7 @@ class TServer(tcp.TCPServer): self.handler_klass = handler_klass self.last_handler = None - def handle_connection(self, request, client_address): + def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h if self.ssl: -- cgit v1.2.3 From d0a6d2e2545089893d3789e3c787e269645df852 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Jan 2014 05:33:21 +0100 Subject: fix tests, remove duplicate code --- netlib/tcp.py | 91 ++++++++++++++++++++++++---------------------------------- netlib/test.py | 2 +- 2 files changed, 38 insertions(+), 55 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index d35818bf..e48f4f6b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -138,8 +138,8 @@ class Reader(_FileLike): raise NetLibTimeout except socket.timeout: raise NetLibTimeout - except socket.error, v: - raise NetLibDisconnect(v[1]) + except socket.error: + raise NetLibDisconnect except SSL.SysCallError: raise NetLibDisconnect except SSL.Error, v: @@ -173,7 +173,40 @@ class Reader(_FileLike): return result -class TCPClient: +class SocketCloseMixin: + def finish(self): + self.finished = True + try: + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.close() + self.wfile.close() + self.rfile.close() + except (socket.error, NetLibDisconnect): + # Remote has disconnected + pass + + def close(self): + """ + Does a hard close of the socket, i.e. a shutdown, followed by a close. + """ + try: + if self.ssl_established: + self.connection.shutdown() + self.connection.sock_shutdown(socket.SHUT_WR) + else: + self.connection.shutdown(socket.SHUT_WR) + #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. + #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + while self.connection.recv(4096): + pass + self.connection.close() + except (socket.error, SSL.Error, IOError): + # Socket probably already closed + pass + + +class TCPClient(SocketCloseMixin): rbufsize = -1 wbufsize = -1 def __init__(self, host, port, source_address=None, use_ipv6=False): @@ -228,27 +261,8 @@ class TCPClient: def gettimeout(self): return self.connection.gettimeout() - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. - """ - try: - if self.ssl_established: - self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. - #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): - pass - self.connection.close() - except (socket.error, SSL.Error, IOError): - # Socket probably already closed - pass - -class BaseHandler: +class BaseHandler(SocketCloseMixin): """ The instantiator is expected to call the handle() and finish() methods. @@ -315,43 +329,12 @@ class BaseHandler: self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) - def finish(self): - self.finished = True - try: - if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.close() - self.wfile.close() - self.rfile.close() - except (socket.error, NetLibDisconnect): - # Remote has disconnected - pass - def handle(self): # pragma: no cover raise NotImplementedError def settimeout(self, n): self.connection.settimeout(n) - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. - """ - try: - if self.ssl_established: - self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent. - # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): - pass - except (socket.error, SSL.Error): - # Socket probably already closed - pass - self.connection.close() class TCPServer: diff --git a/netlib/test.py b/netlib/test.py index 2209ebc3..f5599082 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -51,7 +51,7 @@ class TServer(tcp.TCPServer): self.last_handler = None def handle_client_connection(self, request, client_address): - h = self.handler_klass(request, client_address, self) + h = self.handler_klass(request) self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( -- cgit v1.2.3 From 85e09278209af88d081e2cbc8002bd6defb624f4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Jan 2014 17:38:39 +0100 Subject: display build status from master branch --- README.mkd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.mkd b/README.mkd index f3516faf..f0a26dd5 100644 --- a/README.mkd +++ b/README.mkd @@ -1,4 +1,4 @@ -[![Build Status](https://travis-ci.org/mitmproxy/netlib.png)](https://travis-ci.org/mitmproxy/netlib) [![Coverage Status](https://coveralls.io/repos/mitmproxy/netlib/badge.png)](https://coveralls.io/r/mitmproxy/netlib) +[![Build Status](https://travis-ci.org/mitmproxy/netlib.png?branch=master)](https://travis-ci.org/mitmproxy/netlib) [![Coverage Status](https://coveralls.io/repos/mitmproxy/netlib/badge.png?branch=master)](https://coveralls.io/r/mitmproxy/netlib) Netlib is a collection of network utility classes, used by the pathod and mitmproxy projects. It differs from other projects in some fundamental -- cgit v1.2.3 From 0f22039bcadd26c2745f609085bcfdbba35b4945 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 18 Jan 2014 22:55:40 +0100 Subject: add CONNECT request to list of request types that don't have a response body --- netlib/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/http.py b/netlib/http.py index 454edb3a..51f85627 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -313,7 +313,7 @@ def read_response(rfile, method, body_size_limit): raise HttpError(502, "Invalid headers.") # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - if method == "HEAD" or (code in [204, 304]) or 100 <= code <= 199: + if method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: content = "" else: content = read_http_body(rfile, headers, body_size_limit, False) -- cgit v1.2.3 From 8266699acdfcb786ba2c87007a17632ff1893fe5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 19 Jan 2014 18:17:06 +1300 Subject: Silence pyflakes, adjust requirements.txt --- netlib/http_auth.py | 1 - requirements.txt | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 8f062826..be99fb3d 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,4 +1,3 @@ -import binascii import contrib.md5crypt as md5crypt import http from argparse import Action, ArgumentTypeError diff --git a/requirements.txt b/requirements.txt index 3b530817..de289584 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ pyasn1>=0.1.7 pyOpenSSL>=0.13 nose>=1.3.0 mock>=1.0.1 -pathod>=0.9.2 +pathod>=0.10 -- cgit v1.2.3 From 2aadea0b7c55489b5171bd9eae35eb21a58cbd0d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 28 Jan 2014 14:09:45 +1300 Subject: Fix homepage URL --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index db7adc0b..2937487c 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,7 @@ setup( long_description = long_description, author = "Aldo Cortesi", author_email = "aldo@corte.si", - url = "http://cortesi.github.com/netlib", + url = "http://github.com/mitmproxy/netlib", packages = packages, package_data = package_data, classifiers = [ -- cgit v1.2.3 From 5ba36622f0496086164d32f9260f4a42ea422dfc Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 28 Jan 2014 14:22:01 +1300 Subject: travis: force install of pathod from git. --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 7e4209c0..9f606f0e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,10 +5,10 @@ python: install: - "pip install coveralls --use-mirrors" - "pip install nose-cov --use-mirrors" - - "pip install -r requirements.txt --use-mirrors" - "pip install --upgrade git+https://github.com/mitmproxy/pathod.git" + - "pip install -r requirements.txt --use-mirrors" # command to run tests, e.g. python setup.py test script: - "nosetests --with-cov --cov-report term-missing" after_success: - - coveralls \ No newline at end of file + - coveralls -- cgit v1.2.3 From 11f729a3a328143814e3dbe155535b91a9767f6b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 28 Jan 2014 14:32:30 +1300 Subject: Try harder to un-break travis --- .travis.yml | 2 +- requirements.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9f606f0e..2702735a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,8 +5,8 @@ python: install: - "pip install coveralls --use-mirrors" - "pip install nose-cov --use-mirrors" - - "pip install --upgrade git+https://github.com/mitmproxy/pathod.git" - "pip install -r requirements.txt --use-mirrors" + - "pip install --upgrade git+https://github.com/mitmproxy/pathod.git" # command to run tests, e.g. python setup.py test script: - "nosetests --with-cov --cov-report term-missing" diff --git a/requirements.txt b/requirements.txt index de289584..bcf8e48f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,3 @@ pyasn1>=0.1.7 pyOpenSSL>=0.13 nose>=1.3.0 mock>=1.0.1 -pathod>=0.10 -- cgit v1.2.3 From 732932e8bb53ad0b57de2b2a6d124562b02584fa Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 28 Jan 2014 14:42:39 +1300 Subject: Try even harderer to fix travis. --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 2702735a..fe06a3eb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,6 +6,7 @@ install: - "pip install coveralls --use-mirrors" - "pip install nose-cov --use-mirrors" - "pip install -r requirements.txt --use-mirrors" + - "pip install ." - "pip install --upgrade git+https://github.com/mitmproxy/pathod.git" # command to run tests, e.g. python setup.py test script: -- cgit v1.2.3 From 9759ec7c29093eb278a0a2eda818811d2b8f2e74 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 28 Jan 2014 02:57:46 +0100 Subject: move test requirements out of main requirements.txt --- .travis.yml | 3 +-- requirements.txt | 4 +--- test/requirements.txt | 5 +++++ 3 files changed, 7 insertions(+), 5 deletions(-) create mode 100644 test/requirements.txt diff --git a/.travis.yml b/.travis.yml index fe06a3eb..31bb399f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,11 +3,10 @@ python: - "2.7" # command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors install: - - "pip install coveralls --use-mirrors" - - "pip install nose-cov --use-mirrors" - "pip install -r requirements.txt --use-mirrors" - "pip install ." - "pip install --upgrade git+https://github.com/mitmproxy/pathod.git" + - "pip install -r test/requirements.txt --use-mirrors" # command to run tests, e.g. python setup.py test script: - "nosetests --with-cov --cov-report term-missing" diff --git a/requirements.txt b/requirements.txt index bcf8e48f..460a60e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,2 @@ pyasn1>=0.1.7 -pyOpenSSL>=0.13 -nose>=1.3.0 -mock>=1.0.1 +pyOpenSSL>=0.13 \ No newline at end of file diff --git a/test/requirements.txt b/test/requirements.txt new file mode 100644 index 00000000..89e4aa0a --- /dev/null +++ b/test/requirements.txt @@ -0,0 +1,5 @@ +mock>=1.0.1 +nose>=1.3.0 +nose-cov>=1.6 +coveralls>=0.4.1 +pathod>=0.10 \ No newline at end of file -- cgit v1.2.3 From 9c9e4a5295dd9c190d9aadd024d135386d8cf7c0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 28 Jan 2014 15:13:31 +1300 Subject: travis irc notifications --- .travis.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.travis.yml b/.travis.yml index fe06a3eb..84f97c9b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,3 +13,10 @@ script: - "nosetests --with-cov --cov-report term-missing" after_success: - coveralls +notifications: + irc: + channels: + - "irc.oftc.net#mitmproxy" + on_success: change + on_failure: always + -- cgit v1.2.3 From 763cb90b66b23cd94b6e37df3d4c7b8e7f89492a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 28 Jan 2014 17:26:35 +0100 Subject: add tcp.Address to unify ipv4/ipv6 address handling --- netlib/certutils.py | 2 +- netlib/tcp.py | 56 +++++++++++++++++++++++++++++++++++++++-------------- netlib/test.py | 11 +++++------ test/test_http.py | 2 +- test/test_tcp.py | 36 +++++++++++++++++----------------- 5 files changed, 67 insertions(+), 40 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 0349bec7..94294f6e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -237,7 +237,7 @@ class SSLCert: def get_remote_cert(host, port, sni): - c = tcp.TCPClient(host, port) + c = tcp.TCPClient((host, port)) c.connect() c.convert_to_ssl(sni=sni) return c.cert diff --git a/netlib/tcp.py b/netlib/tcp.py index e48f4f6b..bad166d0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -173,6 +173,35 @@ class Reader(_FileLike): return result +class Address(tuple): + """ + This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. + """ + def __new__(cls, address, use_ipv6=False): + a = super(Address, cls).__new__(cls, tuple(address)) + a.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET + return a + + @classmethod + def wrap(cls, t): + if isinstance(t, cls): + return t + else: + return cls(t) + + @property + def host(self): + return self[0] + + @property + def port(self): + return self[1] + + @property + def is_ipv6(self): + return self.family == socket.AF_INET6 + + class SocketCloseMixin: def finish(self): self.finished = True @@ -209,10 +238,9 @@ class SocketCloseMixin: class TCPClient(SocketCloseMixin): rbufsize = -1 wbufsize = -1 - def __init__(self, host, port, source_address=None, use_ipv6=False): - self.host, self.port = host, port + def __init__(self, address, source_address=None): + self.address = Address.wrap(address) self.source_address = source_address - self.use_ipv6 = use_ipv6 self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False @@ -245,14 +273,14 @@ class TCPClient(SocketCloseMixin): def connect(self): try: - connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) + connection = socket.socket(self.address.family, socket.SOCK_STREAM) if self.source_address: connection.bind(self.source_address) - connection.connect((self.host, self.port)) + connection.connect(self.address) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: - raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) + raise NetLibError('Error connecting to "%s": %s' % (self.address[0], err)) self.connection = connection def settimeout(self, n): @@ -269,8 +297,9 @@ class BaseHandler(SocketCloseMixin): """ rbufsize = -1 wbufsize = -1 - def __init__(self, connection): + def __init__(self, connection, address): self.connection = connection + self.address = Address.wrap(address) self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) @@ -339,19 +368,18 @@ class BaseHandler(SocketCloseMixin): class TCPServer: request_queue_size = 20 - def __init__(self, server_address, use_ipv6=False): - self.server_address = server_address - self.use_ipv6 = use_ipv6 + def __init__(self, address): + self.address = Address.wrap(address) self.__is_shut_down = threading.Event() self.__shutdown_request = False - self.socket = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) + self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.server_address) - self.server_address = self.socket.getsockname() - self.port = self.server_address[1] + self.socket.bind(self.address) + self.address = Address.wrap(self.socket.getsockname()) self.socket.listen(self.request_queue_size) def connection_thread(self, connection, client_address): + client_address = Address(client_address) try: self.handle_client_connection(connection, client_address) except: diff --git a/netlib/test.py b/netlib/test.py index f5599082..565b97cd 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -17,19 +17,18 @@ class ServerTestBase: ssl = None handler = None addr = ("localhost", 0) - use_ipv6 = False @classmethod def setupAll(cls): cls.q = Queue.Queue() s = cls.makeserver() - cls.port = s.port + cls.port = s.address.port cls.server = ServerThread(s) cls.server.start() @classmethod def makeserver(cls): - return TServer(cls.ssl, cls.q, cls.handler, cls.addr, cls.use_ipv6) + return TServer(cls.ssl, cls.q, cls.handler, cls.addr) @classmethod def teardownAll(cls): @@ -41,17 +40,17 @@ class ServerTestBase: class TServer(tcp.TCPServer): - def __init__(self, ssl, q, handler_klass, addr, use_ipv6): + def __init__(self, ssl, q, handler_klass, addr): """ ssl: A {cert, key, v3_only} dict. """ - tcp.TCPServer.__init__(self, addr, use_ipv6=use_ipv6) + tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q self.handler_klass = handler_klass self.last_handler = None def handle_client_connection(self, request, client_address): - h = self.handler_klass(request) + h = self.handler_klass(request, client_address) self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( diff --git a/test/test_http.py b/test/test_http.py index a0386115..e80e4b8f 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -223,7 +223,7 @@ class TestReadResponseNoContentLength(test.ServerTestBase): handler = NoContentLengthHTTPHandler def test_no_content_length(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None) assert content == "bar\r\n\r\n" diff --git a/test/test_tcp.py b/test/test_tcp.py index 7f2c21c4..49e20635 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -73,7 +73,7 @@ class TestServer(test.ServerTestBase): handler = EchoHandler def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write(testval) c.wfile.flush() @@ -88,7 +88,7 @@ class TestServerBind(test.ServerTestBase): for i in range(20): random_port = random.randrange(1024, 65535) try: - c = tcp.TCPClient("127.0.0.1", self.port, source_address=("127.0.0.1", random_port)) + c = tcp.TCPClient(("127.0.0.1", self.port), source_address=("127.0.0.1", random_port)) c.connect() assert c.rfile.readline() == str(("127.0.0.1", random_port)) return @@ -98,11 +98,11 @@ class TestServerBind(test.ServerTestBase): class TestServerIPv6(test.ServerTestBase): handler = EchoHandler - use_ipv6 = True + addr = tcp.Address(("localhost", 0), use_ipv6=True) def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient("::1", self.port, use_ipv6=True) + c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True)) c.connect() c.wfile.write(testval) c.wfile.flush() @@ -127,7 +127,7 @@ class TestFinishFail(test.ServerTestBase): handler = FinishFailHandler def test_disconnect_in_finish(self): testval = "echo!\n" - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write("foo\n") c.wfile.flush() @@ -137,7 +137,7 @@ class TestDisconnect(test.ServerTestBase): handler = EchoHandler def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write(testval) c.wfile.flush() @@ -153,7 +153,7 @@ class TestServerSSL(test.ServerTestBase): v3_only = False ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(sni="foo.com", options=tcp.OP_ALL) testval = "echo!\n" @@ -174,7 +174,7 @@ class TestSSLv3Only(test.ServerTestBase): v3_only = True ) def test_failure(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD) @@ -188,13 +188,13 @@ class TestSSLClientCert(test.ServerTestBase): v3_only = False ) def test_clientcert(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(cert=tutils.test_data.path("data/clientcert/client.pem")) assert c.rfile.readline().strip() == "1" def test_clientcert_err(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() tutils.raises( tcp.NetLibError, @@ -212,7 +212,7 @@ class TestSNI(test.ServerTestBase): v3_only = False ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(sni="foo.com") assert c.rfile.readline() == "foo.com" @@ -228,7 +228,7 @@ class TestClientCipherList(test.ServerTestBase): cipher_list = 'RC4-SHA' ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(sni="foo.com") assert c.rfile.readline() == "['RC4-SHA']" @@ -243,7 +243,7 @@ class TestSSLDisconnect(test.ServerTestBase): v3_only = False ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() # Excercise SSL.ZeroReturnError @@ -255,7 +255,7 @@ class TestSSLDisconnect(test.ServerTestBase): class TestDisconnect(test.ServerTestBase): def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.rfile.read(10) c.wfile.write("foo") @@ -266,7 +266,7 @@ class TestDisconnect(test.ServerTestBase): class TestServerTimeOut(test.ServerTestBase): handler = TimeoutHandler def test_timeout(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() time.sleep(0.3) assert self.last_handler.timeout @@ -275,7 +275,7 @@ class TestServerTimeOut(test.ServerTestBase): class TestTimeOut(test.ServerTestBase): handler = HangHandler def test_timeout(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.settimeout(0.1) assert c.gettimeout() == 0.1 @@ -291,7 +291,7 @@ class TestSSLTimeOut(test.ServerTestBase): v3_only = False ) def test_timeout_client(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() c.settimeout(0.1) @@ -300,7 +300,7 @@ class TestSSLTimeOut(test.ServerTestBase): class TestTCPClient: def test_conerr(self): - c = tcp.TCPClient("127.0.0.1", 0) + c = tcp.TCPClient(("127.0.0.1", 0)) tutils.raises(tcp.NetLibError, c.connect) -- cgit v1.2.3 From e18ac4b672e8645388dc8057801092ce417f1511 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 28 Jan 2014 20:30:16 +0100 Subject: re-add server attribute to BaseHandler --- netlib/tcp.py | 4 +++- netlib/test.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index bad166d0..729e513e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -297,9 +297,11 @@ class BaseHandler(SocketCloseMixin): """ rbufsize = -1 wbufsize = -1 - def __init__(self, connection, address): + + def __init__(self, connection, address, server): self.connection = connection self.address = Address.wrap(address) + self.server = server self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) diff --git a/netlib/test.py b/netlib/test.py index 565b97cd..2f6a7107 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -50,7 +50,7 @@ class TServer(tcp.TCPServer): self.last_handler = None def handle_client_connection(self, request, client_address): - h = self.handler_klass(request, client_address) + h = self.handler_klass(request, client_address, self) self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( -- cgit v1.2.3 From ff9656be80192ac837cf98997f9fe6c00c9c5a32 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 30 Jan 2014 20:07:30 +0100 Subject: remove subclassing of tuple in tcp.Address, move StateObject into netlib --- netlib/certutils.py | 12 +++++++- netlib/odict.py | 7 ++++- netlib/stateobject.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++ netlib/tcp.py | 45 ++++++++++++++++++++--------- 4 files changed, 128 insertions(+), 16 deletions(-) create mode 100644 netlib/stateobject.py diff --git a/netlib/certutils.py b/netlib/certutils.py index 94294f6e..139203b9 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -3,6 +3,7 @@ from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL +from netlib.stateobject import StateObject import tcp default_exp = 62208000 # =24 * 60 * 60 * 720 @@ -152,13 +153,22 @@ class _GeneralNames(univ.SequenceOf): sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) -class SSLCert: +class SSLCert(StateObject): def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. """ self.x509 = cert + def _get_state(self): + return self.to_pem() + + def _load_state(self, state): + self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) + + def _from_state(cls, state): + return cls.from_pem(state) + @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) diff --git a/netlib/odict.py b/netlib/odict.py index 0759a5bf..8e195afc 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,4 +1,6 @@ import re, copy +from netlib.stateobject import StateObject + def safe_subn(pattern, repl, target, *args, **kwargs): """ @@ -9,7 +11,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs): return re.subn(str(pattern), str(repl), target, *args, **kwargs) -class ODict: +class ODict(StateObject): """ A dictionary-like object for managing ordered (key, value) data. """ @@ -98,6 +100,9 @@ class ODict: def _get_state(self): return [tuple(i) for i in self.lst] + def _load_state(self, state): + self.list = [list(i) for i in state] + @classmethod def _from_state(klass, state): return klass([list(i) for i in state]) diff --git a/netlib/stateobject.py b/netlib/stateobject.py new file mode 100644 index 00000000..c2ef2cd4 --- /dev/null +++ b/netlib/stateobject.py @@ -0,0 +1,80 @@ +from types import ClassType + + +class StateObject: + def _get_state(self): + raise NotImplementedError + + def _load_state(self, state): + raise NotImplementedError + + @classmethod + def _from_state(cls, state): + raise NotImplementedError + + def __eq__(self, other): + try: + return self._get_state() == other._get_state() + except AttributeError: # we may compare with something that's not a StateObject + return False + + +class SimpleStateObject(StateObject): + """ + A StateObject with opionated conventions that tries to keep everything DRY. + + Simply put, you agree on a list of attributes and their type. + Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves. + SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods. + Overriding _get_state or _load_state to add custom adjustments is always possible. + """ + + _stateobject_attributes = None # none by default to raise an exception if definition was forgotten + """ + An attribute-name -> class-or-type dict containing all attributes that should be serialized + If the attribute is a class, this class must be a subclass of StateObject. + """ + + def _get_state(self): + return {attr: self.__get_state_attr(attr, cls) + for attr, cls in self._stateobject_attributes.iteritems()} + + def __get_state_attr(self, attr, cls): + """ + helper for _get_state. + returns the value of the given attribute + """ + if getattr(self, attr) is None: + return None + if isinstance(cls, ClassType): + return getattr(self, attr)._get_state() + else: + return getattr(self, attr) + + def _load_state(self, state): + for attr, cls in self._stateobject_attributes.iteritems(): + self.__load_state_attr(attr, cls, state) + + def __load_state_attr(self, attr, cls, state): + """ + helper for _load_state. + loads the given attribute from the state. + """ + if state[attr] is not None: # First, catch None as value. + if isinstance(cls, ClassType): # Is the attribute a StateObject itself? + assert issubclass(cls, StateObject) + curr = getattr(self, attr) + if curr: # if the attribute is already present, delegate to the objects ._load_state method. + curr._load_state(state[attr]) + else: # otherwise, create a new object. + setattr(self, attr, cls._from_state(state[attr])) + else: + setattr(self, attr, cls(state[attr])) + else: + setattr(self, attr, None) + + @classmethod + def _from_state(cls, state): + f = cls() # the default implementation assumes an empty constructor. Override accordingly. + f._load_state(state) + return f \ No newline at end of file diff --git a/netlib/tcp.py b/netlib/tcp.py index 729e513e..c26d1191 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,6 +1,7 @@ import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils +from netlib.stateobject import StateObject SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD @@ -173,14 +174,13 @@ class Reader(_FileLike): return result -class Address(tuple): +class Address(StateObject): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ - def __new__(cls, address, use_ipv6=False): - a = super(Address, cls).__new__(cls, tuple(address)) - a.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET - return a + def __init__(self, address, use_ipv6=False): + self.address = address + self.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET @classmethod def wrap(cls, t): @@ -189,18 +189,35 @@ class Address(tuple): else: return cls(t) + def __call__(self): + return self.address + @property def host(self): - return self[0] + return self.address[0] @property def port(self): - return self[1] + return self.address[1] @property - def is_ipv6(self): + def use_ipv6(self): return self.family == socket.AF_INET6 + def _load_state(self, state): + self.address = state["address"] + self.family = socket.AF_INET6 if state["use_ipv6"] else socket.AF_INET + + def _get_state(self): + return dict( + address=self.address, + use_ipv6=self.use_ipv6 + ) + + @classmethod + def _from_state(cls, state): + return cls(**state) + class SocketCloseMixin: def finish(self): @@ -240,7 +257,7 @@ class TCPClient(SocketCloseMixin): wbufsize = -1 def __init__(self, address, source_address=None): self.address = Address.wrap(address) - self.source_address = source_address + self.source_address = Address.wrap(source_address) if source_address else None self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False @@ -275,12 +292,12 @@ class TCPClient(SocketCloseMixin): try: connection = socket.socket(self.address.family, socket.SOCK_STREAM) if self.source_address: - connection.bind(self.source_address) - connection.connect(self.address) + connection.bind(self.source_address()) + connection.connect(self.address()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: - raise NetLibError('Error connecting to "%s": %s' % (self.address[0], err)) + raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection def settimeout(self, n): @@ -376,7 +393,7 @@ class TCPServer: self.__shutdown_request = False self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.address) + self.socket.bind(self.address()) self.address = Address.wrap(self.socket.getsockname()) self.socket.listen(self.request_queue_size) @@ -427,7 +444,7 @@ class TCPServer: if traceback: exc = traceback.format_exc() print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s"%client_address + print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) print >> fp, exc print >> fp, '-'*40 -- cgit v1.2.3 From dc45b4bf19bff5edc0b72ccb68fad04d479aff83 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 31 Jan 2014 01:06:53 +0100 Subject: move StateObject back into libmproxy --- netlib/certutils.py | 12 +------- netlib/odict.py | 3 +- netlib/stateobject.py | 80 --------------------------------------------------- netlib/tcp.py | 21 ++++---------- 4 files changed, 7 insertions(+), 109 deletions(-) delete mode 100644 netlib/stateobject.py diff --git a/netlib/certutils.py b/netlib/certutils.py index 139203b9..94294f6e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -3,7 +3,6 @@ from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -from netlib.stateobject import StateObject import tcp default_exp = 62208000 # =24 * 60 * 60 * 720 @@ -153,22 +152,13 @@ class _GeneralNames(univ.SequenceOf): sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) -class SSLCert(StateObject): +class SSLCert: def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. """ self.x509 = cert - def _get_state(self): - return self.to_pem() - - def _load_state(self, state): - self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) - - def _from_state(cls, state): - return cls.from_pem(state) - @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) diff --git a/netlib/odict.py b/netlib/odict.py index 8e195afc..46b74e8e 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,5 +1,4 @@ import re, copy -from netlib.stateobject import StateObject def safe_subn(pattern, repl, target, *args, **kwargs): @@ -11,7 +10,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs): return re.subn(str(pattern), str(repl), target, *args, **kwargs) -class ODict(StateObject): +class ODict: """ A dictionary-like object for managing ordered (key, value) data. """ diff --git a/netlib/stateobject.py b/netlib/stateobject.py deleted file mode 100644 index c2ef2cd4..00000000 --- a/netlib/stateobject.py +++ /dev/null @@ -1,80 +0,0 @@ -from types import ClassType - - -class StateObject: - def _get_state(self): - raise NotImplementedError - - def _load_state(self, state): - raise NotImplementedError - - @classmethod - def _from_state(cls, state): - raise NotImplementedError - - def __eq__(self, other): - try: - return self._get_state() == other._get_state() - except AttributeError: # we may compare with something that's not a StateObject - return False - - -class SimpleStateObject(StateObject): - """ - A StateObject with opionated conventions that tries to keep everything DRY. - - Simply put, you agree on a list of attributes and their type. - Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves. - SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods. - Overriding _get_state or _load_state to add custom adjustments is always possible. - """ - - _stateobject_attributes = None # none by default to raise an exception if definition was forgotten - """ - An attribute-name -> class-or-type dict containing all attributes that should be serialized - If the attribute is a class, this class must be a subclass of StateObject. - """ - - def _get_state(self): - return {attr: self.__get_state_attr(attr, cls) - for attr, cls in self._stateobject_attributes.iteritems()} - - def __get_state_attr(self, attr, cls): - """ - helper for _get_state. - returns the value of the given attribute - """ - if getattr(self, attr) is None: - return None - if isinstance(cls, ClassType): - return getattr(self, attr)._get_state() - else: - return getattr(self, attr) - - def _load_state(self, state): - for attr, cls in self._stateobject_attributes.iteritems(): - self.__load_state_attr(attr, cls, state) - - def __load_state_attr(self, attr, cls, state): - """ - helper for _load_state. - loads the given attribute from the state. - """ - if state[attr] is not None: # First, catch None as value. - if isinstance(cls, ClassType): # Is the attribute a StateObject itself? - assert issubclass(cls, StateObject) - curr = getattr(self, attr) - if curr: # if the attribute is already present, delegate to the objects ._load_state method. - curr._load_state(state[attr]) - else: # otherwise, create a new object. - setattr(self, attr, cls._from_state(state[attr])) - else: - setattr(self, attr, cls(state[attr])) - else: - setattr(self, attr, None) - - @classmethod - def _from_state(cls, state): - f = cls() # the default implementation assumes an empty constructor. Override accordingly. - f._load_state(state) - return f \ No newline at end of file diff --git a/netlib/tcp.py b/netlib/tcp.py index c26d1191..346bc053 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,7 +1,6 @@ import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils -from netlib.stateobject import StateObject SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD @@ -174,13 +173,13 @@ class Reader(_FileLike): return result -class Address(StateObject): +class Address(object): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ def __init__(self, address, use_ipv6=False): self.address = address - self.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET + self.use_ipv6 = use_ipv6 @classmethod def wrap(cls, t): @@ -204,19 +203,9 @@ class Address(StateObject): def use_ipv6(self): return self.family == socket.AF_INET6 - def _load_state(self, state): - self.address = state["address"] - self.family = socket.AF_INET6 if state["use_ipv6"] else socket.AF_INET - - def _get_state(self): - return dict( - address=self.address, - use_ipv6=self.use_ipv6 - ) - - @classmethod - def _from_state(cls, state): - return cls(**state) + @use_ipv6.setter + def use_ipv6(self, b): + self.family = socket.AF_INET6 if b else socket.AF_INET class SocketCloseMixin: -- cgit v1.2.3 From 0bbc40dc33dd7bd3729e639874882dd6dd7ea818 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 4 Feb 2014 04:51:41 +0100 Subject: store used sni in TCPClient, add equality check for tcp.Address --- netlib/tcp.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 346bc053..94ea8806 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -207,8 +207,12 @@ class Address(object): def use_ipv6(self, b): self.family = socket.AF_INET6 if b else socket.AF_INET + def __eq__(self, other): + other = Address.wrap(other) + return (self.address, self.family) == (other.address, other.family) -class SocketCloseMixin: + +class SocketCloseMixin(object): def finish(self): self.finished = True try: @@ -250,6 +254,7 @@ class TCPClient(SocketCloseMixin): self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False + self.sni = None def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ @@ -267,6 +272,7 @@ class TCPClient(SocketCloseMixin): self.connection = SSL.Connection(context, self.connection) self.ssl_established = True if sni: + self.sni = sni self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() try: -- cgit v1.2.3 From 7fc544bc7ff8fd610ba9db92c0d3b59a0b040b5b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 5 Feb 2014 21:34:14 +0100 Subject: adjust netlib.wsgi to reflect changes in mitmproxys flow format --- netlib/tcp.py | 2 +- netlib/wsgi.py | 15 ++++++++++----- test/test_tcp.py | 1 + 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 94ea8806..34e47999 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -178,7 +178,7 @@ class Address(object): This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ def __init__(self, address, use_ipv6=False): - self.address = address + self.address = tuple(address) self.use_ipv6 = use_ipv6 @classmethod diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 647cb899..b576bdff 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,17 +1,22 @@ import cStringIO, urllib, time, traceback -import odict +import odict, tcp class ClientConn: def __init__(self, address): - self.address = address + self.address = tcp.Address.wrap(address) + + +class Flow: + def __init__(self, client_conn): + self.client_conn = client_conn class Request: def __init__(self, client_conn, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content - self.client_conn = client_conn + self.flow = Flow(client_conn) def date_time_string(): @@ -60,8 +65,8 @@ class WSGIAdaptor: 'SERVER_PROTOCOL': "HTTP/1.1", } environ.update(extra) - if request.client_conn.address: - environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address + if request.flow.client_conn.address: + environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.flow.client_conn.address() for key, value in request.headers.items(): key = 'HTTP_' + key.upper().replace('-', '_') diff --git a/test/test_tcp.py b/test/test_tcp.py index 49e20635..525961d5 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -215,6 +215,7 @@ class TestSNI(test.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(sni="foo.com") + assert c.sni == "foo.com" assert c.rfile.readline() == "foo.com" -- cgit v1.2.3 From a72ae4d85c08b5716cd88715081be0f1ecaeb9d4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 11 Feb 2014 12:09:58 +0100 Subject: Bump version Do it now already so that mitmproxy will warn the user if netlib is not from master. --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index 9b2e037e..1d3250e1 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,4 @@ -IVERSION = (0, 10) +IVERSION = (0, 11) VERSION = ".".join(str(i) for i in IVERSION) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION -- cgit v1.2.3 From c276b4294cac97c1281ce9bb4934e49d0ba970a2 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 15 Feb 2014 23:16:28 +0100 Subject: allow super() on TCPServer, add thread names for better debugging --- netlib/http_auth.py | 2 +- netlib/tcp.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/netlib/http_auth.py b/netlib/http_auth.py index be99fb3d..b0451e3b 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,4 +1,4 @@ -import contrib.md5crypt as md5crypt +from .contrib import md5crypt import http from argparse import Action, ArgumentTypeError diff --git a/netlib/tcp.py b/netlib/tcp.py index 34e47999..5c351bae 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -380,7 +380,7 @@ class BaseHandler(SocketCloseMixin): -class TCPServer: +class TCPServer(object): request_queue_size = 20 def __init__(self, address): self.address = Address.wrap(address) @@ -416,7 +416,10 @@ class TCPServer: connection, client_address = self.socket.accept() t = threading.Thread( target = self.connection_thread, - args = (connection, client_address) + args = (connection, client_address), + name = "ConnectionThread (%s:%s -> %s:%s)" % + (client_address[0], client_address[1], + self.address.host, self.address.port) ) t.setDaemon(1) t.start() @@ -443,7 +446,7 @@ class TCPServer: print >> fp, exc print >> fp, '-'*40 - def handle_client_connection(self, conn, client_address): # pragma: no cover + def handle_client_connection(self, conn, client_address): # pragma: no cover """ Called after client connection. """ -- cgit v1.2.3 From 49f29ce8eff8251e79d76df72ccc85c66969ba44 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 26 Feb 2014 10:09:36 +1300 Subject: Add an explicit license file. Fixes #30 --- LICENSE | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..c08a0186 --- /dev/null +++ b/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2013, Aldo Cortesi. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. -- cgit v1.2.3 From 3443bae94e090b0bf12005ef4f0ca474bd903fb1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 27 Feb 2014 18:35:16 +1300 Subject: Cipher suite selection for client connections, improved error handling --- netlib/tcp.py | 19 ++++++++++++++++--- test/test_tcp.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 5c351bae..23449baf 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -235,7 +235,8 @@ class SocketCloseMixin(object): self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. + #Section 4.2.2.13 of RFC 1122 tells us that a close() with any + # pending readable data could lead to an immediate RST being sent. #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html while self.connection.recv(4096): pass @@ -256,11 +257,16 @@ class TCPClient(SocketCloseMixin): self.ssl_established = False self.sni = None - def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): + def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None, cipher_list=None): """ cert: Path to a file containing both client cert and private key. """ context = SSL.Context(method) + if cipher_list: + try: + context.set_cipher_list(cipher_list) + except SSL.Error, v: + raise NetLibError("SSL cipher specification error: %s"%str(v)) if options is not None: context.set_options(options) if cert: @@ -350,7 +356,10 @@ class BaseHandler(SocketCloseMixin): if not options is None: ctx.set_options(options) if cipher_list: - ctx.set_cipher_list(cipher_list) + try: + ctx.set_cipher_list(cipher_list) + except SSL.Error, v: + raise NetLibError("SSL cipher specification error: %s"%str(v)) if handle_sni: # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(handle_sni) @@ -399,6 +408,10 @@ class TCPServer(object): except: self.handle_error(connection, client_address) finally: + try: + connection.shutdown(socket.SHUT_RDWR) + except: + pass connection.close() def serve_forever(self, poll_interval=0.1): diff --git a/test/test_tcp.py b/test/test_tcp.py index 525961d5..9c15e2eb 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -219,7 +219,7 @@ class TestSNI(test.ServerTestBase): assert c.rfile.readline() == "foo.com" -class TestClientCipherList(test.ServerTestBase): +class TestServerCipherList(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict( cert = tutils.test_data.path("data/server.crt"), @@ -235,6 +235,36 @@ class TestClientCipherList(test.ServerTestBase): assert c.rfile.readline() == "['RC4-SHA']" +class TestServerCipherListError(test.ServerTestBase): + handler = ClientCipherListHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + request_client_cert = False, + v3_only = False, + cipher_list = 'bogus' + ) + def test_echo(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com") + + +class TestClientCipherListError(test.ServerTestBase): + handler = ClientCipherListHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + request_client_cert = False, + v3_only = False, + cipher_list = 'RC4-SHA' + ) + def test_echo(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + tutils.raises("cipher specification", c.convert_to_ssl, sni="foo.com", cipher_list="bogus") + + class TestSSLDisconnect(test.ServerTestBase): handler = DisconnectHandler ssl = dict( -- cgit v1.2.3 From 7788391903ef67ed1e779560936d60402159f8f5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 2 Mar 2014 13:50:19 +1300 Subject: Minor improvement to CertStore interface --- netlib/certutils.py | 9 ++++----- test/test_certutils.py | 10 +++++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 94294f6e..0b29d52f 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -113,10 +113,11 @@ class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self): + def __init__(self, cacert): self.certs = {} + self.cacert = cacert - def get_cert(self, commonname, sans, cacert): + def get_cert(self, commonname, sans): """ Returns an SSLCert object. @@ -125,13 +126,11 @@ class CertStore: sans: A list of Subject Alternate Names. - cacert: The path to a CA certificate. - Return None if the certificate could not be found or generated. """ if commonname in self.certs: return self.certs[commonname] - c = dummy_cert(cacert, commonname, sans) + c = dummy_cert(self.cacert, commonname, sans) self.certs[commonname] = c return c diff --git a/test/test_certutils.py b/test/test_certutils.py index 7a00caca..4fab69e6 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -21,16 +21,16 @@ class TestCertStore: with tutils.tmpdir() as d: ca = os.path.join(d, "ca") assert certutils.dummy_ca(ca) - c = certutils.CertStore() + c = certutils.CertStore(ca) def test_create_tmp(self): with tutils.tmpdir() as d: ca = os.path.join(d, "ca") assert certutils.dummy_ca(ca) - c = certutils.CertStore() - assert c.get_cert("foo.com", [], ca) - assert c.get_cert("foo.com", [], ca) - assert c.get_cert("*.foo.com", [], ca) + c = certutils.CertStore(ca) + assert c.get_cert("foo.com", []) + assert c.get_cert("foo.com", []) + assert c.get_cert("*.foo.com", []) class TestDummyCert: -- cgit v1.2.3 From e381c0366863ae412547e16d67860137a6b89a32 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 2 Mar 2014 16:47:10 +1300 Subject: Cleanups, tests, and no-cover directives for code sections we can't test. --- netlib/odict.py | 10 ---------- netlib/tcp.py | 8 +++++--- test/test_odict.py | 8 -------- test/test_tcp.py | 20 ++++++++++++++++++++ 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/netlib/odict.py b/netlib/odict.py index 46b74e8e..7c743f4e 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -96,16 +96,6 @@ class ODict: def items(self): return self.lst[:] - def _get_state(self): - return [tuple(i) for i in self.lst] - - def _load_state(self, state): - self.list = [list(i) for i in state] - - @classmethod - def _from_state(klass, state): - return klass([list(i) for i in state]) - def copy(self): """ Returns a copy of this object. diff --git a/netlib/tcp.py b/netlib/tcp.py index 23449baf..8f2ebdf0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -2,6 +2,8 @@ import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils +EINTR = 4 + SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD @@ -238,7 +240,7 @@ class SocketCloseMixin(object): #Section 4.2.2.13 of RFC 1122 tells us that a close() with any # pending readable data could lead to an immediate RST being sent. #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): + while self.connection.recv(4096): # pragma: no cover pass self.connection.close() except (socket.error, SSL.Error, IOError): @@ -420,8 +422,8 @@ class TCPServer(object): while not self.__shutdown_request: try: r, w, e = select.select([self.socket], [], [], poll_interval) - except select.error, ex: - if ex[0] == 4: + except select.error, ex: # pragma: no cover + if ex[0] == EINTR: continue else: raise diff --git a/test/test_odict.py b/test/test_odict.py index 26bff357..cdbb4f39 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -41,14 +41,6 @@ class TestODict: assert h.match_re("two: due") assert not h.match_re("nonono") - def test_getset_state(self): - self.od.add("foo", 1) - self.od.add("foo", 2) - self.od.add("bar", 3) - state = self.od._get_state() - nd = odict.ODict._from_state(state) - assert nd == self.od - def test_in_any(self): self.od["one"] = ["atwoa", "athreea"] assert self.od.in_any("one", "two") diff --git a/test/test_tcp.py b/test/test_tcp.py index 9c15e2eb..4e27a632 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -2,6 +2,7 @@ import cStringIO, Queue, time, socket, random from netlib import tcp, certutils, test import mock import tutils +from OpenSSL import SSL class SNIHandler(tcp.BaseHandler): sni = None @@ -435,3 +436,22 @@ class TestFileLike: s.readline() assert s.first_byte_timestamp == expected + def test_read_ssl_error(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = mock.MagicMock() + s.read = mock.MagicMock(side_effect=SSL.Error()) + s = tcp.Reader(s) + tutils.raises(tcp.NetLibSSLError, s.read, 1) + + + +class TestAddress: + def test_simple(self): + a = tcp.Address("localhost", True) + assert a.use_ipv6 + b = tcp.Address("foo.com", True) + assert not a == b + c = tcp.Address("localhost", True) + assert a == c + + -- cgit v1.2.3 From 1acaf1c880ba7054e4eb1cc1ed4ea5d0cf852e61 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 2 Mar 2014 16:54:21 +1300 Subject: Re-add state operations to ODict. --- netlib/odict.py | 10 ++++++++++ test/test_odict.py | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/netlib/odict.py b/netlib/odict.py index 7c743f4e..46b74e8e 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -96,6 +96,16 @@ class ODict: def items(self): return self.lst[:] + def _get_state(self): + return [tuple(i) for i in self.lst] + + def _load_state(self, state): + self.list = [list(i) for i in state] + + @classmethod + def _from_state(klass, state): + return klass([list(i) for i in state]) + def copy(self): """ Returns a copy of this object. diff --git a/test/test_odict.py b/test/test_odict.py index cdbb4f39..794956be 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -24,6 +24,15 @@ class TestODict: for i in expected: assert out.find(i) >= 0 + def test_getset_state(self): + self.od.add("foo", 1) + self.od.add("foo", 2) + self.od.add("bar", 3) + state = self.od._get_state() + nd = odict.ODict._from_state(state) + assert nd == self.od + nd._load_state(state) + def test_dictToHeader2(self): self.od["one"] = ["uno"] expected1 = "one: uno\r\n" -- cgit v1.2.3 From cfaa3da25cee39c5395a6ff27dfc47ff07dbeef6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 2 Mar 2014 21:37:28 +1300 Subject: Use PyOpenSSL's underlying ffi interface to get current cipher for connections. --- netlib/tcp.py | 16 +++++++++++++--- test/test_tcp.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 8f2ebdf0..0dff807b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -2,6 +2,7 @@ import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils + EINTR = 4 SSLv2_METHOD = SSL.SSLv2_METHOD @@ -214,7 +215,16 @@ class Address(object): return (self.address, self.family) == (other.address, other.family) -class SocketCloseMixin(object): +class _Connection(object): + def get_current_cipher(self): + if not self.ssl_established: + return None + c = SSL._lib.SSL_get_current_cipher(self.connection._ssl) + name = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_name(c))) + bits = SSL._lib.SSL_CIPHER_get_bits(c, SSL._ffi.NULL) + version = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_version(c))) + return name, bits, version + def finish(self): self.finished = True try: @@ -248,7 +258,7 @@ class SocketCloseMixin(object): pass -class TCPClient(SocketCloseMixin): +class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 def __init__(self, address, source_address=None): @@ -310,7 +320,7 @@ class TCPClient(SocketCloseMixin): return self.connection.gettimeout() -class BaseHandler(SocketCloseMixin): +class BaseHandler(_Connection): """ The instantiator is expected to call the handle() and finish() methods. diff --git a/test/test_tcp.py b/test/test_tcp.py index 4e27a632..387e3f33 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -49,6 +49,13 @@ class ClientCipherListHandler(tcp.BaseHandler): self.wfile.flush() +class CurrentCipherHandler(tcp.BaseHandler): + sni = None + def handle(self): + self.wfile.write("%s"%str(self.get_current_cipher())) + self.wfile.flush() + + class DisconnectHandler(tcp.BaseHandler): def handle(self): self.close() @@ -151,7 +158,8 @@ class TestServerSSL(test.ServerTestBase): cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), request_client_cert = False, - v3_only = False + v3_only = False, + cipher_list = "AES256-SHA" ) def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -165,6 +173,15 @@ class TestServerSSL(test.ServerTestBase): def test_get_remote_cert(self): assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1") + def test_get_current_cipher(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + assert not c.get_current_cipher() + c.convert_to_ssl(sni="foo.com") + ret = c.get_current_cipher() + assert ret + assert "AES" in ret[0] + class TestSSLv3Only(test.ServerTestBase): handler = EchoHandler @@ -236,6 +253,22 @@ class TestServerCipherList(test.ServerTestBase): assert c.rfile.readline() == "['RC4-SHA']" +class TestServerCurrentCipher(test.ServerTestBase): + handler = CurrentCipherHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + request_client_cert = False, + v3_only = False, + cipher_list = 'RC4-SHA' + ) + def test_echo(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(sni="foo.com") + assert "RC4-SHA" in c.rfile.readline() + + class TestServerCipherListError(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict( -- cgit v1.2.3 From d56f7fba806e6c2008c40df9b0b290b81189cb92 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 2 Mar 2014 22:14:33 +1300 Subject: We now require PyOpenSSL >= 0.14 --- setup.py | 2 +- test/test_tcp.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 2937487c..5ba9f824 100644 --- a/setup.py +++ b/setup.py @@ -88,5 +88,5 @@ setup( "Topic :: Software Development :: Testing :: Traffic Generation", "Topic :: Internet :: WWW/HTTP", ], - install_requires=["pyasn1>0.1.2", "pyopenssl>=0.12"], + install_requires=["pyasn1>0.1.2", "pyopenssl>=0.14"], ) diff --git a/test/test_tcp.py b/test/test_tcp.py index 387e3f33..d5d11294 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -477,7 +477,6 @@ class TestFileLike: tutils.raises(tcp.NetLibSSLError, s.read, 1) - class TestAddress: def test_simple(self): a = tcp.Address("localhost", True) @@ -486,5 +485,3 @@ class TestAddress: assert not a == b c = tcp.Address("localhost", True) assert a == c - - -- cgit v1.2.3 From 7c82418e0baca311487230074655f5f106bcdd2b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 4 Mar 2014 14:12:58 +1300 Subject: Beef up CertStore, add DH params. --- netlib/certutils.py | 157 ++++++++++++++++++++++++++----------------------- test/test_certutils.py | 39 +++++------- 2 files changed, 99 insertions(+), 97 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 0b29d52f..b9c291d0 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -5,23 +5,27 @@ from pyasn1.error import PyAsn1Error import OpenSSL import tcp -default_exp = 62208000 # =24 * 60 * 60 * 720 -default_o = "mitmproxy" -default_cn = "mitmproxy" - -def create_ca(o=default_o, cn=default_cn, exp=default_exp): +DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 +# Generated with "openssl dhparam". It's too slow to generate this on startup. +DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS----- +MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 +zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK +1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC +-----END DH PARAMETERS-----""" + +def create_ca(o, cn, exp): key = OpenSSL.crypto.PKey() key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) - ca = OpenSSL.crypto.X509() - ca.set_serial_number(int(time.time()*10000)) - ca.set_version(2) - ca.get_subject().CN = cn - ca.get_subject().O = o - ca.gmtime_adj_notBefore(0) - ca.gmtime_adj_notAfter(exp) - ca.set_issuer(ca.get_subject()) - ca.set_pubkey(key) - ca.add_extensions([ + cert = OpenSSL.crypto.X509() + cert.set_serial_number(int(time.time()*10000)) + cert.set_version(2) + cert.get_subject().CN = cn + cert.get_subject().O = o + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(exp) + cert.set_issuer(cert.get_subject()) + cert.set_pubkey(key) + cert.add_extensions([ OpenSSL.crypto.X509Extension("basicConstraints", True, "CA:TRUE"), OpenSSL.crypto.X509Extension("nsCertType", True, @@ -32,80 +36,39 @@ def create_ca(o=default_o, cn=default_cn, exp=default_exp): OpenSSL.crypto.X509Extension("keyUsage", False, "keyCertSign, cRLSign"), OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", - subject=ca), + subject=cert), ]) - ca.sign(key, "sha1") - return key, ca - - -def dummy_ca(path, o=default_o, cn=default_cn, exp=default_exp): - dirname = os.path.dirname(path) - if not os.path.exists(dirname): - os.makedirs(dirname) - if path.endswith(".pem"): - basename, _ = os.path.splitext(path) - basename = os.path.basename(basename) - else: - basename = os.path.basename(path) - - key, ca = create_ca(o=o, cn=cn, exp=exp) - - # Dump the CA plus private key - f = open(path, "wb") - f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() - - # Dump the certificate in PEM format - f = open(os.path.join(dirname, basename + "-cert.pem"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() - - # Create a .cer file with the same contents for Android - f = open(os.path.join(dirname, basename + "-cert.cer"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() - - # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(dirname, basename + "-cert.p12"), "wb") - p12 = OpenSSL.crypto.PKCS12() - p12.set_certificate(ca) - p12.set_privatekey(key) - f.write(p12.export()) - f.close() - return True - - -def dummy_cert(ca, commonname, sans): + cert.sign(key, "sha1") + return key, cert + + +def dummy_cert(pkey, cacert, commonname, sans): """ - Generates and writes a certificate to fp. + Generates a dummy certificate. - ca: Path to the certificate authority file, or None. + pkey: CA private key + cacert: CA certificate commonname: Common name for the generated certificate. sans: A list of Subject Alternate Names. - Returns cert path if operation succeeded, None if not. + Returns cert if operation succeeded, None if not. """ ss = [] for i in sans: ss.append("DNS: %s"%i) ss = ", ".join(ss) - raw = file(ca, "rb").read() - ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) - key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600*48) cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) - cert.set_issuer(ca.get_subject()) + cert.set_issuer(cacert.get_subject()) cert.get_subject().CN = commonname cert.set_serial_number(int(time.time()*10000)) if ss: cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) - cert.set_pubkey(ca.get_pubkey()) - cert.sign(key, "sha1") + cert.set_pubkey(cacert.get_pubkey()) + cert.sign(pkey, "sha1") return SSLCert(cert) @@ -113,9 +76,59 @@ class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, cacert): + def __init__(self, pkey, cert): + self.pkey, self.cert = pkey, cert self.certs = {} - self.cacert = cacert + + @classmethod + def from_store(klass, path, basename): + p = os.path.join(path, basename + "-ca.pem") + if not os.path.exists(p): + key, ca = klass.create_store(path, basename) + else: + p = os.path.join(path, basename + "-ca.pem") + raw = file(p, "rb").read() + ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + return klass(key, ca) + + @classmethod + def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): + if not os.path.exists(path): + os.makedirs(path) + + o = o or basename + cn = cn or basename + + key, ca = create_ca(o=o, cn=cn, exp=expiry) + # Dump the CA plus private key + f = open(os.path.join(path, basename + "-ca.pem"), "wb") + f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Dump the certificate in PEM format + f = open(os.path.join(path, basename + "-cert.pem"), "wb") + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Create a .cer file with the same contents for Android + f = open(os.path.join(path, basename + "-cert.cer"), "wb") + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.close() + + # Dump the certificate in PKCS12 format for Windows devices + f = open(os.path.join(path, basename + "-cert.p12"), "wb") + p12 = OpenSSL.crypto.PKCS12() + p12.set_certificate(ca) + p12.set_privatekey(key) + f.write(p12.export()) + f.close() + + f = open(os.path.join(path, basename + "-dhparam.pem"), "wb") + f.write(DEFAULT_DHPARAM) + f.close() + return key, ca def get_cert(self, commonname, sans): """ @@ -130,7 +143,7 @@ class CertStore: """ if commonname in self.certs: return self.certs[commonname] - c = dummy_cert(self.cacert, commonname, sans) + c = dummy_cert(self.pkey, self.cert, commonname, sans) self.certs[commonname] = c return c diff --git a/test/test_certutils.py b/test/test_certutils.py index 4fab69e6..f741bdec 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -3,43 +3,32 @@ from netlib import certutils import tutils -def test_dummy_ca(): - with tutils.tmpdir() as d: - path = os.path.join(d, "foo/cert.cnf") - assert certutils.dummy_ca(path) - assert os.path.exists(path) - - path = os.path.join(d, "foo/cert2.pem") - assert certutils.dummy_ca(path) - assert os.path.exists(path) - assert os.path.exists(os.path.join(d, "foo/cert2-cert.pem")) - assert os.path.exists(os.path.join(d, "foo/cert2-cert.p12")) - - class TestCertStore: def test_create_explicit(self): with tutils.tmpdir() as d: - ca = os.path.join(d, "ca") - assert certutils.dummy_ca(ca) - c = certutils.CertStore(ca) + ca = certutils.CertStore.from_store(d, "test") + assert ca.get_cert("foo", []) + + ca2 = certutils.CertStore.from_store(d, "test") + assert ca2.get_cert("foo", []) + + assert ca.cert.get_serial_number() == ca2.cert.get_serial_number() def test_create_tmp(self): with tutils.tmpdir() as d: - ca = os.path.join(d, "ca") - assert certutils.dummy_ca(ca) - c = certutils.CertStore(ca) - assert c.get_cert("foo.com", []) - assert c.get_cert("foo.com", []) - assert c.get_cert("*.foo.com", []) + ca = certutils.CertStore.from_store(d, "test") + assert ca.get_cert("foo.com", []) + assert ca.get_cert("foo.com", []) + assert ca.get_cert("*.foo.com", []) class TestDummyCert: def test_with_ca(self): with tutils.tmpdir() as d: - cacert = os.path.join(d, "cacert") - assert certutils.dummy_ca(cacert) + ca = certutils.CertStore.from_store(d, "test") r = certutils.dummy_cert( - cacert, + ca.pkey, + ca.cert, "foo.com", ["one.com", "two.com", "*.three.com"] ) -- cgit v1.2.3 From 0c3bc1cff2a8b1c4c425be5c1ca11c4b850bcc68 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 5 Mar 2014 13:19:16 +1300 Subject: Much more sophisticated certificate store - Handle wildcard lookup - Handle lookup of SANs - Provide hooks for registering override certs and keys for specific domains (including wildcard specifications) --- netlib/certutils.py | 87 +++++++++++++++++++++++++++++++++++++++++++------- test/test_certutils.py | 68 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 140 insertions(+), 15 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index b9c291d0..fafcb5fd 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -4,6 +4,7 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL import tcp +import UserDict DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 # Generated with "openssl dhparam". It's too slow to generate this on startup. @@ -42,11 +43,11 @@ def create_ca(o, cn, exp): return key, cert -def dummy_cert(pkey, cacert, commonname, sans): +def dummy_cert(privkey, cacert, commonname, sans): """ Generates a dummy certificate. - pkey: CA private key + privkey: CA private key cacert: CA certificate commonname: Common name for the generated certificate. sans: A list of Subject Alternate Names. @@ -68,17 +69,55 @@ def dummy_cert(pkey, cacert, commonname, sans): cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert.set_pubkey(cacert.get_pubkey()) - cert.sign(pkey, "sha1") + cert.sign(privkey, "sha1") return SSLCert(cert) +class _Node(UserDict.UserDict): + def __init__(self): + UserDict.UserDict.__init__(self) + self.value = None + + +class DNTree: + """ + Domain store that knows about wildcards. DNS wildcards are very + restricted - the only valid variety is an asterisk on the left-most + domain component, i.e.: + + *.foo.com + """ + def __init__(self): + self.d = _Node() + + def add(self, dn, cert): + parts = dn.split(".") + parts.reverse() + current = self.d + for i in parts: + current = current.setdefault(i, _Node()) + current.value = cert + + def get(self, dn): + parts = dn.split(".") + current = self.d + for i in reversed(parts): + if i in current: + current = current[i] + elif "*" in current: + return current["*"].value + else: + return None + return current.value + + class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, pkey, cert): - self.pkey, self.cert = pkey, cert - self.certs = {} + def __init__(self, privkey, cacert): + self.privkey, self.cacert = privkey, cacert + self.certs = DNTree() @classmethod def from_store(klass, path, basename): @@ -130,9 +169,29 @@ class CertStore: f.close() return key, ca + def add_cert_file(self, commonname, path): + raw = file(path, "rb").read() + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + try: + privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + except Exception: + privkey = None + self.add_cert(SSLCert(cert), privkey, commonname) + + def add_cert(self, cert, privkey, *names): + """ + Adds a cert to the certstore. We register the CN in the cert plus + any SANs, and also the list of names provided as an argument. + """ + self.certs.add(cert.cn, (cert, privkey)) + for i in cert.altnames: + self.certs.add(i, (cert, privkey)) + for i in names: + self.certs.add(i, (cert, privkey)) + def get_cert(self, commonname, sans): """ - Returns an SSLCert object. + Returns an (cert, privkey) tuple. commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. @@ -141,11 +200,12 @@ class CertStore: Return None if the certificate could not be found or generated. """ - if commonname in self.certs: - return self.certs[commonname] - c = dummy_cert(self.pkey, self.cert, commonname, sans) - self.certs[commonname] = c - return c + c = self.certs.get(commonname) + if not c: + c = dummy_cert(self.privkey, self.cacert, commonname, sans) + self.add_cert(c, None) + c = (c, None) + return (c[0], c[1] or self.privkey) class _GeneralName(univ.Choice): @@ -171,6 +231,9 @@ class SSLCert: """ self.x509 = cert + def __eq__(self, other): + return self.digest("sha1") == other.digest("sha1") + @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) diff --git a/test/test_certutils.py b/test/test_certutils.py index f741bdec..7f320e7e 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -1,7 +1,37 @@ import os from netlib import certutils +import OpenSSL import tutils +class TestDNTree: + def test_simple(self): + d = certutils.DNTree() + d.add("foo.com", "foo") + d.add("bar.com", "bar") + assert d.get("foo.com") == "foo" + assert d.get("bar.com") == "bar" + assert not d.get("oink.com") + assert not d.get("oink") + assert not d.get("") + assert not d.get("oink.oink") + + d.add("*.match.org", "match") + assert not d.get("match.org") + assert d.get("foo.match.org") == "match" + assert d.get("foo.foo.match.org") == "match" + + def test_wildcard(self): + d = certutils.DNTree() + d.add("foo.com", "foo") + assert not d.get("*.foo.com") + d.add("*.foo.com", "wild") + + d = certutils.DNTree() + d.add("*", "foo") + assert d.get("foo.com") == "foo" + assert d.get("*.foo.com") == "foo" + assert d.get("com") == "foo" + class TestCertStore: def test_create_explicit(self): @@ -12,7 +42,7 @@ class TestCertStore: ca2 = certutils.CertStore.from_store(d, "test") assert ca2.get_cert("foo", []) - assert ca.cert.get_serial_number() == ca2.cert.get_serial_number() + assert ca.cacert.get_serial_number() == ca2.cacert.get_serial_number() def test_create_tmp(self): with tutils.tmpdir() as d: @@ -21,14 +51,46 @@ class TestCertStore: assert ca.get_cert("foo.com", []) assert ca.get_cert("*.foo.com", []) + r = ca.get_cert("*.foo.com", []) + assert r[1] == ca.privkey + + def test_add_cert(self): + with tutils.tmpdir() as d: + ca = certutils.CertStore.from_store(d, "test") + + def test_sans(self): + with tutils.tmpdir() as d: + ca = certutils.CertStore.from_store(d, "test") + c1 = ca.get_cert("foo.com", ["*.bar.com"]) + c2 = ca.get_cert("foo.bar.com", []) + assert c1 == c2 + c3 = ca.get_cert("bar.com", []) + assert not c1 == c3 + + def test_overrides(self): + with tutils.tmpdir() as d: + ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") + ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test") + assert not ca1.cacert.get_serial_number() == ca2.cacert.get_serial_number() + + dc = ca2.get_cert("foo.com", []) + dcp = os.path.join(d, "dc") + f = open(dcp, "wb") + f.write(dc[0].to_pem()) + f.close() + ca1.add_cert_file("foo.com", dcp) + + ret = ca1.get_cert("foo.com", []) + assert ret[0].serial == dc[0].serial + class TestDummyCert: def test_with_ca(self): with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") r = certutils.dummy_cert( - ca.pkey, - ca.cert, + ca.privkey, + ca.cacert, "foo.com", ["one.com", "two.com", "*.three.com"] ) -- cgit v1.2.3 From 86730a9a4c3a14b510590aa97a8ae8989cb6ec5e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 5 Mar 2014 13:43:52 +1300 Subject: Handler convert_to_ssl now takes a key object, not a path. --- netlib/tcp.py | 2 +- netlib/test.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 0dff807b..83059bc2 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -375,7 +375,7 @@ class BaseHandler(_Connection): if handle_sni: # SNI callback happens during do_handshake() ctx.set_tlsext_servername_callback(handle_sni) - ctx.use_privatekey_file(key) + ctx.use_privatekey(key) ctx.use_certificate(cert.x509) if request_client_cert: def ver(*args): diff --git a/netlib/test.py b/netlib/test.py index 2f6a7107..b88b3586 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,5 +1,6 @@ import threading, Queue, cStringIO import tcp, certutils +import OpenSSL class ServerThread(threading.Thread): def __init__(self, server): @@ -49,6 +50,8 @@ class TServer(tcp.TCPServer): self.handler_klass = handler_klass self.last_handler = None + + def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h @@ -56,6 +59,8 @@ class TServer(tcp.TCPServer): cert = certutils.SSLCert.from_pem( file(self.ssl["cert"], "rb").read() ) + raw = file(self.ssl["key"], "rb").read() + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 @@ -63,8 +68,7 @@ class TServer(tcp.TCPServer): method = tcp.SSLv23_METHOD options = None h.convert_to_ssl( - cert, - self.ssl["key"], + cert, key, method = method, options = options, handle_sni = getattr(h, "handle_sni", None), -- cgit v1.2.3 From 52b14aa1d1bbeb3e2b8c62ee9939b9575ee1840f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 5 Mar 2014 17:29:14 +1300 Subject: CertStore: cope with certs that have no common name --- netlib/certutils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index fafcb5fd..d544cfa6 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -169,21 +169,22 @@ class CertStore: f.close() return key, ca - def add_cert_file(self, commonname, path): + def add_cert_file(self, spec, path): raw = file(path, "rb").read() cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) try: privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) except Exception: privkey = None - self.add_cert(SSLCert(cert), privkey, commonname) + self.add_cert(SSLCert(cert), privkey, spec) def add_cert(self, cert, privkey, *names): """ Adds a cert to the certstore. We register the CN in the cert plus any SANs, and also the list of names provided as an argument. """ - self.certs.add(cert.cn, (cert, privkey)) + if cert.cn: + self.certs.add(cert.cn, (cert, privkey)) for i in cert.altnames: self.certs.add(i, (cert, privkey)) for i in names: -- cgit v1.2.3 From 2a12aa3c47d57cc2d3a36f6726a5f081ca493457 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 7 Mar 2014 16:38:50 +1300 Subject: Support Ephemeral Diffie-Hellman --- netlib/certutils.py | 24 +++++++++++++++++++----- netlib/tcp.py | 7 ++++++- netlib/test.py | 11 ++++++----- test/data/dhparam.pem | 5 +++++ test/test_tcp.py | 20 ++++++++++++++++++++ 5 files changed, 56 insertions(+), 11 deletions(-) create mode 100644 test/data/dhparam.pem diff --git a/netlib/certutils.py b/netlib/certutils.py index d544cfa6..19148382 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -115,10 +115,22 @@ class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, privkey, cacert): + def __init__(self, privkey, cacert, dhparams=None): self.privkey, self.cacert = privkey, cacert + self.dhparams = dhparams self.certs = DNTree() + @classmethod + def load_dhparam(klass, path): + bio = OpenSSL.SSL._lib.BIO_new_file(path, b"r") + if bio != OpenSSL.SSL._ffi.NULL: + bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) + dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( + bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL + ) + dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) + return dh + @classmethod def from_store(klass, path, basename): p = os.path.join(path, basename + "-ca.pem") @@ -129,7 +141,9 @@ class CertStore: raw = file(p, "rb").read() ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - return klass(key, ca) + dhp = os.path.join(path, basename + "-dhparam.pem") + dh = klass.load_dhparam(dhp) + return klass(key, ca, dh) @classmethod def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): @@ -147,17 +161,17 @@ class CertStore: f.close() # Dump the certificate in PEM format - f = open(os.path.join(path, basename + "-cert.pem"), "wb") + f = open(os.path.join(path, basename + "-ca-cert.pem"), "wb") f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Create a .cer file with the same contents for Android - f = open(os.path.join(path, basename + "-cert.cer"), "wb") + f = open(os.path.join(path, basename + "-ca-cert.cer"), "wb") f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.close() # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(path, basename + "-cert.p12"), "wb") + f = open(os.path.join(path, basename + "-ca-cert.p12"), "wb") p12 = OpenSSL.crypto.PKCS12() p12.set_certificate(ca) p12.set_privatekey(key) diff --git a/netlib/tcp.py b/netlib/tcp.py index 83059bc2..078ac497 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -339,7 +339,10 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None): + def convert_to_ssl(self, cert, key, + method=SSLv23_METHOD, options=None, handle_sni=None, + request_client_cert=False, cipher_list=None, dhparams=None + ): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -377,6 +380,8 @@ class BaseHandler(_Connection): ctx.set_tlsext_servername_callback(handle_sni) ctx.use_privatekey(key) ctx.use_certificate(cert.x509) + if dhparams: + SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams) if request_client_cert: def ver(*args): self.clientcert = certutils.SSLCert(args[1]) diff --git a/netlib/test.py b/netlib/test.py index b88b3586..bb0012ad 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -18,7 +18,6 @@ class ServerTestBase: ssl = None handler = None addr = ("localhost", 0) - @classmethod def setupAll(cls): cls.q = Queue.Queue() @@ -43,15 +42,16 @@ class ServerTestBase: class TServer(tcp.TCPServer): def __init__(self, ssl, q, handler_klass, addr): """ - ssl: A {cert, key, v3_only} dict. + ssl: A dictionary of SSL parameters: + + cert, key, request_client_cert, cipher_list, + dhparams, v3_only """ tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q self.handler_klass = handler_klass self.last_handler = None - - def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h @@ -73,7 +73,8 @@ class TServer(tcp.TCPServer): options = options, handle_sni = getattr(h, "handle_sni", None), request_client_cert = self.ssl["request_client_cert"], - cipher_list = self.ssl.get("cipher_list", None) + cipher_list = self.ssl.get("cipher_list", None), + dhparams = self.ssl.get("dhparams", None) ) h.handle() h.finish() diff --git a/test/data/dhparam.pem b/test/data/dhparam.pem new file mode 100644 index 00000000..6f2526e1 --- /dev/null +++ b/test/data/dhparam.pem @@ -0,0 +1,5 @@ +-----BEGIN DH PARAMETERS----- +MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 +zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK +1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC +-----END DH PARAMETERS----- diff --git a/test/test_tcp.py b/test/test_tcp.py index d5d11294..814754cd 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -363,6 +363,26 @@ class TestSSLTimeOut(test.ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) +class TestDHParams(test.ServerTestBase): + handler = HangHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + request_client_cert = False, + v3_only = False, + dhparams = certutils.CertStore.load_dhparam( + tutils.test_data.path("data/dhparam.pem"), + ), + cipher_list = "DHE-RSA-AES256-SHA" + ) + def test_dhparams(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + ret = c.get_current_cipher() + assert ret[0] == "DHE-RSA-AES256-SHA" + + class TestTCPClient: def test_conerr(self): c = tcp.TCPClient(("127.0.0.1", 0)) -- cgit v1.2.3 From f5cc63d653b27210d9c3d7646c01c3a9d540d9c7 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 10 Mar 2014 17:29:27 +1300 Subject: Certificate flags --- .gitignore | 3 +- netlib/certffi.py | 36 ++++++++++++++ netlib/certutils.py | 7 +++ test/test_certutils.py | 14 +++++- test/test_tcp.py | 127 ++++++++++++++++++++++++++++--------------------- 5 files changed, 130 insertions(+), 57 deletions(-) create mode 100644 netlib/certffi.py diff --git a/.gitignore b/.gitignore index e66d51fe..26c449d1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ MANIFEST *.swp *.swo .coverage -.idea \ No newline at end of file +.idea +__pycache__ diff --git a/netlib/certffi.py b/netlib/certffi.py new file mode 100644 index 00000000..c5d7c95e --- /dev/null +++ b/netlib/certffi.py @@ -0,0 +1,36 @@ +import cffi +import OpenSSL +xffi = cffi.FFI() +xffi.cdef (""" + struct rsa_meth_st { + int flags; + ...; + }; + struct rsa_st { + int pad; + long version; + struct rsa_meth_st *meth; + ...; + }; +""") +xffi.verify( + """#include """, + extra_compile_args=['-w'] +) + +def handle(privkey): + new = xffi.new("struct rsa_st*") + newbuf = xffi.buffer(new) + rsa = OpenSSL.SSL._lib.EVP_PKEY_get1_RSA(privkey._pkey) + oldbuf = OpenSSL.SSL._ffi.buffer(rsa) + newbuf[:] = oldbuf[:] + return new + +def set_flags(privkey, val): + hdl = handle(privkey) + hdl.meth.flags = val + return privkey + +def get_flags(privkey): + hdl = handle(privkey) + return hdl.meth.flags diff --git a/netlib/certutils.py b/netlib/certutils.py index 19148382..92b219ee 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -111,6 +111,7 @@ class DNTree: return current.value + class CertStore: """ Implements an in-memory certificate store. @@ -222,6 +223,11 @@ class CertStore: c = (c, None) return (c[0], c[1] or self.privkey) + def gen_pkey(self, cert): + import certffi + certffi.set_flags(self.privkey, 1) + return self.privkey + class _GeneralName(univ.Choice): # We are only interested in dNSNames. We use a default handler to ignore @@ -326,6 +332,7 @@ class SSLCert: return altnames + def get_remote_cert(host, port, sni): c = tcp.TCPClient((host, port)) c.connect() diff --git a/test/test_certutils.py b/test/test_certutils.py index 7f320e7e..176575ea 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -1,5 +1,5 @@ import os -from netlib import certutils +from netlib import certutils, certffi import OpenSSL import tutils @@ -83,6 +83,16 @@ class TestCertStore: ret = ca1.get_cert("foo.com", []) assert ret[0].serial == dc[0].serial + def test_gen_pkey(self): + try: + with tutils.tmpdir() as d: + ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") + ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test") + cert = ca1.get_cert("foo.com", []) + assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1 + finally: + certffi.set_flags(ca2.privkey, 0) + class TestDummyCert: def test_with_ca(self): @@ -125,3 +135,5 @@ class TestSSLCert: d = file(tutils.test_data.path("data/dercert"),"rb").read() s = certutils.SSLCert.from_der(d) assert s.cn + + diff --git a/test/test_tcp.py b/test/test_tcp.py index 814754cd..ec995702 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -4,16 +4,6 @@ import mock import tutils from OpenSSL import SSL -class SNIHandler(tcp.BaseHandler): - sni = None - def handle_sni(self, connection): - self.sni = connection.get_servername() - - def handle(self): - self.wfile.write(self.sni) - self.wfile.flush() - - class EchoHandler(tcp.BaseHandler): sni = None def handle_sni(self, connection): @@ -25,58 +15,19 @@ class EchoHandler(tcp.BaseHandler): self.wfile.flush() -class ClientPeernameHandler(tcp.BaseHandler): - def handle(self): - self.wfile.write(str(self.connection.getpeername())) - self.wfile.flush() - - -class CertHandler(tcp.BaseHandler): - sni = None - def handle_sni(self, connection): - self.sni = connection.get_servername() - - def handle(self): - self.wfile.write("%s\n"%self.clientcert.serial) - self.wfile.flush() - - class ClientCipherListHandler(tcp.BaseHandler): sni = None - def handle(self): self.wfile.write("%s"%self.connection.get_cipher_list()) self.wfile.flush() -class CurrentCipherHandler(tcp.BaseHandler): - sni = None - def handle(self): - self.wfile.write("%s"%str(self.get_current_cipher())) - self.wfile.flush() - - -class DisconnectHandler(tcp.BaseHandler): - def handle(self): - self.close() - - class HangHandler(tcp.BaseHandler): def handle(self): while 1: time.sleep(1) -class TimeoutHandler(tcp.BaseHandler): - def handle(self): - self.timeout = False - self.settimeout(0.01) - try: - self.rfile.read(10) - except tcp.NetLibTimeout: - self.timeout = True - - class TestServer(test.ServerTestBase): handler = EchoHandler def test_echo(self): @@ -89,7 +40,10 @@ class TestServer(test.ServerTestBase): class TestServerBind(test.ServerTestBase): - handler = ClientPeernameHandler + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write(str(self.connection.getpeername())) + self.wfile.flush() def test_bind(self): """ Test to bind to a given random port. Try again if the random port turned out to be blocked. """ @@ -198,7 +152,14 @@ class TestSSLv3Only(test.ServerTestBase): class TestSSLClientCert(test.ServerTestBase): - handler = CertHandler + class handler(tcp.BaseHandler): + sni = None + def handle_sni(self, connection): + self.sni = connection.get_servername() + + def handle(self): + self.wfile.write("%s\n"%self.clientcert.serial) + self.wfile.flush() ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -222,7 +183,15 @@ class TestSSLClientCert(test.ServerTestBase): class TestSNI(test.ServerTestBase): - handler = SNIHandler + class handler(tcp.BaseHandler): + sni = None + def handle_sni(self, connection): + self.sni = connection.get_servername() + + def handle(self): + self.wfile.write(self.sni) + self.wfile.flush() + ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -254,7 +223,11 @@ class TestServerCipherList(test.ServerTestBase): class TestServerCurrentCipher(test.ServerTestBase): - handler = CurrentCipherHandler + class handler(tcp.BaseHandler): + sni = None + def handle(self): + self.wfile.write("%s"%str(self.get_current_cipher())) + self.wfile.flush() ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -300,7 +273,9 @@ class TestClientCipherListError(test.ServerTestBase): class TestSSLDisconnect(test.ServerTestBase): - handler = DisconnectHandler + class handler(tcp.BaseHandler): + def handle(self): + self.close() ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), @@ -329,7 +304,15 @@ class TestDisconnect(test.ServerTestBase): class TestServerTimeOut(test.ServerTestBase): - handler = TimeoutHandler + class handler(tcp.BaseHandler): + def handle(self): + self.timeout = False + self.settimeout(0.01) + try: + self.rfile.read(10) + except tcp.NetLibTimeout: + self.timeout = True + def test_timeout(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -383,6 +366,40 @@ class TestDHParams(test.ServerTestBase): assert ret[0] == "DHE-RSA-AES256-SHA" + +class TestPrivkeyGen(test.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + with tutils.tmpdir() as d: + ca1 = certutils.CertStore.from_store(d, "test2") + ca2 = certutils.CertStore.from_store(d, "test3") + cert, _ = ca1.get_cert("foo.com", []) + key = ca2.gen_pkey(cert) + self.convert_to_ssl(cert, key) + + def test_privkey(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + tutils.raises("bad record mac", c.convert_to_ssl) + + +class TestPrivkeyGenNoFlags(test.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + with tutils.tmpdir() as d: + ca1 = certutils.CertStore.from_store(d, "test2") + ca2 = certutils.CertStore.from_store(d, "test3") + cert, _ = ca1.get_cert("foo.com", []) + certffi.set_flags(ca2.privkey, 0) + self.convert_to_ssl(cert, ca2.privkey) + + def test_privkey(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + tutils.raises("unexpected eof", c.convert_to_ssl) + + + class TestTCPClient: def test_conerr(self): c = tcp.TCPClient(("127.0.0.1", 0)) -- cgit v1.2.3 From 4bd15a28b73f521fc08ea77512198faffeaaa247 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 10 Mar 2014 17:43:39 +0100 Subject: fix #28 --- netlib/tcp.py | 4 +++- requirements.txt | 2 +- test/test_tcp.py | 23 +++++++++++++++++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 078ac497..c5f97f94 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -143,7 +143,9 @@ class Reader(_FileLike): raise NetLibTimeout except socket.error: raise NetLibDisconnect - except SSL.SysCallError: + except SSL.SysCallError as e: + if e.args == (-1, 'Unexpected EOF'): + break raise NetLibDisconnect except SSL.Error, v: raise NetLibSSLError(v.message) diff --git a/requirements.txt b/requirements.txt index 460a60e4..7b45f7c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ pyasn1>=0.1.7 -pyOpenSSL>=0.13 \ No newline at end of file +pyOpenSSL>=0.14 \ No newline at end of file diff --git a/test/test_tcp.py b/test/test_tcp.py index ec995702..77146829 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -106,6 +106,11 @@ class TestDisconnect(test.ServerTestBase): assert c.rfile.readline() == testval +class HardDisconnectHandler(tcp.BaseHandler): + def handle(self): + self.connection.close() + + class TestServerSSL(test.ServerTestBase): handler = EchoHandler ssl = dict( @@ -293,6 +298,24 @@ class TestSSLDisconnect(test.ServerTestBase): tutils.raises(Queue.Empty, self.q.get_nowait) +class TestSSLHardDisconnect(test.ServerTestBase): + handler = HardDisconnectHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + request_client_cert = False, + v3_only = False + ) + def test_echo(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + # Exercise SSL.SysCallError + c.rfile.read(10) + c.close() + tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") + + class TestDisconnect(test.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) -- cgit v1.2.3 From 34e469eb558cae999b13510b029714a31d9dd1f3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 11 Mar 2014 20:23:27 +0100 Subject: create dhparam file if it doesn't exist, fix mitmproxy/mitmproxy#235 --- netlib/certutils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/netlib/certutils.py b/netlib/certutils.py index 92b219ee..ebe643e4 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -123,6 +123,13 @@ class CertStore: @classmethod def load_dhparam(klass, path): + + # netlib<=0.10 doesn't generate a dhparam file. + # Create it now if neccessary. + if not os.path.exists(path): + with open(path, "wb") as f: + f.write(DEFAULT_DHPARAM) + bio = OpenSSL.SSL._lib.BIO_new_file(path, b"r") if bio != OpenSSL.SSL._ffi.NULL: bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) -- cgit v1.2.3 From d8f54c7c038872fb6f05952214654843c9103da1 Mon Sep 17 00:00:00 2001 From: Bradley Baetz Date: Thu, 20 Mar 2014 11:12:11 +1100 Subject: Change the criticality of a number of X509 extentions, to match the RFCs and real-world CAs/certs. This improve compatability with older browsers/clients. --- netlib/certutils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index ebe643e4..4c50b984 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -29,12 +29,12 @@ def create_ca(o, cn, exp): cert.add_extensions([ OpenSSL.crypto.X509Extension("basicConstraints", True, "CA:TRUE"), - OpenSSL.crypto.X509Extension("nsCertType", True, + OpenSSL.crypto.X509Extension("nsCertType", False, "sslCA"), - OpenSSL.crypto.X509Extension("extendedKeyUsage", True, + OpenSSL.crypto.X509Extension("extendedKeyUsage", False, "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" ), - OpenSSL.crypto.X509Extension("keyUsage", False, + OpenSSL.crypto.X509Extension("keyUsage", True, "keyCertSign, cRLSign"), OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", subject=cert), @@ -67,7 +67,7 @@ def dummy_cert(privkey, cacert, commonname, sans): cert.set_serial_number(int(time.time()*10000)) if ss: cert.set_version(2) - cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) + cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) cert.set_pubkey(cacert.get_pubkey()) cert.sign(privkey, "sha1") return SSLCert(cert) -- cgit v1.2.3 From e7c3e4c5acdf9a229e13502e14a39caac332fe6c Mon Sep 17 00:00:00 2001 From: Pedro Worcel Date: Sun, 30 Mar 2014 20:58:47 +1300 Subject: Change error into awesome user-friendlyness Hi there, I was getting a very weird error "ODict valuelist should be lists", when attempting to add a header. My code was as followed: ``` msg.headers["API-Key"] = new_headers["API-Key"] 42 msg.headers["API-Sign"] = new_headers["API-Sign"] ``` In the end, that was because there could be multiple equal headers. In order to cater to that, it you guys might enjoy the patch I attach, for it converts strings automatically into lists of multiple headers. I think it should work, but I haven't tested it :$ It'd allow me to have the above code, instead of this one below: ``` msg.headers["API-Key"] = [new_headers["API-Key"]] 42 msg.headers["API-Sign"] = [new_headers["API-Sign"]] ``` --- netlib/odict.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/netlib/odict.py b/netlib/odict.py index 46b74e8e..d0ff5cf6 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -60,7 +60,9 @@ class ODict: key, they are cleared. """ if isinstance(valuelist, basestring): - raise ValueError("ODict valuelist should be lists.") + # convert the string into a single element list. + valuelist = [valuelist] + new = self._filter_lst(k, self.lst) for i in valuelist: new.append([k, i]) -- cgit v1.2.3 From bb10dfc5055b6877f35a362ee7705c612aece418 Mon Sep 17 00:00:00 2001 From: Pedro Worcel Date: Mon, 31 Mar 2014 20:19:23 +1300 Subject: Instead of removing the error, for consistency, leaving the error as-was and replaced the message with something that may or may not be more understandable :P --- netlib/odict.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/netlib/odict.py b/netlib/odict.py index d0ff5cf6..0640c25d 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -60,9 +60,8 @@ class ODict: key, they are cleared. """ if isinstance(valuelist, basestring): - # convert the string into a single element list. - valuelist = [valuelist] - + raise ValueError("Expected list instead of string. E.g. odict['elem'] = ['string1', 'string2']") + new = self._filter_lst(k, self.lst) for i in valuelist: new.append([k, i]) -- cgit v1.2.3 From c2c952b3ccf0a1803bd64d4a77998c754298e31a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 31 Mar 2014 12:44:20 +0200 Subject: make error message example less abstract. --- netlib/odict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/odict.py b/netlib/odict.py index 0640c25d..ea95a586 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -60,7 +60,7 @@ class ODict: key, they are cleared. """ if isinstance(valuelist, basestring): - raise ValueError("Expected list instead of string. E.g. odict['elem'] = ['string1', 'string2']") + raise ValueError("Expected list of values instead of string. Example: odict['Host'] = ['www.example.com']") new = self._filter_lst(k, self.lst) for i in valuelist: -- cgit v1.2.3 From 92081eee04ebbdae6443d24b74404c76fd4f17d4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 25 Apr 2014 19:40:37 +0200 Subject: Update certutils.py refs mitmproxy/mitmproxy#200 --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index ebe643e4..187abfae 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -22,7 +22,7 @@ def create_ca(o, cn, exp): cert.set_version(2) cert.get_subject().CN = cn cert.get_subject().O = o - cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notBefore(-3600*48) cert.gmtime_adj_notAfter(exp) cert.set_issuer(cert.get_subject()) cert.set_pubkey(key) -- cgit v1.2.3 From a8345af282692a7faf859b37f2748705091004fe Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 15 May 2014 13:51:59 +0200 Subject: extract cert creation to be accessible in handle_sni callbacks --- netlib/tcp.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index c5f97f94..7b05222f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -341,10 +341,9 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def convert_to_ssl(self, cert, key, - method=SSLv23_METHOD, options=None, handle_sni=None, - request_client_cert=False, cipher_list=None, dhparams=None - ): + def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, + handle_sni=None, request_client_cert=None, cipher_list=None, + dhparams=None ): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -390,6 +389,14 @@ class BaseHandler(_Connection): # Return true to prevent cert verification error return True ctx.set_verify(SSL.VERIFY_PEER, ver) + return ctx + + def convert_to_ssl(self, **kwargs): + """ + Convert connection to SSL. + For a list of parameters, see BaseHandler._create_ssl_context(...) + """ + ctx = self._create_ssl_context(**kwargs) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True self.connection.set_accept_state() -- cgit v1.2.3 From 71834aeab144d8bf083785f668989ad3fb21554e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 15 May 2014 14:15:33 +0200 Subject: make cert and key mandatory --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 7b05222f..e72d5e48 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -391,12 +391,12 @@ class BaseHandler(_Connection): ctx.set_verify(SSL.VERIFY_PEER, ver) return ctx - def convert_to_ssl(self, **kwargs): + def convert_to_ssl(self, cert, key, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) """ - ctx = self._create_ssl_context(**kwargs) + ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) self.connection = SSL.Connection(ctx, self.connection) self.ssl_established = True self.connection.set_accept_state() -- cgit v1.2.3 From 52c6ba8880363ba5d82b5e767559afbc72371272 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 15 May 2014 18:15:29 +0200 Subject: properly subclass Exception in HTTPError --- netlib/http.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 51f85627..f5b8118a 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,15 +1,15 @@ import string, urlparse, binascii import odict, utils -class HttpError(Exception): - def __init__(self, code, msg): - self.code, self.msg = code, msg - def __str__(self): - return "HttpError(%s, %s)"%(self.code, self.msg) +class HttpError(Exception): + def __init__(self, code, message): + super(HttpError, self).__init__(message) + self.code = code -class HttpErrorConnClosed(HttpError): pass +class HttpErrorConnClosed(HttpError): + pass def _is_valid_port(port): -- cgit v1.2.3 From 66ac56509f754d1239f81c92b6f7cfb65509dc47 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 21 May 2014 01:14:55 +0200 Subject: add support for ctx.load_verify_locations, refs mitmproxy/mitmproxy#174 --- netlib/tcp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index e72d5e48..c5bb7c4b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -343,7 +343,7 @@ class BaseHandler(_Connection): def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=None, cipher_list=None, - dhparams=None ): + dhparams=None, ca_file=None): """ cert: A certutils.SSLCert object. method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD @@ -371,6 +371,8 @@ class BaseHandler(_Connection): ctx = SSL.Context(method) if not options is None: ctx.set_options(options) + if ca_file: + ctx.load_verify_locations(ca_file) if cipher_list: try: ctx.set_cipher_list(cipher_list) @@ -450,7 +452,7 @@ class TCPServer(object): if ex[0] == EINTR: continue else: - raise + raise if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( -- cgit v1.2.3 From dc071c4ea7c77b640cb733d769f06631dceb8477 Mon Sep 17 00:00:00 2001 From: Pritam Baral Date: Wed, 28 May 2014 07:10:10 +0530 Subject: Ignore username:password part in url --- netlib/http.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/netlib/http.py b/netlib/http.py index f5b8118a..d000b802 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -43,6 +43,8 @@ def parse_url(url): return None if not scheme: return None + if '@' in netloc: + _, netloc = string.rsplit(netloc, '@', maxsplit=1) if ':' in netloc: host, port = string.rsplit(netloc, ':', maxsplit=1) try: -- cgit v1.2.3 From 217660f5db8f91fa351c188e1e61903e9f54e94d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 25 Jun 2014 14:30:42 +0200 Subject: add socks module --- netlib/socks.py | 142 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 netlib/socks.py diff --git a/netlib/socks.py b/netlib/socks.py new file mode 100644 index 00000000..daebe577 --- /dev/null +++ b/netlib/socks.py @@ -0,0 +1,142 @@ +import socket +import struct +from array import array +from .tcp import Address + + +class SocksError(Exception): + def __init__(self, code, message): + super(SocksError, self).__init__(message) + self.code = code + +class VERSION: + SOCKS4 = 0x04 + SOCKS5 = 0x05 + + +class CMD: + CONNECT = 0x01 + BIND = 0x02 + UDP_ASSOCIATE = 0x03 + + +class ATYP: + IPV4_ADDRESS = 0x01 + DOMAINNAME = 0x03 + IPV6_ADDRESS = 0x04 + +class REP: + SUCCEEDED = 0x00 + GENERAL_SOCKS_SERVER_FAILURE = 0x01 + CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02 + NETWORK_UNREACHABLE = 0x03 + HOST_UNREACHABLE = 0x04 + CONNECTION_REFUSED = 0x05 + TTL_EXPIRED = 0x06 + COMMAND_NOT_SUPPORTED = 0x07 + ADDRESS_TYPE_NOT_SUPPORTED = 0x08 + +class METHOD: + NO_AUTHENTICATION_REQUIRED = 0x00 + GSSAPI = 0x01 + USERNAME_PASSWORD = 0x02 + NO_ACCEPTABLE_METHODS = 0xFF + + +class ClientGreeting(object): + __slots__ = ("ver", "methods") + + def __init__(self, ver, methods): + self.ver = ver + self.methods = methods + + @classmethod + def from_file(cls, f): + ver, nmethods = struct.unpack_from("!BB", f) + methods = array("B") + methods.fromfile(f, nmethods) + return cls(ver, methods) + + def to_file(self, f): + struct.pack_into("!BB", f, 0, self.ver, len(self.methods)) + self.methods.tofile(f) + + +class ServerGreeting(object): + __slots__ = ("ver", "method") + + def __init__(self, ver, method): + self.ver = ver + self.method = method + + @classmethod + def from_file(cls, f): + ver, method = struct.unpack_from("!BB", f) + return cls(ver, method) + + def to_file(self, f): + struct.pack_into("!BB", f, 0, self.ver, self.method) + + +class Request(object): + __slots__ = ("ver", "cmd", "atyp", "dst") + + def __init__(self, ver, cmd, atyp, dst): + self.ver = ver + self.cmd = cmd + self.atyp = atyp + self.dst = dst + + @classmethod + def from_file(cls, f): + ver, cmd, rsv, atyp = struct.unpack_from("!BBBB", f) + if rsv != 0x00: + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, + "Socks Request: Invalid reserved byte: %s" % rsv) + + if atyp == ATYP.IPV4_ADDRESS: + host = socket.inet_ntoa(f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. + use_ipv6 = False + elif atyp == ATYP.IPV6_ADDRESS: + host = socket.inet_ntop(socket.AF_INET6, f.read(16)) + use_ipv6 = True + elif atyp == ATYP.DOMAINNAME: + length = struct.unpack_from("!B", f) + host = f.read(length) + use_ipv6 = False + else: + raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, + "Socks Request: Unknown ATYP: %s" % atyp) + + port = struct.unpack_from("!H", f) + dst = Address(host, port, use_ipv6=use_ipv6) + return Request(ver, cmd, atyp, dst) + + def to_file(self, f): + raise NotImplementedError() + +class Reply(object): + __slots__ = ("ver", "rep", "atyp", "bnd") + + def __init__(self, ver, rep, atyp, bnd): + self.ver = ver + self.rep = rep + self.atyp = atyp + self.bnd = bnd + + @classmethod + def from_file(cls, f): + raise NotImplementedError() + + def to_file(self, f): + struct.pack_into("!BBBB", f, 0, self.ver, self.rep, 0x00, self.atyp) + if self.atyp == ATYP.IPV4_ADDRESS: + f.write(socket.inet_aton(self.bnd.host)) + elif self.atyp == ATYP.IPV6_ADDRESS: + f.write(socket.inet_pton(socket.AF_INET6, self.bnd.host)) + elif self.atyp == ATYP.DOMAINNAME: + struct.pack_into("!B", f, 0, len(self.bnd.host)) + f.write(self.bnd.host) + else: + raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Unknown ATYP: %s" % self.atyp) + struct.pack_into("!H", f, 0, self.bnd.port) \ No newline at end of file -- cgit v1.2.3 From dc3d3e5f0a8c4de734187c39888af5fbdb63d8a0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 25 Jun 2014 20:31:10 +0200 Subject: add inet_ntop/inet_pton functions --- netlib/utils.py | 29 ++++++++++++++++++++++++++--- netlib/version.py | 1 + test/test_utils.py | 19 ++++++++++++++++++- 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/netlib/utils.py b/netlib/utils.py index 61fd54ae..00e1cd12 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,3 +1,5 @@ +import socket + def isascii(s): try: @@ -32,9 +34,9 @@ def hexdump(s): """ parts = [] for i in range(0, len(s), 16): - o = "%.10x"%i - part = s[i:i+16] - x = " ".join("%.2x"%ord(i) for i in part) + o = "%.10x" % i + part = s[i:i + 16] + x = " ".join("%.2x" % ord(i) for i in part) if len(part) < 16: x += " " x += " ".join(" " for i in range(16 - len(part))) @@ -42,3 +44,24 @@ def hexdump(s): (o, x, cleanBin(part, True)) ) return parts + + +def inet_ntop(address_family, packed_ip): + if hasattr(socket, "inet_ntop"): + return socket.inet_ntop(address_family, packed_ip) + # Windows Fallbacks + if address_family == socket.AF_INET: + return socket.inet_ntoa(packed_ip) + if address_family == socket.AF_INET6: + ip = packed_ip.encode("hex") + return ":".join([ip[i:i + 4] for i in range(0, len(ip), 4)]) + + +def inet_pton(address_family, ip_string): + if hasattr(socket, "inet_pton"): + return socket.inet_pton(address_family, ip_string) + # Windows Fallbacks + if address_family == socket.AF_INET: + return socket.inet_aton(ip_string) + if address_family == socket.AF_INET6: + return ip_string.replace(":", "").decode("hex") \ No newline at end of file diff --git a/netlib/version.py b/netlib/version.py index 1d3250e1..25565d40 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,4 +1,5 @@ IVERSION = (0, 11) VERSION = ".".join(str(i) for i in IVERSION) +MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION diff --git a/test/test_utils.py b/test/test_utils.py index 61820a81..a9a48cd0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,5 +1,5 @@ from netlib import utils - +import socket def test_hexdump(): assert utils.hexdump("one\0"*10) @@ -11,3 +11,20 @@ def test_cleanBin(): assert utils.cleanBin("\nne") == "\nne" assert utils.cleanBin("\nne", True) == ".ne" +def test_ntop_pton(): + for family, ip_string, packed_ip in ( + (socket.AF_INET, + "127.0.0.1", + "\x7f\x00\x00\x01"), + (socket.AF_INET6, + "2001:0db8:85a3:08d3:1319:8a2e:0370:7344", + " \x01\r\xb8\x85\xa3\x08\xd3\x13\x19\x8a.\x03psD")): + assert ip_string == utils.inet_ntop(family, packed_ip) + assert packed_ip == utils.inet_pton(family, ip_string) + if hasattr(socket, "inet_ntop"): + ntop, pton = socket.inet_ntop, socket.inet_pton + delattr(socket,"inet_ntop") + delattr(socket,"inet_pton") + assert ip_string == utils.inet_ntop(family, packed_ip) + assert packed_ip == utils.inet_pton(family, ip_string) + socket.inet_ntop, socket.inet_pton = ntop, pton \ No newline at end of file -- cgit v1.2.3 From 6405595ae8593a52f6b81d7f311044f113476d82 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 25 Jun 2014 20:31:28 +0200 Subject: socks module: polish, add tests --- netlib/socks.py | 71 +++++++++++++++++++++--------------------------------- test/test_socks.py | 55 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 43 deletions(-) create mode 100644 test/test_socks.py diff --git a/netlib/socks.py b/netlib/socks.py index daebe577..01f54859 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -1,7 +1,7 @@ import socket import struct from array import array -from .tcp import Address +from . import tcp, utils class SocksError(Exception): @@ -9,6 +9,7 @@ class SocksError(Exception): super(SocksError, self).__init__(message) self.code = code + class VERSION: SOCKS4 = 0x04 SOCKS5 = 0x05 @@ -25,6 +26,7 @@ class ATYP: DOMAINNAME = 0x03 IPV6_ADDRESS = 0x04 + class REP: SUCCEEDED = 0x00 GENERAL_SOCKS_SERVER_FAILURE = 0x01 @@ -36,6 +38,7 @@ class REP: COMMAND_NOT_SUPPORTED = 0x07 ADDRESS_TYPE_NOT_SUPPORTED = 0x08 + class METHOD: NO_AUTHENTICATION_REQUIRED = 0x00 GSSAPI = 0x01 @@ -52,15 +55,14 @@ class ClientGreeting(object): @classmethod def from_file(cls, f): - ver, nmethods = struct.unpack_from("!BB", f) + ver, nmethods = struct.unpack("!BB", f.read(2)) methods = array("B") - methods.fromfile(f, nmethods) + methods.fromstring(f.read(nmethods)) return cls(ver, methods) def to_file(self, f): - struct.pack_into("!BB", f, 0, self.ver, len(self.methods)) - self.methods.tofile(f) - + f.write(struct.pack("!BB", self.ver, len(self.methods))) + f.write(self.methods.tostring()) class ServerGreeting(object): __slots__ = ("ver", "method") @@ -71,72 +73,55 @@ class ServerGreeting(object): @classmethod def from_file(cls, f): - ver, method = struct.unpack_from("!BB", f) + ver, method = struct.unpack("!BB", f.read(2)) return cls(ver, method) def to_file(self, f): - struct.pack_into("!BB", f, 0, self.ver, self.method) + f.write(struct.pack("!BB", self.ver, self.method)) +class Message(object): + __slots__ = ("ver", "msg", "atyp", "addr") -class Request(object): - __slots__ = ("ver", "cmd", "atyp", "dst") - - def __init__(self, ver, cmd, atyp, dst): + def __init__(self, ver, msg, atyp, addr): self.ver = ver - self.cmd = cmd + self.msg = msg self.atyp = atyp - self.dst = dst + self.addr = addr @classmethod def from_file(cls, f): - ver, cmd, rsv, atyp = struct.unpack_from("!BBBB", f) + ver, msg, rsv, atyp = struct.unpack("!BBBB", f.read(4)) if rsv != 0x00: raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: - host = socket.inet_ntoa(f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. + host = utils.inet_ntop(socket.AF_INET, f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = socket.inet_ntop(socket.AF_INET6, f.read(16)) + host = utils.inet_ntop(socket.AF_INET6, f.read(16)) use_ipv6 = True elif atyp == ATYP.DOMAINNAME: - length = struct.unpack_from("!B", f) + length, = struct.unpack("!B", f.read(1)) host = f.read(length) use_ipv6 = False else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Socks Request: Unknown ATYP: %s" % atyp) - port = struct.unpack_from("!H", f) - dst = Address(host, port, use_ipv6=use_ipv6) - return Request(ver, cmd, atyp, dst) - - def to_file(self, f): - raise NotImplementedError() - -class Reply(object): - __slots__ = ("ver", "rep", "atyp", "bnd") - - def __init__(self, ver, rep, atyp, bnd): - self.ver = ver - self.rep = rep - self.atyp = atyp - self.bnd = bnd - - @classmethod - def from_file(cls, f): - raise NotImplementedError() + port, = struct.unpack("!H", f.read(2)) + addr = tcp.Address((host, port), use_ipv6=use_ipv6) + return cls(ver, msg, atyp, addr) def to_file(self, f): - struct.pack_into("!BBBB", f, 0, self.ver, self.rep, 0x00, self.atyp) + f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) if self.atyp == ATYP.IPV4_ADDRESS: - f.write(socket.inet_aton(self.bnd.host)) + f.write(utils.inet_pton(socket.AF_INET, self.addr.host)) elif self.atyp == ATYP.IPV6_ADDRESS: - f.write(socket.inet_pton(socket.AF_INET6, self.bnd.host)) + f.write(utils.inet_pton(socket.AF_INET6, self.addr.host)) elif self.atyp == ATYP.DOMAINNAME: - struct.pack_into("!B", f, 0, len(self.bnd.host)) - f.write(self.bnd.host) + f.write(struct.pack("!B", len(self.addr.host))) + f.write(self.addr.host) else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Unknown ATYP: %s" % self.atyp) - struct.pack_into("!H", f, 0, self.bnd.port) \ No newline at end of file + f.write(struct.pack("!H", self.addr.port)) \ No newline at end of file diff --git a/test/test_socks.py b/test/test_socks.py new file mode 100644 index 00000000..3771df62 --- /dev/null +++ b/test/test_socks.py @@ -0,0 +1,55 @@ +from cStringIO import StringIO +import socket +from netlib import socks, utils +import tutils + + +def test_client_greeting(): + raw = StringIO("\x05\x02\x00\xBE\xEF") + out = StringIO() + msg = socks.ClientGreeting.from_file(raw) + msg.to_file(out) + + assert out.getvalue() == raw.getvalue()[:-1] + assert msg.ver == 5 + assert len(msg.methods) == 2 + assert 0xBE in msg.methods + assert 0xEF not in msg.methods + + +def test_server_greeting(): + raw = StringIO("\x05\x02") + out = StringIO() + msg = socks.ServerGreeting.from_file(raw) + msg.to_file(out) + + assert out.getvalue() == raw.getvalue() + assert msg.ver == 5 + assert msg.method == 0x02 + + +def test_message(): + raw = StringIO("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") + out = StringIO() + msg = socks.Message.from_file(raw) + assert raw.read(2) == "\xBE\xEF" + msg.to_file(out) + + assert out.getvalue() == raw.getvalue()[:-2] + assert msg.ver == 5 + assert msg.msg == 0x01 + assert msg.atyp == 0x03 + assert msg.addr == ("example.com", 0xDEAD) + + # Test ATYP=0x01 (IPV4) + raw = StringIO("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + msg = socks.Message.from_file(raw) + assert raw.read(2) == "\xBE\xEF" + assert msg.addr == ("127.0.0.1", 0xDEAD) + + # Test ATYP=0x04 (IPV6) + ipv6_addr = "2001:0db8:85a3:08d3:1319:8a2e:0370:7344" + raw = StringIO("\x05\x01\x00\x04" + utils.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF") + msg = socks.Message.from_file(raw) + assert raw.read(2) == "\xBE\xEF" + assert msg.addr.host == ipv6_addr \ No newline at end of file -- cgit v1.2.3 From e69133f98c513a99c017ad561ea9195280e3f7c5 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 25 Jun 2014 21:16:47 +0200 Subject: remove ntop windows workaround --- netlib/socks.py | 8 ++++---- netlib/utils.py | 23 +---------------------- test/test_socks.py | 7 ++++++- test/test_utils.py | 18 ------------------ 4 files changed, 11 insertions(+), 45 deletions(-) diff --git a/netlib/socks.py b/netlib/socks.py index 01f54859..97df3478 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -96,10 +96,10 @@ class Message(object): "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: - host = utils.inet_ntop(socket.AF_INET, f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. + host = socket.inet_ntoa(f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = utils.inet_ntop(socket.AF_INET6, f.read(16)) + host = socket.inet_ntop(socket.AF_INET6, f.read(16)) use_ipv6 = True elif atyp == ATYP.DOMAINNAME: length, = struct.unpack("!B", f.read(1)) @@ -116,9 +116,9 @@ class Message(object): def to_file(self, f): f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) if self.atyp == ATYP.IPV4_ADDRESS: - f.write(utils.inet_pton(socket.AF_INET, self.addr.host)) + f.write(socket.inet_aton(self.addr.host)) elif self.atyp == ATYP.IPV6_ADDRESS: - f.write(utils.inet_pton(socket.AF_INET6, self.addr.host)) + f.write(socket.inet_pton(socket.AF_INET6, self.addr.host)) elif self.atyp == ATYP.DOMAINNAME: f.write(struct.pack("!B", len(self.addr.host))) f.write(self.addr.host) diff --git a/netlib/utils.py b/netlib/utils.py index 00e1cd12..69ba456a 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -43,25 +43,4 @@ def hexdump(s): parts.append( (o, x, cleanBin(part, True)) ) - return parts - - -def inet_ntop(address_family, packed_ip): - if hasattr(socket, "inet_ntop"): - return socket.inet_ntop(address_family, packed_ip) - # Windows Fallbacks - if address_family == socket.AF_INET: - return socket.inet_ntoa(packed_ip) - if address_family == socket.AF_INET6: - ip = packed_ip.encode("hex") - return ":".join([ip[i:i + 4] for i in range(0, len(ip), 4)]) - - -def inet_pton(address_family, ip_string): - if hasattr(socket, "inet_pton"): - return socket.inet_pton(address_family, ip_string) - # Windows Fallbacks - if address_family == socket.AF_INET: - return socket.inet_aton(ip_string) - if address_family == socket.AF_INET6: - return ip_string.replace(":", "").decode("hex") \ No newline at end of file + return parts \ No newline at end of file diff --git a/test/test_socks.py b/test/test_socks.py index 3771df62..4787e309 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -1,5 +1,6 @@ from cStringIO import StringIO import socket +from nose.plugins.skip import SkipTest from netlib import socks, utils import tutils @@ -47,9 +48,13 @@ def test_message(): assert raw.read(2) == "\xBE\xEF" assert msg.addr == ("127.0.0.1", 0xDEAD) + +def test_message_ipv6(): + if not hasattr(socket, "inet_ntop"): + raise SkipTest("Skipped because inet_ntop is not available") # Test ATYP=0x04 (IPV6) ipv6_addr = "2001:0db8:85a3:08d3:1319:8a2e:0370:7344" - raw = StringIO("\x05\x01\x00\x04" + utils.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF") + raw = StringIO("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF") msg = socks.Message.from_file(raw) assert raw.read(2) == "\xBE\xEF" assert msg.addr.host == ipv6_addr \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index a9a48cd0..971e5076 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -10,21 +10,3 @@ def test_cleanBin(): assert utils.cleanBin("\00ne") == ".ne" assert utils.cleanBin("\nne") == "\nne" assert utils.cleanBin("\nne", True) == ".ne" - -def test_ntop_pton(): - for family, ip_string, packed_ip in ( - (socket.AF_INET, - "127.0.0.1", - "\x7f\x00\x00\x01"), - (socket.AF_INET6, - "2001:0db8:85a3:08d3:1319:8a2e:0370:7344", - " \x01\r\xb8\x85\xa3\x08\xd3\x13\x19\x8a.\x03psD")): - assert ip_string == utils.inet_ntop(family, packed_ip) - assert packed_ip == utils.inet_pton(family, ip_string) - if hasattr(socket, "inet_ntop"): - ntop, pton = socket.inet_ntop, socket.inet_pton - delattr(socket,"inet_ntop") - delattr(socket,"inet_pton") - assert ip_string == utils.inet_ntop(family, packed_ip) - assert packed_ip == utils.inet_pton(family, ip_string) - socket.inet_ntop, socket.inet_pton = ntop, pton \ No newline at end of file -- cgit v1.2.3 From 896e1a5524863b657292807037952ff2d574901b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 25 Jun 2014 21:31:10 +0200 Subject: fix overly restrictive tests --- test/test_socks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_socks.py b/test/test_socks.py index 4787e309..964678de 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -53,7 +53,7 @@ def test_message_ipv6(): if not hasattr(socket, "inet_ntop"): raise SkipTest("Skipped because inet_ntop is not available") # Test ATYP=0x04 (IPV6) - ipv6_addr = "2001:0db8:85a3:08d3:1319:8a2e:0370:7344" + ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344" raw = StringIO("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF") msg = socks.Message.from_file(raw) assert raw.read(2) == "\xBE\xEF" -- cgit v1.2.3 From dfabe165d46b726fc38e73b37b42c09bdc709795 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 25 Jun 2014 21:45:45 +0200 Subject: socks: 100% test coverage --- test/test_socks.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/test/test_socks.py b/test/test_socks.py index 964678de..740fdb9c 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -1,7 +1,7 @@ from cStringIO import StringIO import socket from nose.plugins.skip import SkipTest -from netlib import socks, utils +from netlib import socks, tcp import tutils @@ -42,10 +42,16 @@ def test_message(): assert msg.atyp == 0x03 assert msg.addr == ("example.com", 0xDEAD) + +def test_message_ipv4(): # Test ATYP=0x01 (IPV4) raw = StringIO("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + out = StringIO() msg = socks.Message.from_file(raw) assert raw.read(2) == "\xBE\xEF" + msg.to_file(out) + + assert out.getvalue() == raw.getvalue()[:-2] assert msg.addr == ("127.0.0.1", 0xDEAD) @@ -54,7 +60,25 @@ def test_message_ipv6(): raise SkipTest("Skipped because inet_ntop is not available") # Test ATYP=0x04 (IPV6) ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344" + raw = StringIO("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF") + out = StringIO() msg = socks.Message.from_file(raw) assert raw.read(2) == "\xBE\xEF" - assert msg.addr.host == ipv6_addr \ No newline at end of file + msg.to_file(out) + + assert out.getvalue() == raw.getvalue()[:-2] + assert msg.addr.host == ipv6_addr + + +def test_message_invalid_rsv(): + raw = StringIO("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + tutils.raises(socks.SocksError, socks.Message.from_file, raw) + + +def test_message_unknown_atyp(): + raw = StringIO("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + tutils.raises(socks.SocksError, socks.Message.from_file, raw) + + m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) + tutils.raises(socks.SocksError, m.to_file, StringIO()) \ No newline at end of file -- cgit v1.2.3 From 4d5d8b65114d061da4f6a41673011ce643c29aab Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 29 Jun 2014 13:10:07 +0200 Subject: mark nsCertType non-critical, fix #39 --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 187abfae..8aec5e82 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -29,7 +29,7 @@ def create_ca(o, cn, exp): cert.add_extensions([ OpenSSL.crypto.X509Extension("basicConstraints", True, "CA:TRUE"), - OpenSSL.crypto.X509Extension("nsCertType", True, + OpenSSL.crypto.X509Extension("nsCertType", False, "sslCA"), OpenSSL.crypto.X509Extension("extendedKeyUsage", True, "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" -- cgit v1.2.3 From 273c25a705c7784ed3fbe15faa11effe05809519 Mon Sep 17 00:00:00 2001 From: Brad Peabody Date: Sat, 12 Jul 2014 22:42:06 -0700 Subject: added option for read_response to only read the headers, beginnings of implementing streamed result in mitmproxy --- netlib/http.py | 6 ++++-- test/test_http.py | 12 ++++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index f5b8118a..21cde538 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -292,7 +292,7 @@ def parse_response_line(line): return (proto, code, msg) -def read_response(rfile, method, body_size_limit): +def read_response(rfile, method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. """ @@ -315,8 +315,10 @@ def read_response(rfile, method, body_size_limit): # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 if method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: content = "" - else: + elif include_body: content = read_http_body(rfile, headers, body_size_limit, False) + else: + content = None # if include_body==False then a None content means the body should be read separately return httpversion, code, msg, headers, content diff --git a/test/test_http.py b/test/test_http.py index e80e4b8f..df351dc7 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -229,10 +229,10 @@ class TestReadResponseNoContentLength(test.ServerTestBase): assert content == "bar\r\n\r\n" def test_read_response(): - def tst(data, method, limit): + def tst(data, method, limit, include_body=True): data = textwrap.dedent(data) r = cStringIO.StringIO(data) - return http.read_response(r, method, limit) + return http.read_response(r, method, limit, include_body=include_body) tutils.raises("server disconnect", tst, "", "GET", None) tutils.raises("invalid server response", tst, "foo", "GET", None) @@ -277,6 +277,14 @@ def test_read_response(): """ tutils.raises("invalid headers", tst, data, "GET", None) + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert tst(data, "GET", None, include_body=False)[4] == None + def test_parse_url(): assert not http.parse_url("") -- cgit v1.2.3 From 24ef9c61a39f24c8f5ec4414a4a9d0b6a2bc4283 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 14 Jul 2014 17:38:49 +0200 Subject: improve docs --- netlib/http.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 21cde538..413c73a1 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -292,12 +292,17 @@ def parse_response_line(line): return (proto, code, msg) -def read_response(rfile, method, body_size_limit, include_body=True): +def read_response(rfile, request_method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a response to a HEAD request) """ line = rfile.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message + if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() if not line: raise HttpErrorConnClosed(502, "Server disconnect.") @@ -312,13 +317,13 @@ def read_response(rfile, method, body_size_limit, include_body=True): if headers is None: raise HttpError(502, "Invalid headers.") - # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - if method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: + # Parse response body according to http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: content = "" elif include_body: content = read_http_body(rfile, headers, body_size_limit, False) else: - content = None # if include_body==False then a None content means the body should be read separately + content = None # if include_body==False then a None content means the body should be read separately return httpversion, code, msg, headers, content -- cgit v1.2.3 From 55c2133b69bc39ad43c6ce1ab14b32019878e56a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Jul 2014 01:47:24 +0200 Subject: add test case for mitmproxy/mitmproxy#295 --- test/test_certutils.py | 7 +++++++ test/test_tcp.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/test/test_certutils.py b/test/test_certutils.py index 176575ea..2d8c7841 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -67,6 +67,13 @@ class TestCertStore: c3 = ca.get_cert("bar.com", []) assert not c1 == c3 + def test_sans_change(self): + with tutils.tmpdir() as d: + ca = certutils.CertStore.from_store(d, "test") + _ = ca.get_cert("foo.com", ["*.bar.com"]) + cert, key = ca.get_cert("foo.bar.com", ["*.baz.com"]) + assert "*.baz.com" in cert.altnames + def test_overrides(self): with tutils.tmpdir() as d: ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") diff --git a/test/test_tcp.py b/test/test_tcp.py index 77146829..b8837655 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,5 +1,5 @@ import cStringIO, Queue, time, socket, random -from netlib import tcp, certutils, test +from netlib import tcp, certutils, test, certffi import mock import tutils from OpenSSL import SSL -- cgit v1.2.3 From 280d9b862575d79b391e28c80156697d2d674c48 Mon Sep 17 00:00:00 2001 From: Brad Peabody Date: Thu, 17 Jul 2014 22:34:29 -0700 Subject: added some additional functions for dealing with chunks - needed for mitmproxy streaming capability --- netlib/http.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++++- test/test_http.py | 70 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) diff --git a/netlib/http.py b/netlib/http.py index 21cde538..736c2c88 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -136,6 +136,49 @@ def read_chunked(fp, headers, limit, is_request): break return content +def read_next_chunk(fp, headers, is_request): + """ + Read next piece of a chunked HTTP body. Returns next piece of + content as a string or None if we hit the end. + """ + # TODO: see and understand the FIXME in read_chunked and + # see if we need to apply here? + content = "" + code = 400 if is_request else 502 + line = fp.readline(128) + if line == "": + raise HttpErrorConnClosed(code, "Connection closed prematurely") + try: + length = int(line, 16) + except ValueError: + # TODO: see note in this part of read_chunked() + raise HttpError(code, "Invalid chunked encoding length: %s"%line) + if length > 0: + content += fp.read(length) + print "read content: '%s'" % content + line = fp.readline(5) + if line == '': + raise HttpErrorConnClosed(code, "Connection closed prematurely") + if line != '\r\n': + raise HttpError(code, "Malformed chunked body: '%s' (len=%d)" % (line, length)) + if content == "": + content = None # normalize zero length to None, meaning end of chunked stream + return content # return this chunk + +def write_chunk(fp, content): + """ + Write a chunk with chunked encoding format, returns True + if there should be more chunks or False if you passed + None, meaning this was the last chunk. + """ + if content == None or content == "": + fp.write("0\r\n\r\n") + return False + fp.write("%x\r\n" % len(content)) + fp.write(content) + fp.write("\r\n") + return True + def get_header_tokens(headers, key): """ @@ -350,4 +393,22 @@ def read_http_body(rfile, headers, limit, is_request): not_done = rfile.read(1) if not_done: raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) - return content \ No newline at end of file + return content + +def expected_http_body_size(headers, is_request): + """ + Returns length of body expected or -1 if not + known and we should just read until end of + stream. + """ + if "content-length" in headers: + try: + l = int(headers["content-length"][0]) + if l < 0: + raise ValueError() + return l + except ValueError: + raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) + elif is_request: + return 0 + return -1 diff --git a/test/test_http.py b/test/test_http.py index df351dc7..e1dffab8 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -38,6 +38,57 @@ def test_read_chunked(): tutils.raises("too large", http.read_chunked, s, None, 2, True) +def test_read_next_chunk(): + s = cStringIO.StringIO( + "4\r\n" + + "mitm\r\n" + + "5\r\n" + + "proxy\r\n" + + "e\r\n" + + " in\r\n\r\nchunks.\r\n" + + "0\r\n" + + "\r\n") + assert http.read_next_chunk(s, None, False) == "mitm" + assert http.read_next_chunk(s, None, False) == "proxy" + assert http.read_next_chunk(s, None, False) == " in\r\n\r\nchunks." + assert http.read_next_chunk(s, None, False) == None + + s = cStringIO.StringIO("") + tutils.raises("closed prematurely", http.read_next_chunk, s, None, False) + + s = cStringIO.StringIO("1\r\na\r\n0\r\n") + http.read_next_chunk(s, None, False) + tutils.raises("closed prematurely", http.read_next_chunk, s, None, False) + + s = cStringIO.StringIO("1\r\nfoo") + tutils.raises("malformed chunked body", http.read_next_chunk, s, None, False) + + s = cStringIO.StringIO("foo\r\nfoo") + tutils.raises(http.HttpError, http.read_next_chunk, s, None, False) + +def test_write_chunk(): + + expected = ("" + + "4\r\n" + + "mitm\r\n" + + "5\r\n" + + "proxy\r\n" + + "e\r\n" + + " in\r\n\r\nchunks.\r\n" + + "0\r\n" + + "\r\n") + + s = cStringIO.StringIO() + http.write_chunk(s, "mitm") + http.write_chunk(s, "proxy") + http.write_chunk(s, " in\r\n\r\nchunks.") + http.write_chunk(s, None) + + print len(s.getvalue()) + print len(expected) + + assert s.getvalue() == expected + def test_connection_close(): h = odict.ODictCaseless() assert http.connection_close((1, 0), h) @@ -111,6 +162,25 @@ def test_read_http_body(): s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") assert http.read_http_body(s, h, 100, False) == "aaaaa" +def test_expected_http_body_size(): + # gibber in the content-length field + h = odict.ODictCaseless() + h["content-length"] = ["foo"] + tutils.raises(http.HttpError, http.expected_http_body_size, h, False) + # negative number in the content-length field + h = odict.ODictCaseless() + h["content-length"] = ["-7"] + tutils.raises(http.HttpError, http.expected_http_body_size, h, False) + # explicit length + h = odict.ODictCaseless() + h["content-length"] = ["5"] + assert http.expected_http_body_size(h, False) == 5 + # no length + h = odict.ODictCaseless() + assert http.expected_http_body_size(h, False) == -1 + # no length request + h = odict.ODictCaseless() + assert http.expected_http_body_size(h, True) == 0 def test_parse_http_protocol(): assert http.parse_http_protocol("HTTP/1.1") == (1, 1) -- cgit v1.2.3 From a7837846a2c20f3fc48406fc63845aec1a7efae0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 18 Jul 2014 22:55:25 +0200 Subject: temporarily replace DNTree with a simpler cert lookup mechanism, fix mitmproxy/mitmproxy#295 --- netlib/certutils.py | 99 +++++++++++++++++++++++++++----------------------- test/test_certutils.py | 58 ++++++++++++++--------------- 2 files changed, 82 insertions(+), 75 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 8aec5e82..87fb99c3 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,4 +1,5 @@ import os, ssl, time, datetime +import itertools from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error @@ -73,42 +74,44 @@ def dummy_cert(privkey, cacert, commonname, sans): return SSLCert(cert) -class _Node(UserDict.UserDict): - def __init__(self): - UserDict.UserDict.__init__(self) - self.value = None - - -class DNTree: - """ - Domain store that knows about wildcards. DNS wildcards are very - restricted - the only valid variety is an asterisk on the left-most - domain component, i.e.: - - *.foo.com - """ - def __init__(self): - self.d = _Node() - - def add(self, dn, cert): - parts = dn.split(".") - parts.reverse() - current = self.d - for i in parts: - current = current.setdefault(i, _Node()) - current.value = cert - - def get(self, dn): - parts = dn.split(".") - current = self.d - for i in reversed(parts): - if i in current: - current = current[i] - elif "*" in current: - return current["*"].value - else: - return None - return current.value +# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict. +# +# class _Node(UserDict.UserDict): +# def __init__(self): +# UserDict.UserDict.__init__(self) +# self.value = None +# +# +# class DNTree: +# """ +# Domain store that knows about wildcards. DNS wildcards are very +# restricted - the only valid variety is an asterisk on the left-most +# domain component, i.e.: +# +# *.foo.com +# """ +# def __init__(self): +# self.d = _Node() +# +# def add(self, dn, cert): +# parts = dn.split(".") +# parts.reverse() +# current = self.d +# for i in parts: +# current = current.setdefault(i, _Node()) +# current.value = cert +# +# def get(self, dn): +# parts = dn.split(".") +# current = self.d +# for i in reversed(parts): +# if i in current: +# current = current[i] +# elif "*" in current: +# return current["*"].value +# else: +# return None +# return current.value @@ -119,7 +122,7 @@ class CertStore: def __init__(self, privkey, cacert, dhparams=None): self.privkey, self.cacert = privkey, cacert self.dhparams = dhparams - self.certs = DNTree() + self.certs = dict() @classmethod def load_dhparam(klass, path): @@ -206,11 +209,11 @@ class CertStore: any SANs, and also the list of names provided as an argument. """ if cert.cn: - self.certs.add(cert.cn, (cert, privkey)) + self.certs[cert.cn] = (cert, privkey) for i in cert.altnames: - self.certs.add(i, (cert, privkey)) + self.certs[i] = (cert, privkey) for i in names: - self.certs.add(i, (cert, privkey)) + self.certs[i] = (cert, privkey) def get_cert(self, commonname, sans): """ @@ -223,12 +226,16 @@ class CertStore: Return None if the certificate could not be found or generated. """ - c = self.certs.get(commonname) - if not c: - c = dummy_cert(self.privkey, self.cacert, commonname, sans) - self.add_cert(c, None) - c = (c, None) - return (c[0], c[1] or self.privkey) + + potential_keys = [commonname] + sans + [(commonname, tuple(sans))] + name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) + if name: + c = self.certs[name] + else: + c = dummy_cert(self.privkey, self.cacert, commonname, sans), None + self.certs[(commonname, tuple(sans))] = c + + return c[0], (c[1] or self.privkey) def gen_pkey(self, cert): import certffi diff --git a/test/test_certutils.py b/test/test_certutils.py index 2d8c7841..95a7280e 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -3,34 +3,34 @@ from netlib import certutils, certffi import OpenSSL import tutils -class TestDNTree: - def test_simple(self): - d = certutils.DNTree() - d.add("foo.com", "foo") - d.add("bar.com", "bar") - assert d.get("foo.com") == "foo" - assert d.get("bar.com") == "bar" - assert not d.get("oink.com") - assert not d.get("oink") - assert not d.get("") - assert not d.get("oink.oink") - - d.add("*.match.org", "match") - assert not d.get("match.org") - assert d.get("foo.match.org") == "match" - assert d.get("foo.foo.match.org") == "match" - - def test_wildcard(self): - d = certutils.DNTree() - d.add("foo.com", "foo") - assert not d.get("*.foo.com") - d.add("*.foo.com", "wild") - - d = certutils.DNTree() - d.add("*", "foo") - assert d.get("foo.com") == "foo" - assert d.get("*.foo.com") == "foo" - assert d.get("com") == "foo" +# class TestDNTree: +# def test_simple(self): +# d = certutils.DNTree() +# d.add("foo.com", "foo") +# d.add("bar.com", "bar") +# assert d.get("foo.com") == "foo" +# assert d.get("bar.com") == "bar" +# assert not d.get("oink.com") +# assert not d.get("oink") +# assert not d.get("") +# assert not d.get("oink.oink") +# +# d.add("*.match.org", "match") +# assert not d.get("match.org") +# assert d.get("foo.match.org") == "match" +# assert d.get("foo.foo.match.org") == "match" +# +# def test_wildcard(self): +# d = certutils.DNTree() +# d.add("foo.com", "foo") +# assert not d.get("*.foo.com") +# d.add("*.foo.com", "wild") +# +# d = certutils.DNTree() +# d.add("*", "foo") +# assert d.get("foo.com") == "foo" +# assert d.get("*.foo.com") == "foo" +# assert d.get("com") == "foo" class TestCertStore: @@ -63,7 +63,7 @@ class TestCertStore: ca = certutils.CertStore.from_store(d, "test") c1 = ca.get_cert("foo.com", ["*.bar.com"]) c2 = ca.get_cert("foo.bar.com", []) - assert c1 == c2 + # assert c1 == c2 c3 = ca.get_cert("bar.com", []) assert not c1 == c3 -- cgit v1.2.3 From cba927885e8c683752f3042ce9f1746336f90168 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 18 Jul 2014 23:08:29 +0200 Subject: fix tests --- test/test_tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_tcp.py b/test/test_tcp.py index b8837655..911beccc 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -419,7 +419,7 @@ class TestPrivkeyGenNoFlags(test.ServerTestBase): def test_privkey(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - tutils.raises("unexpected eof", c.convert_to_ssl) + tutils.raises("sslv3 alert handshake failure", c.convert_to_ssl) -- cgit v1.2.3 From d382bb27bf4732def621cddb46fc4cc1d2143ab4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 19 Jul 2014 00:02:31 +0200 Subject: certstore: add support for asterisk form to DNTree replacement --- netlib/certutils.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 87fb99c3..308d6cf8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -215,6 +215,19 @@ class CertStore: for i in names: self.certs[i] = (cert, privkey) + @staticmethod + def asterisk_forms(dn): + parts = dn.split(".") + parts.reverse() + curr_dn = "" + dn_forms = ["*"] + for part in parts[:-1]: + curr_dn = "." + part + curr_dn # .example.com + dn_forms.append("*" + curr_dn) # *.example.com + if parts[-1] != "*": + dn_forms.append(parts[-1] + curr_dn) + return dn_forms + def get_cert(self, commonname, sans): """ Returns an (cert, privkey) tuple. @@ -227,7 +240,11 @@ class CertStore: Return None if the certificate could not be found or generated. """ - potential_keys = [commonname] + sans + [(commonname, tuple(sans))] + potential_keys = self.asterisk_forms(commonname) + for s in sans: + potential_keys.extend(self.asterisk_forms(s)) + potential_keys.append((commonname, tuple(sans))) + name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) if name: c = self.certs[name] -- cgit v1.2.3 From 6bd5df79f82a33b7e725afb5f279bda4cba41935 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Jul 2014 14:01:24 +0200 Subject: refactor response length handling --- netlib/http.py | 183 ++++++++++++++++++++++++------------------------------ test/test_http.py | 99 ++++++++--------------------- 2 files changed, 107 insertions(+), 175 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 736c2c88..f88e6652 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,4 +1,5 @@ import string, urlparse, binascii +import sys import odict, utils @@ -88,14 +89,14 @@ def read_headers(fp): # We're being liberal in what we accept, here. if i > 0: name = line[:i] - value = line[i+1:].strip() + value = line[i + 1:].strip() ret.append([name, value]) else: return None return odict.ODictCaseless(ret) -def read_chunked(fp, headers, limit, is_request): +def read_chunked(fp, limit, is_request): """ Read a chunked HTTP body. @@ -103,10 +104,9 @@ def read_chunked(fp, headers, limit, is_request): """ # FIXME: Should check if chunked is the final encoding in the headers # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2. - content = "" total = 0 code = 400 if is_request else 502 - while 1: + while True: line = fp.readline(128) if line == "": raise HttpErrorConnClosed(code, "Connection closed prematurely") @@ -114,70 +114,19 @@ def read_chunked(fp, headers, limit, is_request): try: length = int(line, 16) except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise HttpError(code, "Invalid chunked encoding length: %s"%line) - if not length: - break + raise HttpError(code, "Invalid chunked encoding length: %s" % line) total += length if limit is not None and total > limit: - msg = "HTTP Body too large."\ - " Limit is %s, chunked content length was at least %s"%(limit, total) + msg = "HTTP Body too large." \ + " Limit is %s, chunked content length was at least %s" % (limit, total) raise HttpError(code, msg) - content += fp.read(length) - line = fp.readline(5) - if line != '\r\n': + chunk = fp.read(length) + suffix = fp.readline(5) + if suffix != '\r\n': raise HttpError(code, "Malformed chunked body") - while 1: - line = fp.readline() - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line == '\r\n' or line == '\n': - break - return content - -def read_next_chunk(fp, headers, is_request): - """ - Read next piece of a chunked HTTP body. Returns next piece of - content as a string or None if we hit the end. - """ - # TODO: see and understand the FIXME in read_chunked and - # see if we need to apply here? - content = "" - code = 400 if is_request else 502 - line = fp.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - try: - length = int(line, 16) - except ValueError: - # TODO: see note in this part of read_chunked() - raise HttpError(code, "Invalid chunked encoding length: %s"%line) - if length > 0: - content += fp.read(length) - print "read content: '%s'" % content - line = fp.readline(5) - if line == '': - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n': - raise HttpError(code, "Malformed chunked body: '%s' (len=%d)" % (line, length)) - if content == "": - content = None # normalize zero length to None, meaning end of chunked stream - return content # return this chunk - -def write_chunk(fp, content): - """ - Write a chunk with chunked encoding format, returns True - if there should be more chunks or False if you passed - None, meaning this was the last chunk. - """ - if content == None or content == "": - fp.write("0\r\n\r\n") - return False - fp.write("%x\r\n" % len(content)) - fp.write(content) - fp.write("\r\n") - return True + yield line, chunk, '\r\n' + if length == 0: + return def get_header_tokens(headers, key): @@ -307,6 +256,7 @@ def parse_init_http(line): def connection_close(httpversion, headers): """ Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1 + Note that a connection should be closed as well if the response has been read until end of the stream. """ # At first, check if we have an explicit Connection header. if "connection" in headers: @@ -323,7 +273,7 @@ def connection_close(httpversion, headers): def parse_response_line(line): parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully + if len(parts) == 2: # handle missing message gracefully parts.append("") if len(parts) != 3: return None @@ -335,37 +285,38 @@ def parse_response_line(line): return (proto, code, msg) -def read_response(rfile, method, body_size_limit, include_body=True): +def read_response(rfile, request_method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. """ line = rfile.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message + if line == "\r\n" or line == "\n": # Possible leftover from previous message line = rfile.readline() if not line: raise HttpErrorConnClosed(502, "Server disconnect.") parts = parse_response_line(line) if not parts: - raise HttpError(502, "Invalid server response: %s"%repr(line)) + raise HttpError(502, "Invalid server response: %s" % repr(line)) proto, code, msg = parts httpversion = parse_http_protocol(proto) if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s"%repr(proto)) + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) headers = read_headers(rfile) if headers is None: raise HttpError(502, "Invalid headers.") - # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - if method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: - content = "" - elif include_body: - content = read_http_body(rfile, headers, body_size_limit, False) + if include_body: + content = read_http_body(rfile, headers, body_size_limit, request_method, code, False) else: - content = None # if include_body==False then a None content means the body should be read separately + content = None # if include_body==False then a None content means the body should be read separately return httpversion, code, msg, headers, content -def read_http_body(rfile, headers, limit, is_request): +def read_http_body(*args, **kwargs): + return "".join(content for _, content, _ in read_http_body_chunked(*args, **kwargs)) + + +def read_http_body_chunked(rfile, headers, limit, request_method, response_code, is_request, max_chunk_size=None): """ Read an HTTP message body: @@ -374,41 +325,69 @@ def read_http_body(rfile, headers, limit, is_request): limit: Size limit. is_request: True if the body to read belongs to a request, False otherwise """ - if has_chunked_encoding(headers): - content = read_chunked(rfile, headers, limit, is_request) - elif "content-length" in headers: - try: - l = int(headers["content-length"][0]) - if l < 0: - raise ValueError() - except ValueError: - raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) - if limit is not None and l > limit: - raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) - content = rfile.read(l) - elif is_request: - content = "" + if max_chunk_size is None: + max_chunk_size = limit or sys.maxint + + expected_size = expected_http_body_size(headers, is_request, request_method, response_code) + + if expected_size is None: + if has_chunked_encoding(headers): + # Python 3: yield from + for x in read_chunked(rfile, limit, is_request): + yield x + else: # pragma: nocover + raise HttpError(400 if is_request else 502, "Content-Length unknown but no chunked encoding") + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError(400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % (limit, expected_size)) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", rfile.read(chunk_size), "" + bytes_left -= chunk_size else: - content = rfile.read(limit if limit else -1) + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size not_done = rfile.read(1) if not_done: raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) - return content -def expected_http_body_size(headers, is_request): + +def expected_http_body_size(headers, is_request, request_method, response_code): """ - Returns length of body expected or -1 if not - known and we should just read until end of - stream. + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding) + - -1, if all data should be read until end of stream. """ + + # Determine response size according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if has_chunked_encoding(headers): + return None if "content-length" in headers: try: - l = int(headers["content-length"][0]) - if l < 0: + size = int(headers["content-length"][0]) + if size < 0: raise ValueError() - return l + return size except ValueError: - raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) - elif is_request: + raise HttpError(400 if is_request else 502, "Invalid content-length header: %s" % headers["content-length"]) + if is_request: return 0 - return -1 + return -1 \ No newline at end of file diff --git a/test/test_http.py b/test/test_http.py index e1dffab8..497e80e2 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -16,78 +16,31 @@ def test_has_chunked_encoding(): def test_read_chunked(): + + h = odict.ODictCaseless() + h["transfer-encoding"] = ["chunked"] s = cStringIO.StringIO("1\r\na\r\n0\r\n") - tutils.raises("closed prematurely", http.read_chunked, s, None, None, True) + + tutils.raises("malformed chunked body", http.read_http_body, s, h, None, "GET", None, True) s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert http.read_chunked(s, None, None, True) == "a" + assert http.read_http_body(s, h, None, "GET", None, True) == "a" s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") - assert http.read_chunked(s, None, None, True) == "a" + assert http.read_http_body(s, h, None, "GET", None, True) == "a" s = cStringIO.StringIO("\r\n") - tutils.raises("closed prematurely", http.read_chunked, s, None, None, True) + tutils.raises("closed prematurely", http.read_http_body, s, h, None, "GET", None, True) s = cStringIO.StringIO("1\r\nfoo") - tutils.raises("malformed chunked body", http.read_chunked, s, None, None, True) + tutils.raises("malformed chunked body", http.read_http_body, s, h, None, "GET", None, True) s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises(http.HttpError, http.read_chunked, s, None, None, True) + tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", None, True) s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - tutils.raises("too large", http.read_chunked, s, None, 2, True) - - -def test_read_next_chunk(): - s = cStringIO.StringIO( - "4\r\n" + - "mitm\r\n" + - "5\r\n" + - "proxy\r\n" + - "e\r\n" + - " in\r\n\r\nchunks.\r\n" + - "0\r\n" + - "\r\n") - assert http.read_next_chunk(s, None, False) == "mitm" - assert http.read_next_chunk(s, None, False) == "proxy" - assert http.read_next_chunk(s, None, False) == " in\r\n\r\nchunks." - assert http.read_next_chunk(s, None, False) == None - - s = cStringIO.StringIO("") - tutils.raises("closed prematurely", http.read_next_chunk, s, None, False) - - s = cStringIO.StringIO("1\r\na\r\n0\r\n") - http.read_next_chunk(s, None, False) - tutils.raises("closed prematurely", http.read_next_chunk, s, None, False) - - s = cStringIO.StringIO("1\r\nfoo") - tutils.raises("malformed chunked body", http.read_next_chunk, s, None, False) - - s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises(http.HttpError, http.read_next_chunk, s, None, False) - -def test_write_chunk(): - - expected = ("" + - "4\r\n" + - "mitm\r\n" + - "5\r\n" + - "proxy\r\n" + - "e\r\n" + - " in\r\n\r\nchunks.\r\n" + - "0\r\n" + - "\r\n") - - s = cStringIO.StringIO() - http.write_chunk(s, "mitm") - http.write_chunk(s, "proxy") - http.write_chunk(s, " in\r\n\r\nchunks.") - http.write_chunk(s, None) - - print len(s.getvalue()) - print len(expected) + tutils.raises("too large", http.read_http_body, s, h, 2, "GET", None, True) - assert s.getvalue() == expected def test_connection_close(): h = odict.ODictCaseless() @@ -114,73 +67,73 @@ def test_get_header_tokens(): def test_read_http_body_request(): h = odict.ODictCaseless() r = cStringIO.StringIO("testing") - assert http.read_http_body(r, h, None, True) == "" + assert http.read_http_body(r, h, None, "GET", None, True) == "" def test_read_http_body_response(): h = odict.ODictCaseless() s = cStringIO.StringIO("testing") - assert http.read_http_body(s, h, None, False) == "testing" + assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" def test_read_http_body(): # test default case h = odict.ODictCaseless() h["content-length"] = [7] s = cStringIO.StringIO("testing") - assert http.read_http_body(s, h, None, False) == "testing" + assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" # test content length: invalid header h["content-length"] = ["foo"] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) + tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", 200, False) # test content length: invalid header #2 h["content-length"] = [-1] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) + tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", 200, False) # test content length: content length > actual content h["content-length"] = [5] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) + tutils.raises(http.HttpError, http.read_http_body, s, h, 4, "GET", 200, False) # test content length: content length < actual content s = cStringIO.StringIO("testing") - assert len(http.read_http_body(s, h, None, False)) == 5 + assert len(http.read_http_body(s, h, None, "GET", 200, False)) == 5 # test no content length: limit > actual content h = odict.ODictCaseless() s = cStringIO.StringIO("testing") - assert len(http.read_http_body(s, h, 100, False)) == 7 + assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) + tutils.raises(http.HttpError, http.read_http_body, s, h, 4, "GET", 200, False) # test chunked h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - assert http.read_http_body(s, h, 100, False) == "aaaaa" + assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" def test_expected_http_body_size(): # gibber in the content-length field h = odict.ODictCaseless() h["content-length"] = ["foo"] - tutils.raises(http.HttpError, http.expected_http_body_size, h, False) + tutils.raises(http.HttpError, http.expected_http_body_size, h, False, "GET", 200) # negative number in the content-length field h = odict.ODictCaseless() h["content-length"] = ["-7"] - tutils.raises(http.HttpError, http.expected_http_body_size, h, False) + tutils.raises(http.HttpError, http.expected_http_body_size, h, False, "GET", 200) # explicit length h = odict.ODictCaseless() h["content-length"] = ["5"] - assert http.expected_http_body_size(h, False) == 5 + assert http.expected_http_body_size(h, False, "GET", 200) == 5 # no length h = odict.ODictCaseless() - assert http.expected_http_body_size(h, False) == -1 + assert http.expected_http_body_size(h, False, "GET", 200) == -1 # no length request h = odict.ODictCaseless() - assert http.expected_http_body_size(h, True) == 0 + assert http.expected_http_body_size(h, True, "GET", None) == 0 def test_parse_http_protocol(): assert http.parse_http_protocol("HTTP/1.1") == (1, 1) -- cgit v1.2.3 From 197dae918388b53fde6f79dcec9613a0ac1d4ba1 Mon Sep 17 00:00:00 2001 From: kronick Date: Tue, 29 Jul 2014 15:12:13 +0200 Subject: Made attribute optional (as it is in pyOpenSSL) See https://github.com/pyca/pyopenssl/commit/0d7e8a1af28ab22950b21afa3fd451cec7dd5fdc -- It looks like this constant isn't set on some platforms (including Raspberry Pi's libssl) --- netlib/tcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index c5bb7c4b..9c92ce38 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -17,7 +17,10 @@ OP_DONT_INSERT_EMPTY_FRAGMENTS = SSL.OP_DONT_INSERT_EMPTY_FRAGMENTS OP_EPHEMERAL_RSA = SSL.OP_EPHEMERAL_RSA OP_MICROSOFT_BIG_SSLV3_BUFFER = SSL.OP_MICROSOFT_BIG_SSLV3_BUFFER OP_MICROSOFT_SESS_ID_BUG = SSL.OP_MICROSOFT_SESS_ID_BUG -OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING +try: + OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING +except AttributeError: + pass OP_NETSCAPE_CA_DN_BUG = SSL.OP_NETSCAPE_CA_DN_BUG OP_NETSCAPE_CHALLENGE_BUG = SSL.OP_NETSCAPE_CHALLENGE_BUG OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG -- cgit v1.2.3 From 1c1167eda0a2757b8fb6588f0400d47020fdb1ab Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 16 Aug 2014 15:28:09 +0200 Subject: use passlib instead of md5crypt --- netlib/contrib/__init__.py | 0 netlib/contrib/md5crypt.py | 94 ---------------------------------------------- netlib/http_auth.py | 29 +++----------- setup.py | 2 +- test/test_http_auth.py | 8 +--- 5 files changed, 8 insertions(+), 125 deletions(-) delete mode 100644 netlib/contrib/__init__.py delete mode 100644 netlib/contrib/md5crypt.py diff --git a/netlib/contrib/__init__.py b/netlib/contrib/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/netlib/contrib/md5crypt.py b/netlib/contrib/md5crypt.py deleted file mode 100644 index d64ea8ac..00000000 --- a/netlib/contrib/md5crypt.py +++ /dev/null @@ -1,94 +0,0 @@ -# Based on FreeBSD src/lib/libcrypt/crypt.c 1.2 -# http://www.freebsd.org/cgi/cvsweb.cgi/~checkout~/src/lib/libcrypt/crypt.c?rev=1.2&content-type=text/plain - -# Original license: -# * "THE BEER-WARE LICENSE" (Revision 42): -# * wrote this file. As long as you retain this notice you -# * can do whatever you want with this stuff. If we meet some day, and you think -# * this stuff is worth it, you can buy me a beer in return. Poul-Henning Kamp - -# This port adds no further stipulations. I forfeit any copyright interest. - -import md5 - -def md5crypt(password, salt, magic='$1$'): - # /* The password first, since that is what is most unknown */ /* Then our magic string */ /* Then the raw salt */ - m = md5.new() - m.update(password + magic + salt) - - # /* Then just as many characters of the MD5(pw,salt,pw) */ - mixin = md5.md5(password + salt + password).digest() - for i in range(0, len(password)): - m.update(mixin[i % 16]) - - # /* Then something really weird... */ - # Also really broken, as far as I can tell. -m - i = len(password) - while i: - if i & 1: - m.update('\x00') - else: - m.update(password[0]) - i >>= 1 - - final = m.digest() - - # /* and now, just to make sure things don't run too fast */ - for i in range(1000): - m2 = md5.md5() - if i & 1: - m2.update(password) - else: - m2.update(final) - - if i % 3: - m2.update(salt) - - if i % 7: - m2.update(password) - - if i & 1: - m2.update(final) - else: - m2.update(password) - - final = m2.digest() - - # This is the bit that uses to64() in the original code. - - itoa64 = './0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' - - rearranged = '' - for a, b, c in ((0, 6, 12), (1, 7, 13), (2, 8, 14), (3, 9, 15), (4, 10, 5)): - v = ord(final[a]) << 16 | ord(final[b]) << 8 | ord(final[c]) - for i in range(4): - rearranged += itoa64[v & 0x3f]; v >>= 6 - - v = ord(final[11]) - for i in range(2): - rearranged += itoa64[v & 0x3f]; v >>= 6 - - return magic + salt + '$' + rearranged - -if __name__ == '__main__': - - def test(clear_password, the_hash): - magic, salt = the_hash[1:].split('$')[:2] - magic = '$' + magic + '$' - return md5crypt(clear_password, salt, magic) == the_hash - - test_cases = ( - (' ', '$1$yiiZbNIH$YiCsHZjcTkYd31wkgW8JF.'), - ('pass', '$1$YeNsbWdH$wvOF8JdqsoiLix754LTW90'), - ('____fifteen____', '$1$s9lUWACI$Kk1jtIVVdmT01p0z3b/hw1'), - ('____sixteen_____', '$1$dL3xbVZI$kkgqhCanLdxODGq14g/tW1'), - ('____seventeen____', '$1$NaH5na7J$j7y8Iss0hcRbu3kzoJs5V.'), - ('__________thirty-three___________', '$1$HO7Q6vzJ$yGwp2wbL5D7eOVzOmxpsy.'), - ('apache', '$apr1$J.w5a/..$IW9y6DR0oO/ADuhlMF5/X1') - ) - - for clearpw, hashpw in test_cases: - if test(clearpw, hashpw): - print '%s: pass' % clearpw - else: - print '%s: FAIL' % clearpw diff --git a/netlib/http_auth.py b/netlib/http_auth.py index b0451e3b..937b66f0 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,4 +1,4 @@ -from .contrib import md5crypt +from passlib.apache import HtpasswdFile import http from argparse import Action, ArgumentTypeError @@ -78,32 +78,14 @@ class PassManHtpasswd: """ Read usernames and passwords from an htpasswd file """ - def __init__(self, fp): + def __init__(self, path): """ Raises ValueError if htpasswd file is invalid. """ - self.usernames = {} - for l in fp: - l = l.strip().split(':') - if len(l) != 2: - raise ValueError("Invalid htpasswd file.") - parts = l[1].split('$') - if len(parts) != 4: - raise ValueError("Invalid htpasswd file.") - self.usernames[l[0]] = dict( - token = l[1], - dummy = parts[0], - magic = parts[1], - salt = parts[2], - hashed_password = parts[3] - ) + self.htpasswd = HtpasswdFile(path) def test(self, username, password_token): - ui = self.usernames.get(username) - if not ui: - return False - expected = md5crypt.md5crypt(password_token, ui["salt"], '$'+ui["magic"]+'$') - return expected==ui["token"] + return bool(self.htpasswd.check_password(username, password_token)) class PassManSingleUser: @@ -149,6 +131,5 @@ class NonanonymousAuthAction(AuthAction): class HtpasswdAuthAction(AuthAction): def getPasswordManager(self, s): - with open(s, "r") as f: - return PassManHtpasswd(f) + return PassManHtpasswd(s) diff --git a/setup.py b/setup.py index 5ba9f824..2dcfa248 100644 --- a/setup.py +++ b/setup.py @@ -88,5 +88,5 @@ setup( "Topic :: Software Development :: Testing :: Traffic Generation", "Topic :: Internet :: WWW/HTTP", ], - install_requires=["pyasn1>0.1.2", "pyopenssl>=0.14"], + install_requires=["pyasn1>0.1.2", "pyopenssl>=0.14", "passlib>=1.6.2"], ) diff --git a/test/test_http_auth.py b/test/test_http_auth.py index dd0273fe..176aa3ff 100644 --- a/test/test_http_auth.py +++ b/test/test_http_auth.py @@ -12,14 +12,10 @@ class TestPassManNonAnon: class TestPassManHtpasswd: def test_file_errors(self): - s = cStringIO.StringIO("foo") - tutils.raises("invalid htpasswd", http_auth.PassManHtpasswd, s) - s = cStringIO.StringIO("foo:bar$foo") - tutils.raises("invalid htpasswd", http_auth.PassManHtpasswd, s) + tutils.raises("malformed htpasswd file", http_auth.PassManHtpasswd, tutils.test_data.path("data/server.crt")) def test_simple(self): - f = open(tutils.test_data.path("data/htpasswd"),"rb") - pm = http_auth.PassManHtpasswd(f) + pm = http_auth.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) vals = ("basic", "test", "test") p = http.assemble_http_basic_auth(*vals) -- cgit v1.2.3 From 6d1b601ddf070ef1335be1804386fa0f4a2fcbd4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 16 Aug 2014 15:53:07 +0200 Subject: minor cleanups --- netlib/__init__.py | 1 + netlib/certffi.py | 9 +++++++-- netlib/certutils.py | 15 +++------------ netlib/http.py | 3 ++- netlib/http_auth.py | 3 ++- netlib/http_status.py | 1 + netlib/http_uastrings.py | 2 ++ netlib/odict.py | 1 + netlib/socks.py | 17 +++++++++-------- netlib/tcp.py | 3 ++- netlib/test.py | 3 ++- netlib/utils.py | 2 +- netlib/version.py | 2 ++ netlib/wsgi.py | 3 ++- test/test_tcp.py | 3 --- tools/getcertnames | 15 +++++++++++++-- 16 files changed, 50 insertions(+), 33 deletions(-) diff --git a/netlib/__init__.py b/netlib/__init__.py index e69de29b..9b4faa33 100644 --- a/netlib/__init__.py +++ b/netlib/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/certffi.py b/netlib/certffi.py index c5d7c95e..81dc72e8 100644 --- a/netlib/certffi.py +++ b/netlib/certffi.py @@ -1,7 +1,9 @@ +from __future__ import (absolute_import, print_function, division) import cffi import OpenSSL + xffi = cffi.FFI() -xffi.cdef (""" +xffi.cdef(""" struct rsa_meth_st { int flags; ...; @@ -18,6 +20,7 @@ xffi.verify( extra_compile_args=['-w'] ) + def handle(privkey): new = xffi.new("struct rsa_st*") newbuf = xffi.buffer(new) @@ -26,11 +29,13 @@ def handle(privkey): newbuf[:] = oldbuf[:] return new + def set_flags(privkey, val): hdl = handle(privkey) - hdl.meth.flags = val + hdl.meth.flags = val return privkey + def get_flags(privkey): hdl = handle(privkey) return hdl.meth.flags diff --git a/netlib/certutils.py b/netlib/certutils.py index 308d6cf8..18179917 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,11 +1,10 @@ +from __future__ import (absolute_import, print_function, division) import os, ssl, time, datetime import itertools from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -import tcp -import UserDict DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 # Generated with "openssl dhparam". It's too slow to generate this on startup. @@ -255,7 +254,7 @@ class CertStore: return c[0], (c[1] or self.privkey) def gen_pkey(self, cert): - import certffi + from . import certffi certffi.set_flags(self.privkey, 1) return self.privkey @@ -360,12 +359,4 @@ class SSLCert: continue for i in dec[0]: altnames.append(i[0].asOctets()) - return altnames - - - -def get_remote_cert(host, port, sni): - c = tcp.TCPClient((host, port)) - c.connect() - c.convert_to_ssl(sni=sni) - return c.cert + return altnames \ No newline at end of file diff --git a/netlib/http.py b/netlib/http.py index 774bac6c..a49f0588 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) import string, urlparse, binascii import sys -import odict, utils +from . import odict, utils class HttpError(Exception): diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 937b66f0..49f5925f 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) from passlib.apache import HtpasswdFile -import http from argparse import Action, ArgumentTypeError +from . import http class NullProxyAuth(): diff --git a/netlib/http_status.py b/netlib/http_status.py index 9f3f7e15..7dba2d56 100644 --- a/netlib/http_status.py +++ b/netlib/http_status.py @@ -1,3 +1,4 @@ +from __future__ import (absolute_import, print_function, division) CONTINUE = 100 SWITCHING = 101 diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py index 826c31a5..d0d145da 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http_uastrings.py @@ -1,3 +1,5 @@ +from __future__ import (absolute_import, print_function, division) + """ A small collection of useful user-agent header strings. These should be kept reasonably current to reflect common usage. diff --git a/netlib/odict.py b/netlib/odict.py index ea95a586..a0e1f694 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,3 +1,4 @@ +from __future__ import (absolute_import, print_function, division) import re, copy diff --git a/netlib/socks.py b/netlib/socks.py index 97df3478..1da5b6cc 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -1,7 +1,8 @@ +from __future__ import (absolute_import, print_function, division) import socket import struct -from array import array -from . import tcp, utils +import array +from . import tcp class SocksError(Exception): @@ -10,24 +11,24 @@ class SocksError(Exception): self.code = code -class VERSION: +class VERSION(object): SOCKS4 = 0x04 SOCKS5 = 0x05 -class CMD: +class CMD(object): CONNECT = 0x01 BIND = 0x02 UDP_ASSOCIATE = 0x03 -class ATYP: +class ATYP(object): IPV4_ADDRESS = 0x01 DOMAINNAME = 0x03 IPV6_ADDRESS = 0x04 -class REP: +class REP(object): SUCCEEDED = 0x00 GENERAL_SOCKS_SERVER_FAILURE = 0x01 CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02 @@ -39,7 +40,7 @@ class REP: ADDRESS_TYPE_NOT_SUPPORTED = 0x08 -class METHOD: +class METHOD(object): NO_AUTHENTICATION_REQUIRED = 0x00 GSSAPI = 0x01 USERNAME_PASSWORD = 0x02 @@ -56,7 +57,7 @@ class ClientGreeting(object): @classmethod def from_file(cls, f): ver, nmethods = struct.unpack("!BB", f.read(2)) - methods = array("B") + methods = array.array("B") methods.fromstring(f.read(nmethods)) return cls(ver, methods) diff --git a/netlib/tcp.py b/netlib/tcp.py index 9c92ce38..f49346a1 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) import select, socket, threading, sys, time, traceback from OpenSSL import SSL -import certutils +from . import certutils EINTR = 4 diff --git a/netlib/test.py b/netlib/test.py index bb0012ad..31a848a6 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,6 +1,7 @@ +from __future__ import (absolute_import, print_function, division) import threading, Queue, cStringIO -import tcp, certutils import OpenSSL +from . import tcp, certutils class ServerThread(threading.Thread): def __init__(self, server): diff --git a/netlib/utils.py b/netlib/utils.py index 69ba456a..79077ac6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,4 +1,4 @@ -import socket +from __future__ import (absolute_import, print_function, division) def isascii(s): diff --git a/netlib/version.py b/netlib/version.py index 25565d40..913f753a 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,3 +1,5 @@ +from __future__ import (absolute_import, print_function, division) + IVERSION = (0, 11) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index b576bdff..492803ab 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,5 +1,6 @@ +from __future__ import (absolute_import, print_function, division) import cStringIO, urllib, time, traceback -import odict, tcp +from . import odict, tcp class ClientConn: diff --git a/test/test_tcp.py b/test/test_tcp.py index 911beccc..bf681811 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -129,9 +129,6 @@ class TestServerSSL(test.ServerTestBase): c.wfile.flush() assert c.rfile.readline() == testval - def test_get_remote_cert(self): - assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1") - def test_get_current_cipher(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() diff --git a/tools/getcertnames b/tools/getcertnames index f39fc635..d22f4980 100755 --- a/tools/getcertnames +++ b/tools/getcertnames @@ -1,14 +1,25 @@ #!/usr/bin/env python import sys sys.path.insert(0, "../../") -from netlib import certutils +from netlib import tcp + + +def get_remote_cert(host, port, sni): + c = tcp.TCPClient((host, port)) + c.connect() + c.convert_to_ssl(sni=sni) + return c.cert if len(sys.argv) > 2: port = int(sys.argv[2]) else: port = 443 +if len(sys.argv) > 3: + sni = sys.argv[3] +else: + sni = None -cert = certutils.get_remote_cert(sys.argv[1], port, None) +cert = get_remote_cert(sys.argv[1], port, sni) print "CN:", cert.cn if cert.altnames: print "SANs:", -- cgit v1.2.3 From f93cd6a33505095aae3aaae4e14d5f9cb30184f6 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 16 Aug 2014 18:35:58 +0200 Subject: always use with statement to open files --- test/test_certutils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/test/test_certutils.py b/test/test_certutils.py index 95a7280e..55fcc1dc 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -116,11 +116,15 @@ class TestDummyCert: class TestSSLCert: def test_simple(self): - c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert"), "rb").read()) + with open(tutils.test_data.path("data/text_cert"), "rb") as f: + d = f.read() + c = certutils.SSLCert.from_pem(d) assert c.cn == "google.com" assert len(c.altnames) == 436 - c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert_2"), "rb").read()) + with open(tutils.test_data.path("data/text_cert_2"), "rb") as f: + d = f.read() + c = certutils.SSLCert.from_pem(d) assert c.cn == "www.inode.co.nz" assert len(c.altnames) == 2 assert c.digest("sha1") @@ -134,12 +138,15 @@ class TestSSLCert: c.has_expired def test_err_broken_sans(self): - c = certutils.SSLCert.from_pem(file(tutils.test_data.path("data/text_cert_weird1"), "rb").read()) + with open(tutils.test_data.path("data/text_cert_weird1"), "rb") as f: + d = f.read() + c = certutils.SSLCert.from_pem(d) # This breaks unless we ignore a decoding error. c.altnames def test_der(self): - d = file(tutils.test_data.path("data/dercert"),"rb").read() + with open(tutils.test_data.path("data/dercert"), "rb") as f: + d = f.read() s = certutils.SSLCert.from_der(d) assert s.cn -- cgit v1.2.3 From ef0e501877e74ba659be08d7d8b0781baff08598 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 19 Aug 2014 13:48:52 +0200 Subject: fix #46 --- README.mkd | 6 ++++++ setup.py | 1 + 2 files changed, 7 insertions(+) diff --git a/README.mkd b/README.mkd index f0a26dd5..7c96d396 100644 --- a/README.mkd +++ b/README.mkd @@ -6,3 +6,9 @@ respects, because both pathod and mitmproxy often need to violate standards. This means that protocols are implemented as small, well-contained and flexible functions, and are designed to allow misbehaviour when needed. + +Requirements +------------ + +* [Python](http://www.python.org) 2.7.x. +* Third-party packages listed in [setup.py](https://github.com/mitmproxy/netlib/blob/master/setup.py) \ No newline at end of file diff --git a/setup.py b/setup.py index 2dcfa248..80582772 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ setup( "Development Status :: 3 - Alpha", "Operating System :: POSIX", "Programming Language :: Python", + "Programming Language :: Python :: 2", "Topic :: Internet", "Topic :: Internet :: WWW/HTTP :: HTTP Servers", "Topic :: Software Development :: Testing", -- cgit v1.2.3 From 3d489f3bb7db6dda7b8476f6daa2177048c911ff Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 3 Sep 2014 17:15:50 +0200 Subject: adapt netlib.wsgi to changes in mitmproxy/mitmproxy#341 --- netlib/tcp.py | 8 ++++---- netlib/wsgi.py | 32 ++++++++++++++++---------------- test/test_wsgi.py | 30 +++++++++++++++--------------- 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index f49346a1..b386603c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -486,10 +486,10 @@ class TCPServer(object): # none. if traceback: exc = traceback.format_exc() - print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) - print >> fp, exc - print >> fp, '-'*40 + print('-' * 40, file=fp) + print("Error in processing of request from %s:%s" % (client_address.host, client_address.port), file=fp) + print(exc, file=fp) + print('-' * 40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 492803ab..568b1f9c 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -9,15 +9,15 @@ class ClientConn: class Flow: - def __init__(self, client_conn): - self.client_conn = client_conn + def __init__(self, address, request): + self.client_conn = ClientConn(address) + self.request = request class Request: - def __init__(self, client_conn, scheme, method, path, headers, content): + def __init__(self, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content - self.flow = Flow(client_conn) def date_time_string(): @@ -39,37 +39,37 @@ class WSGIAdaptor: def __init__(self, app, domain, port, sversion): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion - def make_environ(self, request, errsoc, **extra): - if '?' in request.path: - path_info, query = request.path.split('?', 1) + def make_environ(self, flow, errsoc, **extra): + if '?' in flow.request.path: + path_info, query = flow.request.path.split('?', 1) else: - path_info = request.path + path_info = flow.request.path query = '' environ = { 'wsgi.version': (1, 0), - 'wsgi.url_scheme': request.scheme, - 'wsgi.input': cStringIO.StringIO(request.content), + 'wsgi.url_scheme': flow.request.scheme, + 'wsgi.input': cStringIO.StringIO(flow.request.content), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, 'wsgi.run_once': False, 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': request.method, + 'REQUEST_METHOD': flow.request.method, 'SCRIPT_NAME': '', 'PATH_INFO': urllib.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': request.headers.get('Content-Type', [''])[0], - 'CONTENT_LENGTH': request.headers.get('Content-Length', [''])[0], + 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], + 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. 'SERVER_PROTOCOL': "HTTP/1.1", } environ.update(extra) - if request.flow.client_conn.address: - environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.flow.client_conn.address() + if flow.client_conn.address: + environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = flow.client_conn.address() - for key, value in request.headers.items(): + for key, value in flow.request.headers.items(): key = 'HTTP_' + key.upper().replace('-', '_') if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): environ[key] = value diff --git a/test/test_wsgi.py b/test/test_wsgi.py index 91a8ff7a..6e1fb146 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -2,11 +2,11 @@ import cStringIO, sys from netlib import wsgi, odict -def treq(): - cc = wsgi.ClientConn(("127.0.0.1", 8888)) +def tflow(): h = odict.ODictCaseless() h["test"] = ["value"] - return wsgi.Request(cc, "http", "GET", "/", h, "") + req = wsgi.Request("http", "GET", "/", h, "") + return wsgi.Flow(("127.0.0.1", 8888), req) class TestApp: @@ -24,22 +24,22 @@ class TestApp: class TestWSGI: def test_make_environ(self): w = wsgi.WSGIAdaptor(None, "foo", 80, "version") - tr = treq() - assert w.make_environ(tr, None) + tf = tflow() + assert w.make_environ(tf, None) - tr.path = "/foo?bar=voing" - r = w.make_environ(tr, None) + tf.request.path = "/foo?bar=voing" + r = w.make_environ(tf, None) assert r["QUERY_STRING"] == "bar=voing" def test_serve(self): ta = TestApp() w = wsgi.WSGIAdaptor(ta, "foo", 80, "version") - r = treq() - r.host = "foo" - r.port = 80 + f = tflow() + f.request.host = "foo" + f.request.port = 80 wfile = cStringIO.StringIO() - err = w.serve(r, wfile) + err = w.serve(f, wfile) assert ta.called assert not err @@ -49,11 +49,11 @@ class TestWSGI: def _serve(self, app): w = wsgi.WSGIAdaptor(app, "foo", 80, "version") - r = treq() - r.host = "foo" - r.port = 80 + f = tflow() + f.request.host = "foo" + f.request.port = 80 wfile = cStringIO.StringIO() - err = w.serve(r, wfile) + err = w.serve(f, wfile) return wfile.getvalue() def test_serve_empty_body(self): -- cgit v1.2.3 From ec628bc37d173b622e905e8012a08a7328cf7215 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 4 Sep 2014 01:10:44 +0200 Subject: fix tcp.Address inequality comparison --- netlib/tcp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index b386603c..5ecfca9d 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -216,10 +216,16 @@ class Address(object): def use_ipv6(self, b): self.family = socket.AF_INET6 if b else socket.AF_INET + def __repr__(self): + return repr(self.address) + def __eq__(self, other): other = Address.wrap(other) return (self.address, self.family) == (other.address, other.family) + def __ne__(self, other): + return not self.__eq__(other) + class _Connection(object): def get_current_cipher(self): -- cgit v1.2.3 From 4bf7f3c0ff5158cd178756bc2a414f506fb34e05 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 4 Sep 2014 16:55:02 +0200 Subject: set source_address if not manually specified --- netlib/tcp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 5ecfca9d..ede8682b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -319,6 +319,8 @@ class TCPClient(_Connection): if self.source_address: connection.bind(self.source_address()) connection.connect(self.address()) + if not self.source_address: + self.source_address = Address(connection.getsockname()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: -- cgit v1.2.3 From d9a731b23a930474adc35d6b4ebee68cd05a0940 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 4 Sep 2014 19:18:43 +0200 Subject: make inequality comparison work --- netlib/certutils.py | 3 +++ netlib/odict.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/netlib/certutils.py b/netlib/certutils.py index 18179917..84316882 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -285,6 +285,9 @@ class SSLCert: def __eq__(self, other): return self.digest("sha1") == other.digest("sha1") + def __ne__(self, other): + return not self.__eq__(other) + @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) diff --git a/netlib/odict.py b/netlib/odict.py index a0e1f694..1e51bb3f 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -24,6 +24,9 @@ class ODict: def __eq__(self, other): return self.lst == other.lst + def __ne__(self, other): + return not self.__eq__(other) + def __iter__(self): return self.lst.__iter__() -- cgit v1.2.3 From 3b81d678c4ff6ae8be563c3d087c4786648c24af Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 7 Sep 2014 11:24:41 +1200 Subject: Use print function after future import --- netlib/tcp.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index f49346a1..a5b9af22 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -486,10 +486,13 @@ class TCPServer(object): # none. if traceback: exc = traceback.format_exc() - print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) - print >> fp, exc - print >> fp, '-'*40 + print('-'*40, file=fp) + print( + "Error in processing of request from %s:%s" % ( + client_address.host, client_address.port + ), file=fp) + print(exc, file=fp) + print('-'*40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ -- cgit v1.2.3 From f4013dcd406c731c08c02789f80ccb364844c0ff Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 7 Sep 2014 12:47:17 +1200 Subject: Add a FIXME note for discarded credentials --- netlib/http.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/netlib/http.py b/netlib/http.py index 53a47d50..35e959cd 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -46,6 +46,9 @@ def parse_url(url): if not scheme: return None if '@' in netloc: + # FIXME: Consider what to do with the discarded credentials here Most + # probably we should extend the signature to return these as a separate + # value. _, netloc = string.rsplit(netloc, '@', maxsplit=1) if ':' in netloc: host, port = string.rsplit(netloc, ':', maxsplit=1) -- cgit v1.2.3 From 07990fdcc231f28c7ce3b3486cf7b423f77dcc67 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 8 Sep 2014 18:59:25 +1200 Subject: Better MANIFEST.in --- MANIFEST.in | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index 2c1bf265..52d3398f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include README.mkd recursive-include test * recursive-include netlib * -recursive-exclude test *.swo *.swp *.pyc +recursive-exclude * *.pyc *.pyo +prune */__pycache__ -- cgit v1.2.3 From f90ea89e69b3ff9fb612b0ee6024f5546f198ca6 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Sep 2014 18:38:05 +0200 Subject: more verbose errors --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 2704eeae..0a3c4ff9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -308,7 +308,7 @@ class TCPClient(_Connection): try: self.connection.do_handshake() except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%str(v)) + raise NetLibError("SSL handshake error: %s"%repr(v)) self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) @@ -417,7 +417,7 @@ class BaseHandler(_Connection): try: self.connection.do_handshake() except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%str(v)) + raise NetLibError("SSL handshake error: %s"%repr(v)) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) -- cgit v1.2.3 From dd2adc791d7d7d6be91789ec83a9be87c10fef24 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Sep 2014 18:58:07 +0200 Subject: improve distribution --- .travis.yml | 5 +---- requirements.txt | 4 ++-- setup.py | 62 +++++++++++++++++++++++++++++++++++--------------------- 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/.travis.yml b/.travis.yml index 653f0d8c..7c4dca92 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,10 +3,7 @@ python: - "2.7" # command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors install: - - "pip install -r requirements.txt --use-mirrors" - - "pip install ." - - "pip install --upgrade git+https://github.com/mitmproxy/pathod.git" - - "pip install -r test/requirements.txt --use-mirrors" + - "pip install --src .. -r requirements.txt" # command to run tests, e.g. python setup.py test script: - "nosetests --with-cov --cov-report term-missing" diff --git a/requirements.txt b/requirements.txt index 7b45f7c3..e3ef3a23 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pyasn1>=0.1.7 -pyOpenSSL>=0.14 \ No newline at end of file +-e git+https://github.com/mitmproxy/pathod.git#egg=pathod +-e .[dev] \ No newline at end of file diff --git a/setup.py b/setup.py index 80582772..8609662f 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ from distutils.core import setup import fnmatch, os.path from netlib import version + def _fnmatch(name, patternList): for i in patternList: if fnmatch.fnmatch(name, i): @@ -65,29 +66,44 @@ def findPackages(path, dataExclude=[]): return packages, package_data -long_description = file("README.mkd", "rb").read() +with open("README.mkd", "rb") as f: + long_description = f.read() + packages, package_data = findPackages("netlib") setup( - name = "netlib", - version = version.VERSION, - description = "A collection of network utilities used by pathod and mitmproxy.", - long_description = long_description, - author = "Aldo Cortesi", - author_email = "aldo@corte.si", - url = "http://github.com/mitmproxy/netlib", - packages = packages, - package_data = package_data, - classifiers = [ - "License :: OSI Approved :: MIT License", - "Development Status :: 3 - Alpha", - "Operating System :: POSIX", - "Programming Language :: Python", - "Programming Language :: Python :: 2", - "Topic :: Internet", - "Topic :: Internet :: WWW/HTTP :: HTTP Servers", - "Topic :: Software Development :: Testing", - "Topic :: Software Development :: Testing :: Traffic Generation", - "Topic :: Internet :: WWW/HTTP", - ], - install_requires=["pyasn1>0.1.2", "pyopenssl>=0.14", "passlib>=1.6.2"], + name="netlib", + version=version.VERSION, + description="A collection of network utilities used by pathod and mitmproxy.", + long_description=long_description, + author="Aldo Cortesi", + author_email="aldo@corte.si", + url="http://github.com/mitmproxy/netlib", + packages=packages, + package_data=package_data, + classifiers=[ + "License :: OSI Approved :: MIT License", + "Development Status :: 3 - Alpha", + "Operating System :: POSIX", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Topic :: Internet", + "Topic :: Internet :: WWW/HTTP", + "Topic :: Internet :: WWW/HTTP :: HTTP Servers", + "Topic :: Software Development :: Testing", + "Topic :: Software Development :: Testing :: Traffic Generation", + ], + install_requires=[ + "pyasn1>=0.1.7", + "pyOpenSSL>=0.14", + "passlib>=1.6.2" + ], + extras_require={ + 'dev': [ + "mock>=1.0.1", + "nose>=1.3.0", + "nose-cov>=1.6", + "coveralls>=0.4.1", + "pathod>=0.10" + ] + } ) -- cgit v1.2.3 From 63c1efd3946ce672640b43b005d12f8f117d670a Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 9 Sep 2014 10:08:56 +1200 Subject: Remove avoidable imports from OpenSSL Fixes #38 --- netlib/tcp.py | 59 +++++++++++++++++++++----------------------------------- test/test_tcp.py | 2 +- 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 2704eeae..080797b4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,6 +1,12 @@ from __future__ import (absolute_import, print_function, division) -import select, socket, threading, sys, time, traceback +import select +import socket +import sys +import threading +import time +import traceback from OpenSSL import SSL + from . import certutils @@ -11,35 +17,6 @@ SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD -OP_ALL = SSL.OP_ALL -OP_CIPHER_SERVER_PREFERENCE = SSL.OP_CIPHER_SERVER_PREFERENCE -OP_COOKIE_EXCHANGE = SSL.OP_COOKIE_EXCHANGE -OP_DONT_INSERT_EMPTY_FRAGMENTS = SSL.OP_DONT_INSERT_EMPTY_FRAGMENTS -OP_EPHEMERAL_RSA = SSL.OP_EPHEMERAL_RSA -OP_MICROSOFT_BIG_SSLV3_BUFFER = SSL.OP_MICROSOFT_BIG_SSLV3_BUFFER -OP_MICROSOFT_SESS_ID_BUG = SSL.OP_MICROSOFT_SESS_ID_BUG -try: - OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING -except AttributeError: - pass -OP_NETSCAPE_CA_DN_BUG = SSL.OP_NETSCAPE_CA_DN_BUG -OP_NETSCAPE_CHALLENGE_BUG = SSL.OP_NETSCAPE_CHALLENGE_BUG -OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG -OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG -OP_NO_QUERY_MTU = SSL.OP_NO_QUERY_MTU -OP_NO_SSLv2 = SSL.OP_NO_SSLv2 -OP_NO_SSLv3 = SSL.OP_NO_SSLv3 -OP_NO_TICKET = SSL.OP_NO_TICKET -OP_NO_TLSv1 = SSL.OP_NO_TLSv1 -OP_PKCS1_CHECK_1 = SSL.OP_PKCS1_CHECK_1 -OP_PKCS1_CHECK_2 = SSL.OP_PKCS1_CHECK_2 -OP_SINGLE_DH_USE = SSL.OP_SINGLE_DH_USE -OP_SSLEAY_080_CLIENT_DH_BUG = SSL.OP_SSLEAY_080_CLIENT_DH_BUG -OP_SSLREF2_REUSE_CERT_TYPE_BUG = SSL.OP_SSLREF2_REUSE_CERT_TYPE_BUG -OP_TLS_BLOCK_PADDING_BUG = SSL.OP_TLS_BLOCK_PADDING_BUG -OP_TLS_D5_BUG = SSL.OP_TLS_D5_BUG -OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG - class NetLibError(Exception): pass class NetLibDisconnect(NetLibError): pass @@ -251,7 +228,8 @@ class _Connection(object): def close(self): """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. + Does a hard close of the socket, i.e. a shutdown, followed by a + close. """ try: if self.ssl_established: @@ -273,6 +251,7 @@ class _Connection(object): class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 + def __init__(self, address, source_address=None): self.address = Address.wrap(address) self.source_address = Address.wrap(source_address) if source_address else None @@ -284,6 +263,8 @@ class TCPClient(_Connection): def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None, cipher_list=None): """ cert: Path to a file containing both client cert and private key. + + options: A bit field consisting of OpenSSL.SSL.OP_* values """ context = SSL.Context(method) if cipher_list: @@ -358,18 +339,22 @@ class BaseHandler(_Connection): dhparams=None, ca_file=None): """ cert: A certutils.SSLCert object. + method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD + handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: - connection.get_servername() + connection.get_servername() + + options: A bit field consisting of OpenSSL.SSL.OP_* values - And you can specify the connection keys as follows: + And you can specify the connection keys as follows: - new_context = Context(TLSv1_METHOD) - new_context.use_privatekey(key) - new_context.use_certificate(cert) - connection.set_context(new_context) + new_context = Context(TLSv1_METHOD) + new_context.use_privatekey(key) + new_context.use_certificate(cert) + connection.set_context(new_context) The request_client_cert argument requires some explanation. We're supposed to be able to do this with no negative effects - if the diff --git a/test/test_tcp.py b/test/test_tcp.py index bf681811..78278909 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -123,7 +123,7 @@ class TestServerSSL(test.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(sni="foo.com", options=tcp.OP_ALL) + c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL) testval = "echo!\n" c.wfile.write(testval) c.wfile.flush() -- cgit v1.2.3 From 414a0a1602b27e9ed1d5aae42ad06d781a5461a6 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 17 Sep 2014 11:47:07 +1200 Subject: Adjust for state object protocol changes in mitmproxy. --- netlib/odict.py | 22 ++++++++++++---------- test/test_odict.py | 6 +++--- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/netlib/odict.py b/netlib/odict.py index 1e51bb3f..3fb38d85 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -101,16 +101,6 @@ class ODict: def items(self): return self.lst[:] - def _get_state(self): - return [tuple(i) for i in self.lst] - - def _load_state(self, state): - self.list = [list(i) for i in state] - - @classmethod - def _from_state(klass, state): - return klass([list(i) for i in state]) - def copy(self): """ Returns a copy of this object. @@ -171,6 +161,18 @@ class ODict: self.lst = nlst return count + # Implement the StateObject protocol from mitmproxy + def get_state(self): + return [tuple(i) for i in self.lst] + + def load_state(self, state): + self.list = [list(i) for i in state] + + @classmethod + def from_state(klass, state): + return klass([list(i) for i in state]) + + class ODictCaseless(ODict): """ diff --git a/test/test_odict.py b/test/test_odict.py index 794956be..a682d7eb 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -28,10 +28,10 @@ class TestODict: self.od.add("foo", 1) self.od.add("foo", 2) self.od.add("bar", 3) - state = self.od._get_state() - nd = odict.ODict._from_state(state) + state = self.od.get_state() + nd = odict.ODict.from_state(state) assert nd == self.od - nd._load_state(state) + nd.load_state(state) def test_dictToHeader2(self): self.od["one"] = ["uno"] -- cgit v1.2.3 From 0e307964698379a973e8a1f96e3145188b9c0b8d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 17 Sep 2014 14:04:26 +1200 Subject: Short-form getstate --- netlib/odict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/odict.py b/netlib/odict.py index 3fb38d85..61448e6d 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -162,7 +162,7 @@ class ODict: return count # Implement the StateObject protocol from mitmproxy - def get_state(self): + def get_state(self, short=False): return [tuple(i) for i in self.lst] def load_state(self, state): -- cgit v1.2.3 From e73a2dbab12296d9787164b5b33320b6d31784d5 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 28 Sep 2014 03:15:26 +0200 Subject: minor changes --- .gitignore | 3 ++- netlib/tcp.py | 6 +++--- netlib/test.py | 6 +++--- setup.py | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 26c449d1..ef830f75 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,6 @@ MANIFEST *.swp *.swo .coverage -.idea +.idea/ __pycache__ +netlib.egg-info/ \ No newline at end of file diff --git a/netlib/tcp.py b/netlib/tcp.py index c8a02ab4..4f5423e4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -471,7 +471,7 @@ class TCPServer(object): self.socket.close() self.handle_shutdown() - def handle_error(self, request, client_address, fp=sys.stderr): + def handle_error(self, connection, client_address, fp=sys.stderr): """ Called when handle_client_connection raises an exception. """ @@ -479,13 +479,13 @@ class TCPServer(object): # none. if traceback: exc = traceback.format_exc() - print('-'*40, file=fp) + print('-' * 40, file=fp) print( "Error in processing of request from %s:%s" % ( client_address.host, client_address.port ), file=fp) print(exc, file=fp) - print('-'*40, file=fp) + print('-' * 40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ diff --git a/netlib/test.py b/netlib/test.py index 31a848a6..fb468907 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -64,7 +64,7 @@ class TServer(tcp.TCPServer): key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD - options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 + options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 else: method = tcp.SSLv23_METHOD options = None @@ -80,7 +80,7 @@ class TServer(tcp.TCPServer): h.handle() h.finish() - def handle_error(self, request, client_address): + def handle_error(self, connection, client_address, fp=None): s = cStringIO.StringIO() - tcp.TCPServer.handle_error(self, request, client_address, s) + tcp.TCPServer.handle_error(self, connection, client_address, s) self.q.put(s.getvalue()) diff --git a/setup.py b/setup.py index 8609662f..d144855f 100644 --- a/setup.py +++ b/setup.py @@ -103,7 +103,7 @@ setup( "nose>=1.3.0", "nose-cov>=1.6", "coveralls>=0.4.1", - "pathod>=0.10" + "pathod>=0.%s" % version.MINORVERSION ] } ) -- cgit v1.2.3 From aee8acbec672a918bb8733159654f57a0c3ab8e2 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 1 Oct 2014 23:22:53 +0200 Subject: distutils -> setuptools --- MANIFEST.in | 6 ++--- setup.py | 84 +++++++++++-------------------------------------------------- 2 files changed, 16 insertions(+), 74 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 52d3398f..bd59f003 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,3 @@ -include README.mkd +include LICENSE README.mkd recursive-include test * -recursive-include netlib * -recursive-exclude * *.pyc *.pyo -prune */__pycache__ +recursive-exclude * *.pyc *.pyo *.swo *.swp \ No newline at end of file diff --git a/setup.py b/setup.py index d144855f..ad3073fc 100644 --- a/setup.py +++ b/setup.py @@ -1,85 +1,25 @@ -from distutils.core import setup -import fnmatch, os.path +from setuptools import setup, find_packages +from codecs import open +import os from netlib import version +# Based on https://github.com/pypa/sampleproject/blob/master/setup.py +# and https://python-packaging-user-guide.readthedocs.org/ -def _fnmatch(name, patternList): - for i in patternList: - if fnmatch.fnmatch(name, i): - return True - return False +here = os.path.abspath(os.path.dirname(__file__)) - -def _splitAll(path): - parts = [] - h = path - while 1: - if not h: - break - h, t = os.path.split(h) - parts.append(t) - parts.reverse() - return parts - - -def findPackages(path, dataExclude=[]): - """ - Recursively find all packages and data directories rooted at path. Note - that only data _directories_ and their contents are returned - - non-Python files at module scope are not, and should be manually - included. - - dataExclude is a list of fnmatch-compatible expressions for files and - directories that should not be included in pakcage_data. - - Returns a (packages, package_data) tuple, ready to be passed to the - corresponding distutils.core.setup arguments. - """ - packages = [] - datadirs = [] - for root, dirs, files in os.walk(path, topdown=True): - if "__init__.py" in files: - p = _splitAll(root) - packages.append(".".join(p)) - else: - dirs[:] = [] - if packages: - datadirs.append(root) - - # Now we recurse into the data directories - package_data = {} - for i in datadirs: - if not _fnmatch(i, dataExclude): - parts = _splitAll(i) - module = ".".join(parts[:-1]) - acc = package_data.get(module, []) - for root, dirs, files in os.walk(i, topdown=True): - sub = os.path.join(*_splitAll(root)[1:]) - if not _fnmatch(sub, dataExclude): - for fname in files: - path = os.path.join(sub, fname) - if not _fnmatch(path, dataExclude): - acc.append(path) - else: - dirs[:] = [] - package_data[module] = acc - return packages, package_data - - -with open("README.mkd", "rb") as f: +with open(os.path.join(here, 'README.mkd'), encoding='utf-8') as f: long_description = f.read() -packages, package_data = findPackages("netlib") setup( name="netlib", version=version.VERSION, description="A collection of network utilities used by pathod and mitmproxy.", long_description=long_description, + url="http://github.com/mitmproxy/netlib", author="Aldo Cortesi", author_email="aldo@corte.si", - url="http://github.com/mitmproxy/netlib", - packages=packages, - package_data=package_data, + license="MIT", classifiers=[ "License :: OSI Approved :: MIT License", "Development Status :: 3 - Alpha", @@ -92,6 +32,10 @@ setup( "Topic :: Software Development :: Testing", "Topic :: Software Development :: Testing :: Traffic Generation", ], + + packages=find_packages(), + include_package_data=True, + install_requires=[ "pyasn1>=0.1.7", "pyOpenSSL>=0.14", @@ -106,4 +50,4 @@ setup( "pathod>=0.%s" % version.MINORVERSION ] } -) +) \ No newline at end of file -- cgit v1.2.3 From 274688172d62131ddf30cf67e6c084e0e928d4bf Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 8 Oct 2014 18:40:46 +0200 Subject: fix mitmproxy/mitmproxy#373 --- netlib/tcp.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 4f5423e4..aca4bd1b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -232,16 +232,25 @@ class _Connection(object): close. """ try: - if self.ssl_established: + if type(self.connection) == SSL.Connection: self.connection.shutdown() self.connection.sock_shutdown(socket.SHUT_WR) else: self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent. - #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): # pragma: no cover - pass + + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any + # pending readable data could lead to an immediate RST being sent (which is the case on Windows). + # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + # + # Do not call this for an SSL.Connection: + # If the SSL handshake failed at the first place, OpenSSL's SSL_read tries to negotiate the connection + # again at this point, calls the SNI handler and segfaults. + # https://github.com/mitmproxy/mitmproxy/issues/373#issuecomment-58383499 + # (if this turns out to be an issue for successful SSL connections, + # we should check for ssl_established or access the socket directly) + + while self.connection.recv(4096): # pragma: no cover + pass self.connection.close() except (socket.error, SSL.Error, IOError): # Socket probably already closed @@ -281,7 +290,6 @@ class TCPClient(_Connection): except SSL.Error, v: raise NetLibError("SSL client certificate error: %s"%str(v)) self.connection = SSL.Connection(context, self.connection) - self.ssl_established = True if sni: self.sni = sni self.connection.set_tlsext_host_name(sni) @@ -290,6 +298,7 @@ class TCPClient(_Connection): self.connection.do_handshake() except SSL.Error, v: raise NetLibError("SSL handshake error: %s"%repr(v)) + self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) @@ -397,12 +406,12 @@ class BaseHandler(_Connection): """ ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) self.connection = SSL.Connection(ctx, self.connection) - self.ssl_established = True self.connection.set_accept_state() try: self.connection.do_handshake() except SSL.Error, v: raise NetLibError("SSL handshake error: %s"%repr(v)) + self.ssl_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) -- cgit v1.2.3 From fdb6f5552d43d7ab02320ccd7e6d58750e33c4c4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 8 Oct 2014 20:46:30 +0200 Subject: CertStore: add support for cert chains --- netlib/certutils.py | 70 +++++++++++++++++++++++++++++--------------------- netlib/tcp.py | 6 ++--- test/test_certutils.py | 14 +++++----- test/test_tcp.py | 8 +++--- 4 files changed, 55 insertions(+), 43 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index fe067ca1..c9e6df26 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -113,13 +113,21 @@ def dummy_cert(privkey, cacert, commonname, sans): # return current.value +class CertStoreEntry(object): + def __init__(self, cert, pkey=None, chain_file=None): + self.cert = cert + self.pkey = pkey + self.chain_file = chain_file + class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, privkey, cacert, dhparams=None): - self.privkey, self.cacert = privkey, cacert + def __init__(self, default_pkey, default_ca, default_chain_file, dhparams=None): + self.default_pkey = default_pkey + self.default_ca = default_ca + self.default_chain_file = default_chain_file self.dhparams = dhparams self.certs = dict() @@ -142,21 +150,21 @@ class CertStore: return dh @classmethod - def from_store(klass, path, basename): - p = os.path.join(path, basename + "-ca.pem") - if not os.path.exists(p): - key, ca = klass.create_store(path, basename) + def from_store(cls, path, basename): + ca_path = os.path.join(path, basename + "-ca.pem") + if not os.path.exists(ca_path): + key, ca = cls.create_store(path, basename) else: - p = os.path.join(path, basename + "-ca.pem") - raw = file(p, "rb").read() + with open(ca_path, "rb") as f: + raw = f.read() ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) - dhp = os.path.join(path, basename + "-dhparam.pem") - dh = klass.load_dhparam(dhp) - return klass(key, ca, dh) + dh_path = os.path.join(path, basename + "-dhparam.pem") + dh = cls.load_dhparam(dh_path) + return cls(key, ca, ca_path, dh) @classmethod - def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): + def create_store(cls, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): if not os.path.exists(path): os.makedirs(path) @@ -194,25 +202,29 @@ class CertStore: return key, ca def add_cert_file(self, spec, path): - raw = file(path, "rb").read() - cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + with open(path, "rb") as f: + raw = f.read() + cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)) try: - privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + pkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) except Exception: - privkey = None - self.add_cert(SSLCert(cert), privkey, spec) + pkey = None + self.add_cert( + CertStoreEntry(cert, pkey, path), + spec + ) - def add_cert(self, cert, privkey, *names): + def add_cert(self, entry, *names): """ Adds a cert to the certstore. We register the CN in the cert plus any SANs, and also the list of names provided as an argument. """ - if cert.cn: - self.certs[cert.cn] = (cert, privkey) - for i in cert.altnames: - self.certs[i] = (cert, privkey) + if entry.cert.cn: + self.certs[entry.cert.cn] = entry + for i in entry.cert.altnames: + self.certs[i] = entry for i in names: - self.certs[i] = (cert, privkey) + self.certs[i] = entry @staticmethod def asterisk_forms(dn): @@ -246,17 +258,17 @@ class CertStore: name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) if name: - c = self.certs[name] + entry = self.certs[name] else: - c = dummy_cert(self.privkey, self.cacert, commonname, sans), None - self.certs[(commonname, tuple(sans))] = c + entry = CertStoreEntry(cert=dummy_cert(self.default_pkey, self.default_ca, commonname, sans)) + self.certs[(commonname, tuple(sans))] = entry - return c[0], (c[1] or self.privkey) + return entry.cert, (entry.pkey or self.default_pkey), (entry.chain_file or self.default_chain_file) def gen_pkey(self, cert): from . import certffi - certffi.set_flags(self.privkey, 1) - return self.privkey + certffi.set_flags(self.default_pkey, 1) + return self.default_pkey class _GeneralName(univ.Choice): diff --git a/netlib/tcp.py b/netlib/tcp.py index aca4bd1b..8e87bec8 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -345,7 +345,7 @@ class BaseHandler(_Connection): def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=None, cipher_list=None, - dhparams=None, ca_file=None): + dhparams=None, chain_file=None): """ cert: A certutils.SSLCert object. @@ -377,8 +377,8 @@ class BaseHandler(_Connection): ctx = SSL.Context(method) if not options is None: ctx.set_options(options) - if ca_file: - ctx.load_verify_locations(ca_file) + if chain_file: + ctx.load_verify_locations(chain_file) if cipher_list: try: ctx.set_cipher_list(cipher_list) diff --git a/test/test_certutils.py b/test/test_certutils.py index 55fcc1dc..f68751ec 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -42,7 +42,7 @@ class TestCertStore: ca2 = certutils.CertStore.from_store(d, "test") assert ca2.get_cert("foo", []) - assert ca.cacert.get_serial_number() == ca2.cacert.get_serial_number() + assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() def test_create_tmp(self): with tutils.tmpdir() as d: @@ -52,7 +52,7 @@ class TestCertStore: assert ca.get_cert("*.foo.com", []) r = ca.get_cert("*.foo.com", []) - assert r[1] == ca.privkey + assert r[1] == ca.default_pkey def test_add_cert(self): with tutils.tmpdir() as d: @@ -71,14 +71,14 @@ class TestCertStore: with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") _ = ca.get_cert("foo.com", ["*.bar.com"]) - cert, key = ca.get_cert("foo.bar.com", ["*.baz.com"]) + cert, key, chain_file = ca.get_cert("foo.bar.com", ["*.baz.com"]) assert "*.baz.com" in cert.altnames def test_overrides(self): with tutils.tmpdir() as d: ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test") - assert not ca1.cacert.get_serial_number() == ca2.cacert.get_serial_number() + assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() dc = ca2.get_cert("foo.com", []) dcp = os.path.join(d, "dc") @@ -98,7 +98,7 @@ class TestCertStore: cert = ca1.get_cert("foo.com", []) assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1 finally: - certffi.set_flags(ca2.privkey, 0) + certffi.set_flags(ca2.default_pkey, 0) class TestDummyCert: @@ -106,8 +106,8 @@ class TestDummyCert: with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") r = certutils.dummy_cert( - ca.privkey, - ca.cacert, + ca.default_pkey, + ca.default_ca, "foo.com", ["one.com", "two.com", "*.three.com"] ) diff --git a/test/test_tcp.py b/test/test_tcp.py index 78278909..0eadac47 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -393,7 +393,7 @@ class TestPrivkeyGen(test.ServerTestBase): with tutils.tmpdir() as d: ca1 = certutils.CertStore.from_store(d, "test2") ca2 = certutils.CertStore.from_store(d, "test3") - cert, _ = ca1.get_cert("foo.com", []) + cert, _, _ = ca1.get_cert("foo.com", []) key = ca2.gen_pkey(cert) self.convert_to_ssl(cert, key) @@ -409,9 +409,9 @@ class TestPrivkeyGenNoFlags(test.ServerTestBase): with tutils.tmpdir() as d: ca1 = certutils.CertStore.from_store(d, "test2") ca2 = certutils.CertStore.from_store(d, "test3") - cert, _ = ca1.get_cert("foo.com", []) - certffi.set_flags(ca2.privkey, 0) - self.convert_to_ssl(cert, ca2.privkey) + cert, _, _ = ca1.get_cert("foo.com", []) + certffi.set_flags(ca2.default_pkey, 0) + self.convert_to_ssl(cert, ca2.default_pkey) def test_privkey(self): c = tcp.TCPClient(("127.0.0.1", self.port)) -- cgit v1.2.3 From 9ef84ccc1cdd0d8da890ba012812c760e31f2fab Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Oct 2014 00:15:39 +0200 Subject: clean up code --- netlib/certutils.py | 73 +++++++++++++++++++++++++------------------------- test/test_certutils.py | 6 ++--- test/test_tcp.py | 4 +-- 3 files changed, 42 insertions(+), 41 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index c9e6df26..af6177d8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -114,9 +114,9 @@ def dummy_cert(privkey, cacert, commonname, sans): class CertStoreEntry(object): - def __init__(self, cert, pkey=None, chain_file=None): + def __init__(self, cert, privatekey, chain_file): self.cert = cert - self.pkey = pkey + self.privatekey = privatekey self.chain_file = chain_file @@ -124,15 +124,15 @@ class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, default_pkey, default_ca, default_chain_file, dhparams=None): - self.default_pkey = default_pkey + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None): + self.default_privatekey = default_privatekey self.default_ca = default_ca self.default_chain_file = default_chain_file self.dhparams = dhparams self.certs = dict() - @classmethod - def load_dhparam(klass, path): + @staticmethod + def load_dhparam(path): # netlib<=0.10 doesn't generate a dhparam file. # Create it now if neccessary. @@ -163,8 +163,8 @@ class CertStore: dh = cls.load_dhparam(dh_path) return cls(key, ca, ca_path, dh) - @classmethod - def create_store(cls, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): + @staticmethod + def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP): if not os.path.exists(path): os.makedirs(path) @@ -173,32 +173,28 @@ class CertStore: key, ca = create_ca(o=o, cn=cn, exp=expiry) # Dump the CA plus private key - f = open(os.path.join(path, basename + "-ca.pem"), "wb") - f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: + f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Dump the certificate in PEM format - f = open(os.path.join(path, basename + "-ca-cert.pem"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Create a .cer file with the same contents for Android - f = open(os.path.join(path, basename + "-ca-cert.cer"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(path, basename + "-ca-cert.p12"), "wb") - p12 = OpenSSL.crypto.PKCS12() - p12.set_certificate(ca) - p12.set_privatekey(key) - f.write(p12.export()) - f.close() - - f = open(os.path.join(path, basename + "-dhparam.pem"), "wb") - f.write(DEFAULT_DHPARAM) - f.close() + with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: + p12 = OpenSSL.crypto.PKCS12() + p12.set_certificate(ca) + p12.set_privatekey(key) + f.write(p12.export()) + + with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: + f.write(DEFAULT_DHPARAM) + return key, ca def add_cert_file(self, spec, path): @@ -206,11 +202,11 @@ class CertStore: raw = f.read() cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)) try: - pkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + privatekey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) except Exception: - pkey = None + privatekey = self.default_privatekey self.add_cert( - CertStoreEntry(cert, pkey, path), + CertStoreEntry(cert, privatekey, path), spec ) @@ -241,7 +237,7 @@ class CertStore: def get_cert(self, commonname, sans): """ - Returns an (cert, privkey) tuple. + Returns an (cert, privkey, cert_chain) tuple. commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. @@ -260,15 +256,20 @@ class CertStore: if name: entry = self.certs[name] else: - entry = CertStoreEntry(cert=dummy_cert(self.default_pkey, self.default_ca, commonname, sans)) + entry = CertStoreEntry( + cert=dummy_cert(self.default_privatekey, self.default_ca, commonname, sans), + privatekey=self.default_privatekey, + chain_file=self.default_chain_file + ) self.certs[(commonname, tuple(sans))] = entry - return entry.cert, (entry.pkey or self.default_pkey), (entry.chain_file or self.default_chain_file) + return entry.cert, entry.privatekey, entry.chain_file def gen_pkey(self, cert): + # FIXME: We should do something with cert here? from . import certffi - certffi.set_flags(self.default_pkey, 1) - return self.default_pkey + certffi.set_flags(self.default_privatekey, 1) + return self.default_privatekey class _GeneralName(univ.Choice): diff --git a/test/test_certutils.py b/test/test_certutils.py index f68751ec..59c9dcd5 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -52,7 +52,7 @@ class TestCertStore: assert ca.get_cert("*.foo.com", []) r = ca.get_cert("*.foo.com", []) - assert r[1] == ca.default_pkey + assert r[1] == ca.default_privatekey def test_add_cert(self): with tutils.tmpdir() as d: @@ -98,7 +98,7 @@ class TestCertStore: cert = ca1.get_cert("foo.com", []) assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1 finally: - certffi.set_flags(ca2.default_pkey, 0) + certffi.set_flags(ca2.default_privatekey, 0) class TestDummyCert: @@ -106,7 +106,7 @@ class TestDummyCert: with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") r = certutils.dummy_cert( - ca.default_pkey, + ca.default_privatekey, ca.default_ca, "foo.com", ["one.com", "two.com", "*.three.com"] diff --git a/test/test_tcp.py b/test/test_tcp.py index 0eadac47..bf3d46bf 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -410,8 +410,8 @@ class TestPrivkeyGenNoFlags(test.ServerTestBase): ca1 = certutils.CertStore.from_store(d, "test2") ca2 = certutils.CertStore.from_store(d, "test3") cert, _, _ = ca1.get_cert("foo.com", []) - certffi.set_flags(ca2.default_pkey, 0) - self.convert_to_ssl(cert, ca2.default_pkey) + certffi.set_flags(ca2.default_privatekey, 0) + self.convert_to_ssl(cert, ca2.default_privatekey) def test_privkey(self): c = tcp.TCPClient(("127.0.0.1", self.port)) -- cgit v1.2.3 From 987fa22e646e2ab79cf93adf7966b5a27273685a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Oct 2014 01:46:08 +0200 Subject: make socks reading more bulletproof --- netlib/socks.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/netlib/socks.py b/netlib/socks.py index 1da5b6cc..5b05b397 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -47,6 +47,17 @@ class METHOD(object): NO_ACCEPTABLE_METHODS = 0xFF +def _read(f, n): + try: + d = f.read(n) + if len(d) == n: + return d + else: + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Incomplete Read") + except socket.error as e: + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e)) + + class ClientGreeting(object): __slots__ = ("ver", "methods") @@ -56,9 +67,9 @@ class ClientGreeting(object): @classmethod def from_file(cls, f): - ver, nmethods = struct.unpack("!BB", f.read(2)) + ver, nmethods = struct.unpack("!BB", _read(f, 2)) methods = array.array("B") - methods.fromstring(f.read(nmethods)) + methods.fromstring(_read(f, nmethods)) return cls(ver, methods) def to_file(self, f): @@ -74,7 +85,7 @@ class ServerGreeting(object): @classmethod def from_file(cls, f): - ver, method = struct.unpack("!BB", f.read(2)) + ver, method = struct.unpack("!BB", _read(f, 2)) return cls(ver, method) def to_file(self, f): @@ -91,26 +102,26 @@ class Message(object): @classmethod def from_file(cls, f): - ver, msg, rsv, atyp = struct.unpack("!BBBB", f.read(4)) + ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4)) if rsv != 0x00: raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: - host = socket.inet_ntoa(f.read(4)) # We use tnoa here as ntop is not commonly available on Windows. + host = socket.inet_ntoa(_read(f, 4)) # We use tnoa here as ntop is not commonly available on Windows. use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = socket.inet_ntop(socket.AF_INET6, f.read(16)) + host = socket.inet_ntop(socket.AF_INET6, _read(f, 16)) use_ipv6 = True elif atyp == ATYP.DOMAINNAME: - length, = struct.unpack("!B", f.read(1)) - host = f.read(length) + length, = struct.unpack("!B", _read(f, 1)) + host = _read(f, length) use_ipv6 = False else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Socks Request: Unknown ATYP: %s" % atyp) - port, = struct.unpack("!H", f.read(2)) + port, = struct.unpack("!H", _read(f, 2)) addr = tcp.Address((host, port), use_ipv6=use_ipv6) return cls(ver, msg, atyp, addr) -- cgit v1.2.3 From e6a8730f98d61583f31ac530e2a1c8da2fa181ed Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Oct 2014 04:42:39 +0200 Subject: fix tcp closing for ssled connections --- netlib/tcp.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 8e87bec8..7a970be6 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -238,19 +238,18 @@ class _Connection(object): else: self.connection.shutdown(socket.SHUT_WR) + if type(self.connection) != SSL.Connection or self.ssl_established: # Section 4.2.2.13 of RFC 1122 tells us that a close() with any # pending readable data could lead to an immediate RST being sent (which is the case on Windows). # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html # - # Do not call this for an SSL.Connection: + # Do not call this for every SSL.Connection: # If the SSL handshake failed at the first place, OpenSSL's SSL_read tries to negotiate the connection # again at this point, calls the SNI handler and segfaults. # https://github.com/mitmproxy/mitmproxy/issues/373#issuecomment-58383499 - # (if this turns out to be an issue for successful SSL connections, - # we should check for ssl_established or access the socket directly) - while self.connection.recv(4096): # pragma: no cover pass + self.connection.close() except (socket.error, SSL.Error, IOError): # Socket probably already closed -- cgit v1.2.3 From 29a4e9105053118aa8c0b458bcb8f10f0bc333d1 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 17 Oct 2014 18:48:30 +0200 Subject: fix mitmproxy/mitmproxy#375 --- netlib/tcp.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 7a970be6..4705f6df 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -243,12 +243,21 @@ class _Connection(object): # pending readable data could lead to an immediate RST being sent (which is the case on Windows). # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html # + # However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: + # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. + # As a workaround, we set a timeout here even if we were in blocking mode. + # Please let us know if you have a better solution to this problem. + # # Do not call this for every SSL.Connection: # If the SSL handshake failed at the first place, OpenSSL's SSL_read tries to negotiate the connection # again at this point, calls the SNI handler and segfaults. # https://github.com/mitmproxy/mitmproxy/issues/373#issuecomment-58383499 + timeout = self.connection.gettimeout() + self.connection.settimeout(timeout or 60) while self.connection.recv(4096): # pragma: no cover pass + self.connection.settimeout(timeout) self.connection.close() except (socket.error, SSL.Error, IOError): -- cgit v1.2.3 From ed5e6855652cd3a41579f700d2fb81169c60c3ea Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 22 Oct 2014 17:54:20 +0200 Subject: refactor tcp close, fix mitmproxy/mitmproxy#376 --- netlib/tcp.py | 99 ++++++++++++++++++++++++++++++----------------------------- 1 file changed, 51 insertions(+), 48 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 4705f6df..46c28cd9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -204,6 +204,37 @@ class Address(object): return not self.__eq__(other) +def close_socket(sock): + """ + Does a hard close of a socket, without emitting a RST. + """ + try: + # We already indicate that we close our end. + # If we close RD, any further received bytes would result in a RST being set, which we want to avoid + # for our purposes + sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux + + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any + # pending readable data could lead to an immediate RST being sent (which is the case on Windows). + # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + # + # However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: + # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. + # As a workaround, we set a timeout here even if we are in blocking mode. + # Please let us know if you have a better solution to this problem. + + sock.settimeout(sock.gettimeout() or 20) + # may raise a timeout/disconnect exception. + while sock.recv(4096): # pragma: no cover + pass + + except socket.error: + pass + + sock.close() + + class _Connection(object): def get_current_cipher(self): if not self.ssl_established: @@ -216,59 +247,36 @@ class _Connection(object): def finish(self): self.finished = True - try: + + # If we have an SSL connection, wfile.close == connection.close + # (We call _FileLike.set_descriptor(conn)) + # Closing the socket is not our task, therefore we don't call close then. + if type(self.connection) != SSL.Connection: if not getattr(self.wfile, "closed", False): self.wfile.flush() - self.close() + self.wfile.close() self.rfile.close() - except (socket.error, NetLibDisconnect): - # Remote has disconnected - pass - - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a - close. - """ - try: - if type(self.connection) == SSL.Connection: + else: + try: self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - - if type(self.connection) != SSL.Connection or self.ssl_established: - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent (which is the case on Windows). - # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - # - # However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: - # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that - # recv() would block infinitely. - # As a workaround, we set a timeout here even if we were in blocking mode. - # Please let us know if you have a better solution to this problem. - # - # Do not call this for every SSL.Connection: - # If the SSL handshake failed at the first place, OpenSSL's SSL_read tries to negotiate the connection - # again at this point, calls the SNI handler and segfaults. - # https://github.com/mitmproxy/mitmproxy/issues/373#issuecomment-58383499 - timeout = self.connection.gettimeout() - self.connection.settimeout(timeout or 60) - while self.connection.recv(4096): # pragma: no cover - pass - self.connection.settimeout(timeout) - - self.connection.close() - except (socket.error, SSL.Error, IOError): - # Socket probably already closed - pass + except SSL.Error: + pass class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 + def close(self): + # Make sure to close the real socket, not the SSL proxy. + # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, + # it tries to renegotiate... + if type(self.connection) == SSL.Connection: + close_socket(self.connection._socket) + else: + close_socket(self.connection) + def __init__(self, address, source_address=None): self.address = Address.wrap(address) self.source_address = Address.wrap(source_address) if source_address else None @@ -430,7 +438,6 @@ class BaseHandler(_Connection): self.connection.settimeout(n) - class TCPServer(object): request_queue_size = 20 def __init__(self, address): @@ -450,11 +457,7 @@ class TCPServer(object): except: self.handle_error(connection, client_address) finally: - try: - connection.shutdown(socket.SHUT_RDWR) - except: - pass - connection.close() + close_socket(connection) def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() -- cgit v1.2.3 From 74a560019080b22c8f578860654ec071141b7ca7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 23 Oct 2014 15:31:42 +0200 Subject: fix tests --- test/test_tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_tcp.py b/test/test_tcp.py index bf3d46bf..ce96f16f 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -277,7 +277,7 @@ class TestClientCipherListError(test.ServerTestBase): class TestSSLDisconnect(test.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): - self.close() + self.finish() ssl = dict( cert = tutils.test_data.path("data/server.crt"), key = tutils.test_data.path("data/server.key"), -- cgit v1.2.3 From ba468f12b8f59f63ce85b221f0cb2d9e004efe6e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 26 Oct 2014 17:30:26 +1300 Subject: Whitespace and legibility --- netlib/http.py | 80 +++++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 23 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 35e959cd..9268418c 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -120,11 +120,14 @@ def read_chunked(fp, limit, is_request): try: length = int(line, 16) except ValueError: - raise HttpError(code, "Invalid chunked encoding length: %s" % line) + raise HttpError( + code, + "Invalid chunked encoding length: %s" % line + ) total += length if limit is not None and total > limit: - msg = "HTTP Body too large." \ - " Limit is %s, chunked content length was at least %s" % (limit, total) + msg = "HTTP Body too large. Limit is %s," \ + " chunked content longer than %s" % (limit, total) raise HttpError(code, msg) chunk = fp.read(length) suffix = fp.readline(5) @@ -149,7 +152,9 @@ def get_header_tokens(headers, key): def has_chunked_encoding(headers): - return "chunked" in [i.lower() for i in get_header_tokens(headers, "transfer-encoding")] + return "chunked" in [ + i.lower() for i in get_header_tokens(headers, "transfer-encoding") + ] def parse_http_protocol(s): @@ -261,8 +266,9 @@ def parse_init_http(line): def connection_close(httpversion, headers): """ - Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1 - Note that a connection should be closed as well if the response has been read until end of the stream. + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1 Note that a connection should be + closed as well if the response has been read until end of the stream. """ # At first, check if we have an explicit Connection header. if "connection" in headers: @@ -271,7 +277,8 @@ def connection_close(httpversion, headers): return True elif "keep-alive" in toks: return False - # If we don't have a Connection header, HTTP 1.1 connections are assumed to be persistent + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent if httpversion == (1, 1): return False return True @@ -317,14 +324,25 @@ def read_response(rfile, request_method, body_size_limit, include_body=True): raise HttpError(502, "Invalid headers.") if include_body: - content = read_http_body(rfile, headers, body_size_limit, request_method, code, False) + content = read_http_body( + rfile, + headers, + body_size_limit, + request_method, + code, + False + ) else: - content = None # if include_body==False then a None content means the body should be read separately + # if include_body==False then a None content means the body should be + # read separately + content = None return httpversion, code, msg, headers, content def read_http_body(*args, **kwargs): - return "".join(content for _, content, _ in read_http_body_chunked(*args, **kwargs)) + return "".join( + content for _, content, _ in read_http_body_chunked(*args, **kwargs) + ) def read_http_body_chunked(rfile, headers, limit, request_method, response_code, is_request, max_chunk_size=None): @@ -334,12 +352,15 @@ def read_http_body_chunked(rfile, headers, limit, request_method, response_code, rfile: A file descriptor to read from headers: An ODictCaseless object limit: Size limit. - is_request: True if the body to read belongs to a request, False otherwise + is_request: True if the body to read belongs to a request, False + otherwise """ if max_chunk_size is None: max_chunk_size = limit or sys.maxint - expected_size = expected_http_body_size(headers, is_request, request_method, response_code) + expected_size = expected_http_body_size( + headers, is_request, request_method, response_code + ) if expected_size is None: if has_chunked_encoding(headers): @@ -347,11 +368,18 @@ def read_http_body_chunked(rfile, headers, limit, request_method, response_code, for x in read_chunked(rfile, limit, is_request): yield x else: # pragma: nocover - raise HttpError(400 if is_request else 502, "Content-Length unknown but no chunked encoding") + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) elif expected_size >= 0: if limit is not None and expected_size > limit: - raise HttpError(400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % (limit, expected_size)) + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) bytes_left = expected_size while bytes_left: chunk_size = min(bytes_left, max_chunk_size) @@ -368,7 +396,10 @@ def read_http_body_chunked(rfile, headers, limit, request_method, response_code, bytes_left -= chunk_size not_done = rfile.read(1) if not_done: - raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) def expected_http_body_size(headers, is_request, request_method, response_code): @@ -378,16 +409,16 @@ def expected_http_body_size(headers, is_request, request_method, response_code): - None, if the size in unknown in advance (chunked encoding) - -1, if all data should be read until end of stream. """ - - # Determine response size according to http://tools.ietf.org/html/rfc7230#section-3.3 + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 if request_method: request_method = request_method.upper() if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): return 0 if has_chunked_encoding(headers): return None @@ -398,7 +429,10 @@ def expected_http_body_size(headers, is_request, request_method, response_code): raise ValueError() return size except ValueError: - raise HttpError(400 if is_request else 502, "Invalid content-length header: %s" % headers["content-length"]) + raise HttpError( + 400 if is_request else 502, + "Invalid content-length header: %s" % headers["content-length"] + ) if is_request: return 0 return -1 -- cgit v1.2.3 From 9ce2f473f6febf3738dca77b20ab9a7d3092d3d0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 7 Nov 2014 15:59:00 +1300 Subject: Simplify expected_http_body_size signature, fixing a traceback found in fuzzing --- netlib/http.py | 10 +++++----- netlib/http_auth.py | 4 ++-- netlib/socks.py | 18 ++++++++++++++---- test/test_http.py | 4 ++-- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 9268418c..d2fc6343 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -406,8 +406,11 @@ def expected_http_body_size(headers, is_request, request_method, response_code): """ Returns the expected body length: - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding) + - None, if the size in unknown in advance (chunked encoding or invalid + data) - -1, if all data should be read until end of stream. + + May raise HttpError. """ # Determine response size according to # http://tools.ietf.org/html/rfc7230#section-3.3 @@ -429,10 +432,7 @@ def expected_http_body_size(headers, is_request, request_method, response_code): raise ValueError() return size except ValueError: - raise HttpError( - 400 if is_request else 502, - "Invalid content-length header: %s" % headers["content-length"] - ) + return None if is_request: return 0 return -1 diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 49f5925f..dca6e2f3 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -1,5 +1,4 @@ from __future__ import (absolute_import, print_function, division) -from passlib.apache import HtpasswdFile from argparse import Action, ArgumentTypeError from . import http @@ -83,7 +82,8 @@ class PassManHtpasswd: """ Raises ValueError if htpasswd file is invalid. """ - self.htpasswd = HtpasswdFile(path) + import passlib.apache + self.htpasswd = passlib.apache.HtpasswdFile(path) def test(self, username, password_token): return bool(self.htpasswd.check_password(username, password_token)) diff --git a/netlib/socks.py b/netlib/socks.py index 5b05b397..a3c4e9a2 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -53,7 +53,10 @@ def _read(f, n): if len(d) == n: return d else: - raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Incomplete Read") + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + "Incomplete Read" + ) except socket.error as e: raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e)) @@ -76,6 +79,7 @@ class ClientGreeting(object): f.write(struct.pack("!BB", self.ver, len(self.methods))) f.write(self.methods.tostring()) + class ServerGreeting(object): __slots__ = ("ver", "method") @@ -91,6 +95,7 @@ class ServerGreeting(object): def to_file(self, f): f.write(struct.pack("!BB", self.ver, self.method)) + class Message(object): __slots__ = ("ver", "msg", "atyp", "addr") @@ -108,7 +113,8 @@ class Message(object): "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: - host = socket.inet_ntoa(_read(f, 4)) # We use tnoa here as ntop is not commonly available on Windows. + # We use tnoa here as ntop is not commonly available on Windows. + host = socket.inet_ntoa(_read(f, 4)) use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: host = socket.inet_ntop(socket.AF_INET6, _read(f, 16)) @@ -135,5 +141,9 @@ class Message(object): f.write(struct.pack("!B", len(self.addr.host))) f.write(self.addr.host) else: - raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Unknown ATYP: %s" % self.atyp) - f.write(struct.pack("!H", self.addr.port)) \ No newline at end of file + raise SocksError( + REP.ADDRESS_TYPE_NOT_SUPPORTED, + "Unknown ATYP: %s" % self.atyp + ) + f.write(struct.pack("!H", self.addr.port)) + diff --git a/test/test_http.py b/test/test_http.py index 497e80e2..e3e92a1e 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -119,11 +119,11 @@ def test_expected_http_body_size(): # gibber in the content-length field h = odict.ODictCaseless() h["content-length"] = ["foo"] - tutils.raises(http.HttpError, http.expected_http_body_size, h, False, "GET", 200) + assert http.expected_http_body_size(h, False, "GET", 200) is None # negative number in the content-length field h = odict.ODictCaseless() h["content-length"] = ["-7"] - tutils.raises(http.HttpError, http.expected_http_body_size, h, False, "GET", 200) + assert http.expected_http_body_size(h, False, "GET", 200) is None # explicit length h = odict.ODictCaseless() h["content-length"] = ["5"] -- cgit v1.2.3 From 0811a9ebde4975d4e934cf4752376dd0db9bb7e4 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 7 Nov 2014 16:01:41 +1300 Subject: .flush can raise NetlibDisconnect. This fixes a traceback found in fuzzing. --- netlib/tcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 46c28cd9..6b7540aa 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -253,7 +253,10 @@ class _Connection(object): # Closing the socket is not our task, therefore we don't call close then. if type(self.connection) != SSL.Connection: if not getattr(self.wfile, "closed", False): - self.wfile.flush() + try: + self.wfile.flush() + except NetLibDisconnect: + pass self.wfile.close() self.rfile.close() -- cgit v1.2.3 From 3b468849e6de501b8d04c2f8c043dd2960387dae Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 11 Nov 2014 14:02:13 +1300 Subject: Update pathod version number in requirements --- test/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/requirements.txt b/test/requirements.txt index 89e4aa0a..b13080b4 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -2,4 +2,4 @@ mock>=1.0.1 nose>=1.3.0 nose-cov>=1.6 coveralls>=0.4.1 -pathod>=0.10 \ No newline at end of file +pathod>=0.11 -- cgit v1.2.3 From 60584387ff860befe38ada5ec9d35f3c529d0238 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 11 Nov 2014 12:26:20 +0100 Subject: be more explicit about requirements --- netlib/version.py | 4 ++++ setup.py | 2 +- test/requirements.txt | 5 ----- 3 files changed, 5 insertions(+), 6 deletions(-) delete mode 100644 test/requirements.txt diff --git a/netlib/version.py b/netlib/version.py index 913f753a..15a8edf9 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -5,3 +5,7 @@ VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" NAMEVERSION = NAME + " " + VERSION + +NEXT_MINORVERSION = list(IVERSION) +NEXT_MINORVERSION[1] += 1 +NEXT_MINORVERSION = ".".join(str(i) for i in NEXT_MINORVERSION[:2]) \ No newline at end of file diff --git a/setup.py b/setup.py index ad3073fc..f6f8907c 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( "nose>=1.3.0", "nose-cov>=1.6", "coveralls>=0.4.1", - "pathod>=0.%s" % version.MINORVERSION + "pathod>=%s, <%s" % (version.MINORVERSION, version.NEXT_MINORVERSION) ] } ) \ No newline at end of file diff --git a/test/requirements.txt b/test/requirements.txt deleted file mode 100644 index b13080b4..00000000 --- a/test/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -mock>=1.0.1 -nose>=1.3.0 -nose-cov>=1.6 -coveralls>=0.4.1 -pathod>=0.11 -- cgit v1.2.3 From c56e7a90d886d7169a75246de062f0f90028ae6c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 15 Nov 2014 12:31:13 +1300 Subject: Fix tracebacks in connection finish --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 6b7540aa..1c3bf230 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -255,10 +255,10 @@ class _Connection(object): if not getattr(self.wfile, "closed", False): try: self.wfile.flush() + self.wfile.close() except NetLibDisconnect: pass - self.wfile.close() self.rfile.close() else: try: -- cgit v1.2.3 From 7098c90a6dceddda20de4d7a7dabf836247a38af Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 15 Nov 2014 12:45:06 +1300 Subject: Bump version to 0.11.1 --- netlib/version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/version.py b/netlib/version.py index 15a8edf9..f67d06b3 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 11) +IVERSION = (0, 11, 1) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" @@ -8,4 +8,4 @@ NAMEVERSION = NAME + " " + VERSION NEXT_MINORVERSION = list(IVERSION) NEXT_MINORVERSION[1] += 1 -NEXT_MINORVERSION = ".".join(str(i) for i in NEXT_MINORVERSION[:2]) \ No newline at end of file +NEXT_MINORVERSION = ".".join(str(i) for i in NEXT_MINORVERSION[:2]) -- cgit v1.2.3 From 438c1fbc7dddcbddd234db3806a4d6b5770d9904 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 15 Dec 2014 12:32:36 +0100 Subject: TCPClient: Use TLS1.1+ where available, BaseHandler: disable SSLv2 --- netlib/tcp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 1c3bf230..7010eef0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,6 +16,8 @@ SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD +OP_NO_SSLv2 = SSL.OP_NO_SSLv2 +OP_NO_SSLv3 = SSL.OP_NO_SSLv3 class NetLibError(Exception): pass @@ -288,7 +290,7 @@ class TCPClient(_Connection): self.ssl_established = False self.sni = None - def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None, cipher_list=None): + def convert_to_ssl(self, cert=None, sni=None, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), cipher_list=None): """ cert: Path to a file containing both client cert and private key. @@ -362,7 +364,7 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, + def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=OP_NO_SSLv2, handle_sni=None, request_client_cert=None, cipher_list=None, dhparams=None, chain_file=None): """ -- cgit v1.2.3 From 3c919631d40cef69dacd166dabafc238a753edc8 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 28 Dec 2014 22:46:19 +1300 Subject: Bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index f67d06b3..826c66fe 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 11, 1) +IVERSION = (0, 11, 2) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From b6af3fddf4ee2c45a722bd5087c86cee59b8cfa0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 7 Feb 2015 01:43:25 +0100 Subject: pypy support, faster travis builds --- .travis.yml | 11 ++++++++--- README.mkd | 6 +++++- setup.py | 2 ++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 7c4dca92..a2e8d5ff 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,9 +1,11 @@ language: python +sudo: false python: - "2.7" + - pypy # command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors -install: - - "pip install --src .. -r requirements.txt" +install: + - "pip install --src . -r requirements.txt" # command to run tests, e.g. python setup.py test script: - "nosetests --with-cov --cov-report term-missing" @@ -15,4 +17,7 @@ notifications: - "irc.oftc.net#mitmproxy" on_success: change on_failure: always - +cache: + directories: + - /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages + - /home/travis/virtualenv/pypy-2.5.0/site-packages \ No newline at end of file diff --git a/README.mkd b/README.mkd index 7c96d396..79e7f803 100644 --- a/README.mkd +++ b/README.mkd @@ -1,4 +1,8 @@ -[![Build Status](https://travis-ci.org/mitmproxy/netlib.png?branch=master)](https://travis-ci.org/mitmproxy/netlib) [![Coverage Status](https://coveralls.io/repos/mitmproxy/netlib/badge.png?branch=master)](https://coveralls.io/r/mitmproxy/netlib) +[![Build Status](https://travis-ci.org/mitmproxy/netlib.svg?branch=master)](https://travis-ci.org/mitmproxy/netlib) +[![Coverage Status](https://coveralls.io/repos/mitmproxy/netlib/badge.svg?branch=master)](https://coveralls.io/r/mitmproxy/netlib) +[![Latest Version](https://pypip.in/version/netlib/badge.svg?style=flat)](https://pypi.python.org/pypi/netlib) +[![Supported Python versions](https://pypip.in/py_versions/netlib/badge.svg?style=flat)](https://pypi.python.org/pypi/netlib) +[![Supported Python implementations](https://pypip.in/implementation/netlib/badge.svg?style=flat)](https://pypi.python.org/pypi/netlib) Netlib is a collection of network utility classes, used by the pathod and mitmproxy projects. It differs from other projects in some fundamental diff --git a/setup.py b/setup.py index f6f8907c..8e3d51b8 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,8 @@ setup( "Operating System :: POSIX", "Programming Language :: Python", "Programming Language :: Python :: 2", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Internet", "Topic :: Internet :: WWW/HTTP", "Topic :: Internet :: WWW/HTTP :: HTTP Servers", -- cgit v1.2.3 From c9de3e770b8b8567cc3c233e9d0f82fd7a47e634 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 17 Feb 2015 11:59:07 +1300 Subject: By popular demand, bump dummy cert expiry to 5 years fixes #52 --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index af6177d8..948eb85d 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -61,7 +61,7 @@ def dummy_cert(privkey, cacert, commonname, sans): cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600*48) - cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) + cert.gmtime_adj_notAfter(60 * 60 * 24 * 30 * 365 * 5) cert.set_issuer(cacert.get_subject()) cert.get_subject().CN = commonname cert.set_serial_number(int(time.time()*10000)) -- cgit v1.2.3 From 7e5bb74e7211dbe06b33847475854f54c56aa8d5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 17 Feb 2015 12:03:52 +1300 Subject: 5 years is enough... --- netlib/certutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 948eb85d..3eb9846d 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -61,7 +61,7 @@ def dummy_cert(privkey, cacert, commonname, sans): cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600*48) - cert.gmtime_adj_notAfter(60 * 60 * 24 * 30 * 365 * 5) + cert.gmtime_adj_notAfter(60 * 60 * 24 * 365 * 5) cert.set_issuer(cacert.get_subject()) cert.get_subject().CN = commonname cert.set_serial_number(int(time.time()*10000)) -- cgit v1.2.3 From 2a2402dfffc9f1a51869170793673eaf49207d0f Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 17 Feb 2015 00:10:10 +0100 Subject: ...two years is not enough. --- netlib/certutils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 3eb9846d..5d8a56b8 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -6,7 +6,7 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 +DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS----- MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 @@ -61,7 +61,7 @@ def dummy_cert(privkey, cacert, commonname, sans): cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600*48) - cert.gmtime_adj_notAfter(60 * 60 * 24 * 365 * 5) + cert.gmtime_adj_notAfter(DEFAULT_EXP) cert.set_issuer(cacert.get_subject()) cert.get_subject().CN = commonname cert.set_serial_number(int(time.time()*10000)) -- cgit v1.2.3 From 224f737646a3f9d0d6540a295524806df7ed1943 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 27 Feb 2015 16:59:29 +0100 Subject: add option to log ssl keys refs mitmproxy/mitmproxy#475 --- netlib/tcp.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 7010eef0..c6e0075e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,4 +1,5 @@ from __future__ import (absolute_import, print_function, division) +import os import select import socket import sys @@ -26,6 +27,37 @@ class NetLibTimeout(NetLibError): pass class NetLibSSLError(NetLibError): pass +class SSLKeyLogger(object): + def __init__(self, filename): + self.filename = filename + self.f = None + self.lock = threading.Lock() + + __name__ = "SSLKeyLogger" # required for functools.wraps, which pyOpenSSL uses. + + def __call__(self, connection, where, ret): + if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: + with self.lock: + if not self.f: + self.f = open(self.filename, "ab") + self.f.write("\r\n") + client_random = connection.client_random().encode("hex") + masterkey = connection.master_key().encode("hex") + self.f.write("CLIENT_RANDOM {} {}\r\n".format(client_random, masterkey)) + self.f.flush() + + def close(self): + with self.lock: + if self.f: + self.f.close() + +_logfile = os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE") +if _logfile: + log_ssl_key = SSLKeyLogger(_logfile) +else: + log_ssl_key = False + + class _FileLike: BLOCKSIZE = 1024 * 32 def __init__(self, o): @@ -314,6 +346,8 @@ class TCPClient(_Connection): if sni: self.sni = sni self.connection.set_tlsext_host_name(sni) + if log_ssl_key: + context.set_info_callback(log_ssl_key) self.connection.set_connect_state() try: self.connection.do_handshake() @@ -418,6 +452,8 @@ class BaseHandler(_Connection): # Return true to prevent cert verification error return True ctx.set_verify(SSL.VERIFY_PEER, ver) + if log_ssl_key: + ctx.set_info_callback(log_ssl_key) return ctx def convert_to_ssl(self, cert, key, **sslctx_kwargs): -- cgit v1.2.3 From 63fb43369029d33ce77cb2ce1df397e99494562c Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 27 Feb 2015 20:40:17 +0100 Subject: fix #53 --- netlib/odict.py | 2 +- test/test_odict.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/netlib/odict.py b/netlib/odict.py index 61448e6d..f97f074b 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -166,7 +166,7 @@ class ODict: return [tuple(i) for i in self.lst] def load_state(self, state): - self.list = [list(i) for i in state] + self.lst = [list(i) for i in state] @classmethod def from_state(klass, state): diff --git a/test/test_odict.py b/test/test_odict.py index a682d7eb..d90bc6e5 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -31,7 +31,9 @@ class TestODict: state = self.od.get_state() nd = odict.ODict.from_state(state) assert nd == self.od - nd.load_state(state) + b = odict.ODict() + b.load_state(state) + assert b == self.od def test_dictToHeader2(self): self.od["one"] = ["uno"] @@ -78,6 +80,7 @@ class TestODict: self.od.add("foo", 2) self.od.add("bar", 3) assert self.od == self.od.copy() + assert not self.od != self.od.copy() def test_del(self): self.od.add("foo", 1) -- cgit v1.2.3 From da1eb94ccd36b31ea7e05c6a4e01dd5a6cf20376 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 27 Feb 2015 22:02:52 +0100 Subject: 100% test coverage :tada: --- netlib/tcp.py | 26 +++++++------ netlib/test.py | 3 +- test/test_certutils.py | 37 ++++++++++-------- test/test_http.py | 6 +++ test/test_socks.py | 14 ++++++- test/test_tcp.py | 103 ++++++++++++++++++++++++++++++++++++------------- 6 files changed, 131 insertions(+), 58 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index c6e0075e..7f98b4f9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -39,6 +39,9 @@ class SSLKeyLogger(object): if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: with self.lock: if not self.f: + d = os.path.dirname(self.filename) + if not os.path.isdir(d): + os.makedirs(d) self.f = open(self.filename, "ab") self.f.write("\r\n") client_random = connection.client_random().encode("hex") @@ -51,11 +54,13 @@ class SSLKeyLogger(object): if self.f: self.f.close() -_logfile = os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE") -if _logfile: - log_ssl_key = SSLKeyLogger(_logfile) -else: - log_ssl_key = False + @staticmethod + def create_logfun(filename): + if filename: + return SSLKeyLogger(filename) + return False + +log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) class _FileLike: @@ -161,9 +166,9 @@ class Reader(_FileLike): except SSL.SysCallError as e: if e.args == (-1, 'Unexpected EOF'): break - raise NetLibDisconnect - except SSL.Error, v: - raise NetLibSSLError(v.message) + raise NetLibSSLError(e.message) + except SSL.Error as e: + raise NetLibSSLError(e.message) self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break @@ -179,10 +184,7 @@ class Reader(_FileLike): while True: if size is not None and bytes_read >= size: break - try: - ch = self.read(1) - except NetLibDisconnect: - break + ch = self.read(1) bytes_read += 1 if not ch: break diff --git a/netlib/test.py b/netlib/test.py index fb468907..3a23ba8f 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -75,7 +75,8 @@ class TServer(tcp.TCPServer): handle_sni = getattr(h, "handle_sni", None), request_client_cert = self.ssl["request_client_cert"], cipher_list = self.ssl.get("cipher_list", None), - dhparams = self.ssl.get("dhparams", None) + dhparams = self.ssl.get("dhparams", None), + chain_file = self.ssl.get("chain_file", None) ) h.handle() h.finish() diff --git a/test/test_certutils.py b/test/test_certutils.py index 59c9dcd5..c96c5087 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -80,7 +80,7 @@ class TestCertStore: ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test") assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() - dc = ca2.get_cert("foo.com", []) + dc = ca2.get_cert("foo.com", ["sans.example.com"]) dcp = os.path.join(d, "dc") f = open(dcp, "wb") f.write(dc[0].to_pem()) @@ -118,31 +118,34 @@ class TestSSLCert: def test_simple(self): with open(tutils.test_data.path("data/text_cert"), "rb") as f: d = f.read() - c = certutils.SSLCert.from_pem(d) - assert c.cn == "google.com" - assert len(c.altnames) == 436 + c1 = certutils.SSLCert.from_pem(d) + assert c1.cn == "google.com" + assert len(c1.altnames) == 436 with open(tutils.test_data.path("data/text_cert_2"), "rb") as f: d = f.read() - c = certutils.SSLCert.from_pem(d) - assert c.cn == "www.inode.co.nz" - assert len(c.altnames) == 2 - assert c.digest("sha1") - assert c.notbefore - assert c.notafter - assert c.subject - assert c.keyinfo == ("RSA", 2048) - assert c.serial - assert c.issuer - assert c.to_pem() - c.has_expired + c2 = certutils.SSLCert.from_pem(d) + assert c2.cn == "www.inode.co.nz" + assert len(c2.altnames) == 2 + assert c2.digest("sha1") + assert c2.notbefore + assert c2.notafter + assert c2.subject + assert c2.keyinfo == ("RSA", 2048) + assert c2.serial + assert c2.issuer + assert c2.to_pem() + assert c2.has_expired is not None + + assert not c1 == c2 + assert c1 != c2 def test_err_broken_sans(self): with open(tutils.test_data.path("data/text_cert_weird1"), "rb") as f: d = f.read() c = certutils.SSLCert.from_pem(d) # This breaks unless we ignore a decoding error. - c.altnames + assert c.altnames is not None def test_der(self): with open(tutils.test_data.path("data/dercert"), "rb") as f: diff --git a/test/test_http.py b/test/test_http.py index e3e92a1e..fed60946 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -325,6 +325,12 @@ def test_parse_url(): assert po == 80 assert pa == "/bar" + s, h, po, pa = http.parse_url("http://user:pass@foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + s, h, po, pa = http.parse_url("http://foo") assert pa == "/" diff --git a/test/test_socks.py b/test/test_socks.py index 740fdb9c..aa4f9c11 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -1,5 +1,6 @@ from cStringIO import StringIO import socket +import mock from nose.plugins.skip import SkipTest from netlib import socks, tcp import tutils @@ -81,4 +82,15 @@ def test_message_unknown_atyp(): tutils.raises(socks.SocksError, socks.Message.from_file, raw) m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) - tutils.raises(socks.SocksError, m.to_file, StringIO()) \ No newline at end of file + tutils.raises(socks.SocksError, m.to_file, StringIO()) + +def test_read(): + cs = StringIO("1234") + assert socks._read(cs, 3) == "123" + + cs = StringIO("123") + tutils.raises(socks.SocksError, socks._read, cs, 4) + + cs = mock.Mock() + cs.read = mock.Mock(side_effect=socket.error) + tutils.raises(socks.SocksError, socks._read, cs, 4) \ No newline at end of file diff --git a/test/test_tcp.py b/test/test_tcp.py index ce96f16f..21fea23e 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,4 +1,5 @@ import cStringIO, Queue, time, socket, random +import os from netlib import tcp, certutils, test, certffi import mock import tutils @@ -71,30 +72,6 @@ class TestServerIPv6(test.ServerTestBase): assert c.rfile.readline() == testval -class FinishFailHandler(tcp.BaseHandler): - def handle(self): - v = self.rfile.readline() - self.wfile.write(v) - self.wfile.flush() - self.wfile.close() - self.rfile.close() - self.close = mock.MagicMock(side_effect=socket.error) - - -class TestFinishFail(test.ServerTestBase): - """ - This tests a difficult-to-trigger exception in the .finish() method of - the handler. - """ - handler = FinishFailHandler - def test_disconnect_in_finish(self): - testval = "echo!\n" - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.wfile.write("foo\n") - c.wfile.flush() - c.rfile.read(4) - class TestDisconnect(test.ServerTestBase): handler = EchoHandler def test_echo(self): @@ -111,6 +88,20 @@ class HardDisconnectHandler(tcp.BaseHandler): self.connection.close() +class TestFinishFail(test.ServerTestBase): + """ + This tests a difficult-to-trigger exception in the .finish() method of + the handler. + """ + handler = EchoHandler + def test_disconnect_in_finish(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.wfile.write("foo\n") + c.wfile.flush = mock.Mock(side_effect=tcp.NetLibDisconnect) + c.finish() + + class TestServerSSL(test.ServerTestBase): handler = EchoHandler ssl = dict( @@ -118,7 +109,8 @@ class TestServerSSL(test.ServerTestBase): key = tutils.test_data.path("data/server.key"), request_client_cert = False, v3_only = False, - cipher_list = "AES256-SHA" + cipher_list = "AES256-SHA", + chain_file=tutils.test_data.path("data/server.crt") ) def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -150,7 +142,7 @@ class TestSSLv3Only(test.ServerTestBase): def test_failure(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD) + tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com") class TestSSLClientCert(test.ServerTestBase): @@ -385,6 +377,11 @@ class TestDHParams(test.ServerTestBase): ret = c.get_current_cipher() assert ret[0] == "DHE-RSA-AES256-SHA" + def test_create_dhparams(self): + with tutils.tmpdir() as d: + filename = os.path.join(d, "dhparam.pem") + certutils.CertStore.load_dhparam(filename) + assert os.path.exists(filename) class TestPrivkeyGen(test.ServerTestBase): @@ -527,12 +524,22 @@ class TestFileLike: assert s.first_byte_timestamp == expected def test_read_ssl_error(self): - s = cStringIO.StringIO("foobar\nfoobar") s = mock.MagicMock() s.read = mock.MagicMock(side_effect=SSL.Error()) s = tcp.Reader(s) tutils.raises(tcp.NetLibSSLError, s.read, 1) + def test_read_syscall_ssl_error(self): + s = mock.MagicMock() + s.read = mock.MagicMock(side_effect=SSL.SysCallError()) + s = tcp.Reader(s) + tutils.raises(tcp.NetLibSSLError, s.read, 1) + + def test_reader_readline_disconnect(self): + o = mock.MagicMock() + o.read = mock.MagicMock(side_effect=socket.error) + s = tcp.Reader(o) + tutils.raises(tcp.NetLibDisconnect, s.readline, 10) class TestAddress: def test_simple(self): @@ -542,3 +549,45 @@ class TestAddress: assert not a == b c = tcp.Address("localhost", True) assert a == c + assert not a != c + assert repr(a) + + +class TestServer(test.ServerTestBase): + handler = EchoHandler + def test_echo(self): + testval = "echo!\n" + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval + +class TestSSLKeyLogger(test.ServerTestBase): + handler = EchoHandler + ssl = dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + request_client_cert = False, + v3_only = False, + cipher_list = "AES256-SHA" + ) + + def test_log(self): + _logfun = tcp.log_ssl_key + + with tutils.tmpdir() as d: + logfile = os.path.join(d, "foo", "bar", "logfile") + tcp.log_ssl_key = tcp.SSLKeyLogger(logfile) + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + tcp.log_ssl_key.close() + with open(logfile, "rb") as f: + assert f.read().count("CLIENT_RANDOM") == 2 + + tcp.log_ssl_key = _logfun + + def test_create_logfun(self): + assert isinstance(tcp.SSLKeyLogger.create_logfun("test"), tcp.SSLKeyLogger) + assert not tcp.SSLKeyLogger.create_logfun(False) \ No newline at end of file -- cgit v1.2.3 From d71f3b68fda688fec358b59fdcfaaa7031b3b80d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 27 Feb 2015 22:27:23 +0100 Subject: make tests more robust, fix coveralls --- .travis.yml | 4 +++- netlib/test.py | 2 +- test/test_tcp.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index a2e8d5ff..aac6b272 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,4 +20,6 @@ notifications: cache: directories: - /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages - - /home/travis/virtualenv/pypy-2.5.0/site-packages \ No newline at end of file + - /home/travis/virtualenv/python2.7.9/bin + - /home/travis/virtualenv/pypy-2.5.0/site-packages + - /home/travis/virtualenv/pypy-2.5.0/bin \ No newline at end of file diff --git a/netlib/test.py b/netlib/test.py index 3a23ba8f..db30c0e6 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -15,7 +15,7 @@ class ServerThread(threading.Thread): self.server.shutdown() -class ServerTestBase: +class ServerTestBase(object): ssl = None handler = None addr = ("localhost", 0) diff --git a/test/test_tcp.py b/test/test_tcp.py index 21fea23e..2216e0d4 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -563,6 +563,7 @@ class TestServer(test.ServerTestBase): c.wfile.flush() assert c.rfile.readline() == testval + class TestSSLKeyLogger(test.ServerTestBase): handler = EchoHandler ssl = dict( @@ -582,6 +583,7 @@ class TestSSLKeyLogger(test.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() + c.finish() tcp.log_ssl_key.close() with open(logfile, "rb") as f: assert f.read().count("CLIENT_RANDOM") == 2 -- cgit v1.2.3 From 24a3dd59fec825afeff8cb62426f963653e8476a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 27 Feb 2015 22:34:36 +0100 Subject: try harder to fix race condition in tests --- test/test_tcp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_tcp.py b/test/test_tcp.py index 2216e0d4..b93b1e0a 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -575,15 +575,21 @@ class TestSSLKeyLogger(test.ServerTestBase): ) def test_log(self): + testval = "echo!\n" _logfun = tcp.log_ssl_key with tutils.tmpdir() as d: logfile = os.path.join(d, "foo", "bar", "logfile") tcp.log_ssl_key = tcp.SSLKeyLogger(logfile) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval c.finish() + tcp.log_ssl_key.close() with open(logfile, "rb") as f: assert f.read().count("CLIENT_RANDOM") == 2 -- cgit v1.2.3 From dbadc1b61327d06bb176d0465ad5831a619126be Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 7 Mar 2015 01:22:02 +0100 Subject: clean up cert handling, fix mitmproxy/mitmproxy#472 --- netlib/tcp.py | 140 +++++++++++++++++++++++++++++++++++---------------------- test/tutils.py | 4 +- 2 files changed, 88 insertions(+), 56 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 7f98b4f9..ba4f008c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -302,6 +302,43 @@ class _Connection(object): except SSL.Error: pass + """ + Creates an SSL Context. + """ + def _create_ssl_context(self, + method=SSLv23_METHOD, + options=(OP_NO_SSLv2 | OP_NO_SSLv3), + cipher_list=None + ): + """ + :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD + :param options: A bit field consisting of OpenSSL.SSL.OP_* values + :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html + :rtype : SSL.Context + """ + context = SSL.Context(method) + # Options (NO_SSLv2/3) + if options is not None: + context.set_options(options) + + # Workaround for + # https://github.com/pyca/pyopenssl/issues/190 + # https://github.com/mitmproxy/mitmproxy/issues/472 + context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Options already set before are not cleared. + + # Cipher List + if cipher_list: + try: + context.set_cipher_list(cipher_list) + except SSL.Error, v: + raise NetLibError("SSL cipher specification error: %s"%str(v)) + + # SSLKEYLOGFILE + if log_ssl_key: + context.set_info_callback(log_ssl_key) + + return context + class TCPClient(_Connection): rbufsize = -1 @@ -324,32 +361,28 @@ class TCPClient(_Connection): self.ssl_established = False self.sni = None - def convert_to_ssl(self, cert=None, sni=None, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), cipher_list=None): - """ - cert: Path to a file containing both client cert and private key. - - options: A bit field consisting of OpenSSL.SSL.OP_* values - """ - context = SSL.Context(method) - if cipher_list: - try: - context.set_cipher_list(cipher_list) - except SSL.Error, v: - raise NetLibError("SSL cipher specification error: %s"%str(v)) - if options is not None: - context.set_options(options) + def create_ssl_context(self, cert=None, **sslctx_kwargs): + context = self._create_ssl_context(**sslctx_kwargs) + # Client Certs if cert: try: context.use_privatekey_file(cert) context.use_certificate_file(cert) except SSL.Error, v: raise NetLibError("SSL client certificate error: %s"%str(v)) + return context + + def convert_to_ssl(self, sni=None, **sslctx_kwargs): + """ + cert: Path to a file containing both client cert and private key. + + options: A bit field consisting of OpenSSL.SSL.OP_* values + """ + context = self.create_ssl_context(**sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni self.connection.set_tlsext_host_name(sni) - if log_ssl_key: - context.set_info_callback(log_ssl_key) self.connection.set_connect_state() try: self.connection.do_handshake() @@ -400,21 +433,21 @@ class BaseHandler(_Connection): self.ssl_established = False self.clientcert = None - def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=OP_NO_SSLv2, - handle_sni=None, request_client_cert=None, cipher_list=None, - dhparams=None, chain_file=None): + def create_ssl_context(self, + cert, key, + handle_sni=None, + request_client_cert=None, + chain_file=None, + dhparams=None, + **sslctx_kwargs): """ cert: A certutils.SSLCert object. - method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD - handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: connection.get_servername() - options: A bit field consisting of OpenSSL.SSL.OP_* values - And you can specify the connection keys as follows: new_context = Context(TLSv1_METHOD) @@ -431,40 +464,38 @@ class BaseHandler(_Connection): we may be able to make the proper behaviour the default again, but until then we're conservative. """ - ctx = SSL.Context(method) - if not options is None: - ctx.set_options(options) - if chain_file: - ctx.load_verify_locations(chain_file) - if cipher_list: - try: - ctx.set_cipher_list(cipher_list) - except SSL.Error, v: - raise NetLibError("SSL cipher specification error: %s"%str(v)) + context = self._create_ssl_context(**sslctx_kwargs) + + context.use_privatekey(key) + context.use_certificate(cert.x509) + if handle_sni: # SNI callback happens during do_handshake() - ctx.set_tlsext_servername_callback(handle_sni) - ctx.use_privatekey(key) - ctx.use_certificate(cert.x509) - if dhparams: - SSL._lib.SSL_CTX_set_tmp_dh(ctx._context, dhparams) + context.set_tlsext_servername_callback(handle_sni) + if request_client_cert: - def ver(*args): - self.clientcert = certutils.SSLCert(args[1]) + def save_cert(conn, cert, errno, depth, preverify_ok): + self.clientcert = certutils.SSLCert(cert) # Return true to prevent cert verification error return True - ctx.set_verify(SSL.VERIFY_PEER, ver) - if log_ssl_key: - ctx.set_info_callback(log_ssl_key) - return ctx + context.set_verify(SSL.VERIFY_PEER, save_cert) + + # Cert Verify + if chain_file: + context.load_verify_locations(chain_file) + + if dhparams: + SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) + + return context def convert_to_ssl(self, cert, key, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) """ - ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) - self.connection = SSL.Connection(ctx, self.connection) + context = self.create_ssl_context(cert, key, **sslctx_kwargs) + self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() try: self.connection.do_handshake() @@ -474,7 +505,7 @@ class BaseHandler(_Connection): self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) - def handle(self): # pragma: no cover + def handle(self): # pragma: no cover raise NotImplementedError def settimeout(self, n): @@ -483,6 +514,7 @@ class BaseHandler(_Connection): class TCPServer(object): request_queue_size = 20 + def __init__(self, address): self.address = Address.wrap(address) self.__is_shut_down = threading.Event() @@ -508,7 +540,7 @@ class TCPServer(object): while not self.__shutdown_request: try: r, w, e = select.select([self.socket], [], [], poll_interval) - except select.error, ex: # pragma: no cover + except select.error as ex: # pragma: no cover if ex[0] == EINTR: continue else: @@ -516,12 +548,12 @@ class TCPServer(object): if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( - target = self.connection_thread, - args = (connection, client_address), - name = "ConnectionThread (%s:%s -> %s:%s)" % - (client_address[0], client_address[1], - self.address.host, self.address.port) - ) + target=self.connection_thread, + args=(connection, client_address), + name="ConnectionThread (%s:%s -> %s:%s)" % + (client_address[0], client_address[1], + self.address.host, self.address.port) + ) t.setDaemon(1) t.start() finally: diff --git a/test/tutils.py b/test/tutils.py index c8e06b96..ea30f59c 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -31,7 +31,7 @@ def raises(exc, obj, *args, **kwargs): :kwargs Arguments to be passed to the callable. """ try: - apply(obj, args, kwargs) + ret = apply(obj, args, kwargs) except Exception, v: if isinstance(exc, basestring): if exc.lower() in str(v).lower(): @@ -51,6 +51,6 @@ def raises(exc, obj, *args, **kwargs): exc.__name__, v.__class__.__name__, str(v) ) ) - raise AssertionError("No exception raised.") + raise AssertionError("No exception raised. Return value: {}".format(ret)) test_data = utils.Data(__name__) -- cgit v1.2.3 From d5eff70b6e7acb3bd60a5e6f8233cf4936a5d606 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 7 Mar 2015 01:31:31 +0100 Subject: fix tests on Windows --- netlib/tcp.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index ba4f008c..b2f11851 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -7,6 +7,7 @@ import threading import time import traceback from OpenSSL import SSL +import OpenSSL from . import certutils @@ -301,6 +302,10 @@ class _Connection(object): self.connection.shutdown() except SSL.Error: pass + except KeyError as e: + # Workaround for https://github.com/pyca/pyopenssl/pull/183 + if OpenSSL.__version__ != "0.14": + raise e """ Creates an SSL Context. -- cgit v1.2.3 From 6fbe3006afa46c4c5f19e5c52b66e6e73a07f819 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Apr 2015 00:12:41 +0200 Subject: fail gracefully if we cannot start a new thread --- netlib/tcp.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index b2f11851..45c60fd8 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -560,7 +560,11 @@ class TCPServer(object): self.address.host, self.address.port) ) t.setDaemon(1) - t.start() + try: + t.start() + except threading.ThreadError: + self.handle_error(connection, Address(client_address)) + connection.close() finally: self.__shutdown_request = False self.__is_shut_down.set() -- cgit v1.2.3 From 7f7ccd3a1865e8e73f3d1813182d01c607d6e501 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Apr 2015 00:57:37 +0200 Subject: 100% test coverage --- netlib/tcp.py | 2 +- test/test_tcp.py | 23 +++++++++++------------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 45c60fd8..20e7d45f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -302,7 +302,7 @@ class _Connection(object): self.connection.shutdown() except SSL.Error: pass - except KeyError as e: + except KeyError as e: # pragma: no cover # Workaround for https://github.com/pyca/pyopenssl/pull/183 if OpenSSL.__version__ != "0.14": raise e diff --git a/test/test_tcp.py b/test/test_tcp.py index b93b1e0a..4dbdd780 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,6 +1,7 @@ import cStringIO, Queue, time, socket, random import os from netlib import tcp, certutils, test, certffi +import threading import mock import tutils from OpenSSL import SSL @@ -39,6 +40,15 @@ class TestServer(test.ServerTestBase): c.wfile.flush() assert c.rfile.readline() == testval + def test_thread_start_error(self): + with mock.patch.object(threading.Thread, "start", side_effect=threading.ThreadError("nonewthread")) as m: + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + assert not c.rfile.read(1) + assert m.called + assert "nonewthread" in self.q.get_nowait() + self.test_echo() + class TestServerBind(test.ServerTestBase): class handler(tcp.BaseHandler): @@ -72,7 +82,7 @@ class TestServerIPv6(test.ServerTestBase): assert c.rfile.readline() == testval -class TestDisconnect(test.ServerTestBase): +class TestEcho(test.ServerTestBase): handler = EchoHandler def test_echo(self): testval = "echo!\n" @@ -553,17 +563,6 @@ class TestAddress: assert repr(a) -class TestServer(test.ServerTestBase): - handler = EchoHandler - def test_echo(self): - testval = "echo!\n" - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.wfile.write(testval) - c.wfile.flush() - assert c.rfile.readline() == testval - - class TestSSLKeyLogger(test.ServerTestBase): handler = EchoHandler ssl = dict( -- cgit v1.2.3 From e58f76aec1db9cc784a3b73c3050d010bb084968 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Apr 2015 02:09:33 +0200 Subject: fix code smell --- netlib/certutils.py | 4 ++-- netlib/http.py | 4 ++-- netlib/http_auth.py | 10 +++++----- netlib/odict.py | 2 +- netlib/tcp.py | 14 +++++++------- netlib/wsgi.py | 8 ++++---- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 5d8a56b8..f5375c03 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -120,7 +120,7 @@ class CertStoreEntry(object): self.chain_file = chain_file -class CertStore: +class CertStore(object): """ Implements an in-memory certificate store. """ @@ -288,7 +288,7 @@ class _GeneralNames(univ.SequenceOf): sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) -class SSLCert: +class SSLCert(object): def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. diff --git a/netlib/http.py b/netlib/http.py index d2fc6343..26438863 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -333,8 +333,8 @@ def read_response(rfile, request_method, body_size_limit, include_body=True): False ) else: - # if include_body==False then a None content means the body should be - # read separately + # if include_body==False then a None content means the body should be + # read separately content = None return httpversion, code, msg, headers, content diff --git a/netlib/http_auth.py b/netlib/http_auth.py index dca6e2f3..296e094c 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -3,7 +3,7 @@ from argparse import Action, ArgumentTypeError from . import http -class NullProxyAuth(): +class NullProxyAuth(object): """ No proxy auth at all (returns empty challange headers) """ @@ -59,12 +59,12 @@ class BasicProxyAuth(NullProxyAuth): return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm} -class PassMan(): +class PassMan(object): def test(self, username, password_token): return False -class PassManNonAnon: +class PassManNonAnon(PassMan): """ Ensure the user specifies a username, accept any password. """ @@ -74,7 +74,7 @@ class PassManNonAnon: return False -class PassManHtpasswd: +class PassManHtpasswd(PassMan): """ Read usernames and passwords from an htpasswd file """ @@ -89,7 +89,7 @@ class PassManHtpasswd: return bool(self.htpasswd.check_password(username, password_token)) -class PassManSingleUser: +class PassManSingleUser(PassMan): def __init__(self, username, password): self.username, self.password = username, password diff --git a/netlib/odict.py b/netlib/odict.py index f97f074b..7a2f611b 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -11,7 +11,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs): return re.subn(str(pattern), str(repl), target, *args, **kwargs) -class ODict: +class ODict(object): """ A dictionary-like object for managing ordered (key, value) data. """ diff --git a/netlib/tcp.py b/netlib/tcp.py index 20e7d45f..10269aa4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -64,7 +64,7 @@ class SSLKeyLogger(object): log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) -class _FileLike: +class _FileLike(object): BLOCKSIZE = 1024 * 32 def __init__(self, o): self.o = o @@ -134,8 +134,8 @@ class Writer(_FileLike): r = self.o.write(v) self.add_log(v[:r]) return r - except (SSL.Error, socket.error), v: - raise NetLibDisconnect(str(v)) + except (SSL.Error, socket.error) as e: + raise NetLibDisconnect(str(e)) class Reader(_FileLike): @@ -546,10 +546,10 @@ class TCPServer(object): try: r, w, e = select.select([self.socket], [], [], poll_interval) except select.error as ex: # pragma: no cover - if ex[0] == EINTR: - continue - else: - raise + if ex[0] == EINTR: + continue + else: + raise if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 568b1f9c..bac27d5a 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -3,18 +3,18 @@ import cStringIO, urllib, time, traceback from . import odict, tcp -class ClientConn: +class ClientConn(object): def __init__(self, address): self.address = tcp.Address.wrap(address) -class Flow: +class Flow(object): def __init__(self, address, request): self.client_conn = ClientConn(address) self.request = request -class Request: +class Request(object): def __init__(self, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content @@ -35,7 +35,7 @@ def date_time_string(): return s -class WSGIAdaptor: +class WSGIAdaptor(object): def __init__(self, app, domain, port, sversion): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion -- cgit v1.2.3 From e41e5cbfdd7b778e6f68e86658e95f9e413133cb Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Thu, 9 Apr 2015 19:35:40 -0700 Subject: netlib websockets --- netlib/http.py | 14 ++ netlib/utils.py | 3 + netlib/websockets/__init__.py | 1 + netlib/websockets/implementations.py | 81 ++++++++ netlib/websockets/websockets.py | 368 +++++++++++++++++++++++++++++++++++ test/test_websockets.py | 15 ++ 6 files changed, 482 insertions(+) create mode 100644 netlib/websockets/__init__.py create mode 100644 netlib/websockets/implementations.py create mode 100644 netlib/websockets/websockets.py create mode 100644 test/test_websockets.py diff --git a/netlib/http.py b/netlib/http.py index 26438863..2c72621d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -29,6 +29,20 @@ def _is_valid_host(host): return None return True +def is_successful_upgrade(request, response): + """ + determines if a client and server successfully agreed to an HTTP protocol upgrade + + https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism + """ + http_switching_protocols_code = 101 + + if request and response: + responseUpgrade = request.headers.get("Upgrade") + requestUpgrade = response.headers.get("Upgrade") + if response.code == http_switching_protocols_code and responseUpgrade == requestUpgrade: + return requestUpgrade[0] if len(requestUpgrade) > 0 else None + return None def parse_url(url): """ diff --git a/netlib/utils.py b/netlib/utils.py index 79077ac6..03a70977 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -8,6 +8,9 @@ def isascii(s): return False return True +# best way to do it in python 2.x +def bytes_to_int(i): + return int(i.encode('hex'), 16) def cleanBin(s, fixspacing=False): """ diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py new file mode 100644 index 00000000..9b4faa33 --- /dev/null +++ b/netlib/websockets/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py new file mode 100644 index 00000000..78ae5be6 --- /dev/null +++ b/netlib/websockets/implementations.py @@ -0,0 +1,81 @@ +from netlib import tcp +from base64 import b64encode +from StringIO import StringIO +from . import websockets as ws +import struct +import SocketServer +import os + +# Simple websocket client and servers that are used to exercise the functionality in websockets.py +# These are *not* fully RFC6455 compliant + +class WebSocketsEchoHandler(tcp.BaseHandler): + def __init__(self, connection, address, server): + super(WebSocketsEchoHandler, self).__init__(connection, address, server) + self.handshake_done = False + + def handle(self): + while True: + if not self.handshake_done: + self.handshake() + else: + self.read_next_message() + + def read_next_message(self): + decoded = ws.WebSocketsFrame.from_byte_stream(self.rfile.read).decoded_payload + self.on_message(decoded) + + def send_message(self, message): + frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = False) + self.wfile.write(frame.to_bytes()) + self.wfile.flush() + + def handshake(self): + client_hs = ws.read_handshake(self.rfile.read, 1) + key = ws.server_process_handshake(client_hs) + response = ws.create_server_handshake(key) + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + + def on_message(self, message): + if message is not None: + self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + def __init__(self, address, source_address=None): + super(WebSocketsClient, self).__init__(address, source_address) + self.version = "13" + self.key = b64encode(os.urandom(16)).decode('utf-8') + self.resource = "/" + + def connect(self): + super(WebSocketsClient, self).connect() + + handshake = ws.create_client_handshake( + self.address.host, + self.address.port, + self.key, + self.version, + self.resource + ) + + self.wfile.write(handshake) + self.wfile.flush() + + response = ws.read_handshake(self.rfile.read, 1) + + if not response: + self.close() + + def read_next_message(self): + try: + return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + except IndexError: + self.close() + + def send_message(self, message): + frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = True) + self.wfile.write(frame.to_bytes()) + self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py new file mode 100644 index 00000000..b796ce39 --- /dev/null +++ b/netlib/websockets/websockets.py @@ -0,0 +1,368 @@ +from __future__ import absolute_import + +from base64 import b64encode +from hashlib import sha1 +from mimetools import Message +from netlib import tcp +from netlib import utils +from StringIO import StringIO +import os +import SocketServer +import struct +import io + +# Colleciton of utility functions that implement small portions of the RFC6455 WebSockets Protocol +# Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or completeness +# +# This is a work in progress and does not yet contain all the utilites need to create fully complient client/servers +# +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + +class WebSocketFrameValidationException(Exception): + pass + +class WebSocketsFrame(object): + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + def __init__( + self, + fin, # decmial integer 1 or 0 + opcode, # decmial integer 1 - 4 + mask_bit, # decimal integer 1 or 0 + payload_length_code, # decimal integer 1 - 127 + decoded_payload, # bytestring + rsv1 = 0, # decimal integer 1 or 0 + rsv2 = 0, # decimal integer 1 or 0 + rsv3 = 0, # decimal integer 1 or 0 + payload = None, # bytestring + masking_key = None, # 32 bit byte string + actual_payload_length = None, # any decimal integer + use_validation = True # indicates whether or not you care if this frame adheres to the spec + ): + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.mask_bit = mask_bit + self.payload_length_code = payload_length_code + self.masking_key = masking_key + self.payload = payload + self.decoded_payload = decoded_payload + self.actual_payload_length = actual_payload_length + self.use_validation = use_validation + + if self.use_validation: + self.validate_frame() + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use read_frame() directly + """ + self.from_byte_stream(io.BytesIO(bytestring).read) + + @classmethod + def default_frame_from_message(cls, message, from_client = False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + length_code, actual_length = get_payload_length_pair(message) + + if from_client: + mask_bit = 1 + masking_key = random_masking_key() + payload = apply_mask(message, masking_key) + else: + mask_bit = 0 + masking_key = None + payload = message + + return cls( + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, + actual_payload_length = actual_length + ) + + def validate_frame(self): + """ + Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame + has not been corrupted. + """ + try: + assert 0 <= self.fin <= 1 + assert 0 <= self.rsv1 <= 1 + assert 0 <= self.rsv2 <= 1 + assert 0 <= self.rsv3 <= 1 + assert 1 <= self.opcode <= 4 + assert 0 <= self.mask_bit <= 1 + assert 1 <= self.payload_length_code <= 127 + + if self.mask_bit == 1: + assert 1 <= len(self.masking_key) <= 4 + else: + assert self.masking_key == None + + assert self.actual_payload_length == len(self.payload) + + if self.payload is not None and self.masking_key is not None: + apply_mask(self.payload, self.masking_key) == self.decoded_payload + + except AssertionError: + raise WebSocketFrameValidationException() + + def human_readable(self): + return "\n".join([ + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)), + ("use_validation - " + str(self.use_validation))]) + + def to_bytes(self): + """ + Serialize the frame back into the wire format, returns a bytestring + """ + # validate enforces all the assumptions made by this serializer + # in the spritit of mitmproxy, it's possible to create and serialize invalid frames + # by skipping validation. + if self.use_validation: + self.validate_frame() + + max_16_bit_int = (1 << 16) + max_64_bit_int = (1 << 63) + + # break down of the bit-math used to construct the first byte from the frame's integer values + # first shift the significant bit into the correct position + # 00000001 << 7 = 10000000 + # ... + # then combine: + # + # 10000000 fin + # 01000000 res1 + # 00100000 res2 + # 00010000 res3 + # 00000001 opcode + # -------- OR + # 11110001 = first_byte + + first_byte = (self.fin << 7) | (self.rsv1 << 6) | (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + + second_byte = (self.mask_bit << 7) | self.payload_length_code + + bytes = chr(first_byte) + chr(second_byte) + + if self.actual_payload_length < 126: + pass + + elif self.actual_payload_length < max_16_bit_int: + # '!H' pack as 16 bit unsigned short + bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length + + elif self.actual_payload_length < max_64_bit_int: + # '!Q' = pack as 64 bit unsigned long long + bytes += struct.pack('!Q', self.actual_payload_length) # add 8 bytes extended payload length + + if self.masking_key is not None: + bytes += self.masking_key + + bytes += self.payload # already will be encoded if neccessary + + return bytes + + + @classmethod + def from_byte_stream(cls, read_bytes): + """ + read a websockets frame sent by a server or client + + read_bytes is a function that can be backed + by sockets or by any byte reader. So this + function may be used to read frames from disk/wire/memory + """ + first_byte = utils.bytes_to_int(read_bytes(1)) + second_byte = utils.bytes_to_int(read_bytes(1)) + + fin = first_byte >> 7 # grab the left most bit + opcode = first_byte & 15 # grab right most 4 bits by and-ing with 00001111 + mask_bit = second_byte >> 7 # grab left most bit + payload_length = second_byte & 127 # grab the next 7 bits + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if payload_length <= 125: + actual_payload_length = payload_length + + elif payload_length == 126: + actual_payload_length = utils.bytes_to_int(read_bytes(2)) + + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = read_bytes(4) + else: + masking_key = None + + payload = read_bytes(actual_payload_length) + + if mask_bit == 1: + decoded_payload = apply_mask(payload, masking_key) + else: + decoded_payload = payload + + return cls( + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, + actual_payload_length = actual_payload_length + ) + +def apply_mask(message, masking_key): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + This method both encodes and decodes strings with the provided mask + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + masks = [utils.bytes_to_int(byte) for byte in masking_key] + result = "" + for char in message: + result += chr(ord(char) ^ masks[len(result) % 4]) + return result + +def random_masking_key(): + return os.urandom(4) + +def masking_key_list(masking_key): + return [utils.bytes_to_int(byte) for byte in masking_key] + +def create_client_handshake(host, port, key, version, resource): + """ + WebSockets connections are intiated by the client with a valid HTTP upgrade request + """ + headers = [ + ('Host', '%s:%s' % (host, port)), + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Key', key), + ('Sec-WebSocket-Version', version) + ] + request = "GET %s HTTP/1.1" % resource + return build_handshake(headers, request) + + +def create_server_handshake(key, magic = websockets_magic): + """ + The server response is a valid HTTP 101 response. + """ + digest = b64encode(sha1(key + magic).hexdigest().decode('hex')) + headers = [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', digest) + ] + request = "HTTP/1.1 101 Switching Protocols" + return build_handshake(headers, request) + + +def build_handshake(headers, request): + handshake = [request.encode('utf-8')] + for header, value in headers: + handshake.append(("%s: %s" % (header, value)).encode('utf-8')) + handshake.append(b'\r\n') + return b'\r\n'.join(handshake) + + +def read_handshake(read_bytes, num_bytes_per_read): + """ + From provided function that reads bytes, read in a + complete HTTP request, which terminates with a CLRF + """ + response = b'' + doubleCLRF = b'\r\n\r\n' + while True: + bytes = read_bytes(num_bytes_per_read) + if not bytes: + break + response += bytes + if doubleCLRF in response: + break + return response + +def get_payload_length_pair(payload_bytestring): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is larger + than 125 + """ + actual_length = len(payload_bytestring) + + if actual_length <= 125: + length_code = actual_length + elif actual_length >= 126 and actual_length <= 65535: + length_code = 126 + else: + length_code = 127 + return (length_code, actual_length) + +def server_process_handshake(handshake): + headers = Message(StringIO(handshake.split('\r\n', 1)[1])) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Key'] + return key + +def generate_client_nounce(): + return b64encode(os.urandom(16)).decode('utf-8') + diff --git a/test/test_websockets.py b/test/test_websockets.py new file mode 100644 index 00000000..d7e1627f --- /dev/null +++ b/test/test_websockets.py @@ -0,0 +1,15 @@ +from netlib import test +from netlib.websockets import implementations as ws + +class TestWebSockets(test.ServerTestBase): + handler = ws.WebSocketsEchoHandler + + def test_websockets_echo(self): + msg = "hello I'm the client" + client = ws.WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message(msg) + response = client.read_next_message() + print "Assert response: " + response + " == msg: " + msg + assert response == msg + -- cgit v1.2.3 From 0edc04814e3affa71025938ac354707b9b4c481c Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sat, 11 Apr 2015 11:35:15 -0700 Subject: small cleanups, working on tests --- netlib/websockets/implementations.py | 10 +++++----- netlib/websockets/websockets.py | 35 +++++++++++++++++------------------ test/test_websockets.py | 24 +++++++++++++++++++----- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 78ae5be6..ff42ff65 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -26,8 +26,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.on_message(decoded) def send_message(self, message): - frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = False) - self.wfile.write(frame.to_bytes()) + frame = ws.WebSocketsFrame.default(message, from_client = False) + self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() def handshake(self): @@ -47,7 +47,7 @@ class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) self.version = "13" - self.key = b64encode(os.urandom(16)).decode('utf-8') + self.key = ws.generate_client_nounce() self.resource = "/" def connect(self): @@ -76,6 +76,6 @@ class WebSocketsClient(tcp.TCPClient): self.close() def send_message(self, message): - frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = True) - self.wfile.write(frame.to_bytes()) + frame = ws.WebSocketsFrame.default(message, from_client = True) + self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index b796ce39..527d55d6 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -65,7 +65,6 @@ class WebSocketsFrame(object): payload = None, # bytestring masking_key = None, # 32 bit byte string actual_payload_length = None, # any decimal integer - use_validation = True # indicates whether or not you care if this frame adheres to the spec ): self.fin = fin self.rsv1 = rsv1 @@ -78,21 +77,18 @@ class WebSocketsFrame(object): self.payload = payload self.decoded_payload = decoded_payload self.actual_payload_length = actual_payload_length - self.use_validation = use_validation - - if self.use_validation: - self.validate_frame() @classmethod def from_bytes(cls, bytestring): """ Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use read_frame() directly + to construct a frame from a stream of bytes, use from_byte_stream() directly """ self.from_byte_stream(io.BytesIO(bytestring).read) + @classmethod - def default_frame_from_message(cls, message, from_client = False): + def default(cls, message, from_client = False): """ Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. @@ -119,7 +115,7 @@ class WebSocketsFrame(object): actual_payload_length = actual_length ) - def validate_frame(self): + def frame_is_valid(self): """ Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame has not been corrupted. @@ -141,10 +137,11 @@ class WebSocketsFrame(object): assert self.actual_payload_length == len(self.payload) if self.payload is not None and self.masking_key is not None: - apply_mask(self.payload, self.masking_key) == self.decoded_payload + assert apply_mask(self.payload, self.masking_key) == self.decoded_payload + return True except AssertionError: - raise WebSocketFrameValidationException() + return False def human_readable(self): return "\n".join([ @@ -161,15 +158,19 @@ class WebSocketsFrame(object): ("actual_payload_length - " + str(self.actual_payload_length)), ("use_validation - " + str(self.use_validation))]) + def safe_to_bytes(self): + try: + assert self.frame_is_valid() + return self.to_bytes() + except: + raise WebSocketFrameValidationException() + def to_bytes(self): """ Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees that the + serialized bytes will be correct. see safe_to_bytes() """ - # validate enforces all the assumptions made by this serializer - # in the spritit of mitmproxy, it's possible to create and serialize invalid frames - # by skipping validation. - if self.use_validation: - self.validate_frame() max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) @@ -198,6 +199,7 @@ class WebSocketsFrame(object): pass elif self.actual_payload_length < max_16_bit_int: + # '!H' pack as 16 bit unsigned short bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length @@ -284,9 +286,6 @@ def apply_mask(message, masking_key): def random_masking_key(): return os.urandom(4) -def masking_key_list(masking_key): - return [utils.bytes_to_int(byte) for byte in masking_key] - def create_client_handshake(host, port, key, version, resource): """ WebSockets connections are intiated by the client with a valid HTTP upgrade request diff --git a/test/test_websockets.py b/test/test_websockets.py index d7e1627f..0b2647ef 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,15 +1,29 @@ from netlib import test -from netlib.websockets import implementations as ws +from netlib.websockets import implementations as impl +from netlib.websockets import websockets as ws +import os class TestWebSockets(test.ServerTestBase): - handler = ws.WebSocketsEchoHandler + handler = impl.WebSocketsEchoHandler - def test_websockets_echo(self): - msg = "hello I'm the client" - client = ws.WebSocketsClient(("127.0.0.1", self.port)) + def echo(self, msg): + client = impl.WebSocketsClient(("127.0.0.1", self.port)) client.connect() client.send_message(msg) response = client.read_next_message() print "Assert response: " + response + " == msg: " + msg assert response == msg + def test_simple_echo(self): + self.echo("hello I'm the client") + + def test_frame_sizes(self): + small_string = os.urandom(100) # length can fit in the the 7 bit payload length + medium_string = os.urandom(50000) # 50kb, sligthly larger than can fit in a 7 bit int + large_string = os.urandom(150000) # 150kb, slightly larger than can fit in a 16 bit int + + self.echo(small_string) + self.echo(medium_string) + self.echo(large_string) + + -- cgit v1.2.3 From 73ce169e3d11eeabeb78143bd86edfdbc3e07fd9 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 12 Apr 2015 10:26:09 +1200 Subject: Initial outline of a cookie parsing and serialization module. --- .env | 5 ++ netlib/http_cookies.py | 133 ++++++++++++++++++++++++++++++++++++++++++++++ test/test_http_cookies.py | 106 ++++++++++++++++++++++++++++++++++++ 3 files changed, 244 insertions(+) create mode 100644 .env create mode 100644 netlib/http_cookies.py create mode 100644 test/test_http_cookies.py diff --git a/.env b/.env new file mode 100644 index 00000000..7f847e29 --- /dev/null +++ b/.env @@ -0,0 +1,5 @@ +DIR=`dirname $0` +if [ -z "$VIRTUAL_ENV" ] && [ -f $DIR/../venv.mitmproxy/bin/activate ]; then + echo "Activating mitmproxy virtualenv..." + source $DIR/../venv.mitmproxy/bin/activate +fi diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py new file mode 100644 index 00000000..e11e0f90 --- /dev/null +++ b/netlib/http_cookies.py @@ -0,0 +1,133 @@ +""" +A flexible module for cookie parsing and manipulation. + +We try to be as permissive as possible. Parsing accepts formats from RFC6265 an +RFC2109. Serialization follows RFC6265 strictly. + + http://tools.ietf.org/html/rfc6265 + http://tools.ietf.org/html/rfc2109 +""" + +import re + +import odict + + +def _read_until(s, start, term): + """ + Read until one of the characters in term is reached. + """ + if start == len(s): + return "", start+1 + for i in range(start, len(s)): + if s[i] in term: + return s[start:i], i + return s[start:i+1], i+1 + + +def _read_token(s, start): + """ + Read a token - the LHS of a token/value pair in a cookie. + """ + return _read_until(s, start, ";=") + + +def _read_quoted_string(s, start): + """ + start: offset to the first quote of the string to be read + + A sort of loose super-set of the various quoted string specifications. + + RFC6265 disallows backslashes or double quotes within quoted strings. + Prior RFCs use backslashes to escape. This leaves us free to apply + backslash escaping by default and be compatible with everything. + """ + escaping = False + ret = [] + # Skip the first quote + for i in range(start+1, len(s)): + if escaping: + ret.append(s[i]) + escaping = False + elif s[i] == '"': + break + elif s[i] == "\\": + escaping = True + pass + else: + ret.append(s[i]) + return "".join(ret), i+1 + + +def _read_value(s, start): + """ + Reads a value - the RHS of a token/value pair in a cookie. + """ + if s[start] == '"': + return _read_quoted_string(s, start) + else: + return _read_until(s, start, ";,") + + +def _read_pairs(s): + """ + Read pairs of lhs=rhs values. + """ + off = 0 + vals = [] + while 1: + lhs, off = _read_token(s, off) + rhs = None + if off < len(s): + if s[off] == "=": + rhs, off = _read_value(s, off+1) + vals.append([lhs.lstrip(), rhs]) + off += 1 + if not off < len(s): + break + return vals, off + + +ESCAPE = re.compile(r"([\"\\])") +SPECIAL = re.compile(r"^\w+$") + + +def _format_pairs(lst): + vals = [] + for k, v in lst: + if v is None: + vals.append(k) + else: + match = SPECIAL.search(v) + if match: + v = ESCAPE.sub(r"\1", v) + vals.append("%s=%s"%(k, v)) + return "; ".join(vals) + + +def parse_cookies(s): + """ + Parses a Cookie header value. + Returns an ODict object. + """ + pairs, off = _read_pairs(s) + return odict.ODict(pairs) + + +def unparse_cookies(od): + """ + Formats a Cookie header value. + """ + vals = [] + for i in od.lst: + vals.append("%s=%s"%(i[0], i[1])) + return "; ".join(vals) + + + +def parse_set_cookies(s): + start = 0 + + +def unparse_set_cookies(s): + pass diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py new file mode 100644 index 00000000..b3f1f914 --- /dev/null +++ b/test/test_http_cookies.py @@ -0,0 +1,106 @@ +from netlib import http_cookies, odict +import nose.tools + + +def test_read_token(): + tokens = [ + [("foo", 0), ("foo", 3)], + [("foo", 1), ("oo", 3)], + [(" foo", 1), ("foo", 4)], + [(" foo;", 1), ("foo", 4)], + [(" foo=", 1), ("foo", 4)], + [(" foo=bar", 1), ("foo", 4)], + ] + for q, a in tokens: + nose.tools.eq_(http_cookies._read_token(*q), a) + + +def test_read_quoted_string(): + tokens = [ + [('"foo" x', 0), ("foo", 5)], + [('"f\oo" x', 0), ("foo", 6)], + [(r'"f\\o" x', 0), (r"f\o", 6)], + [(r'"f\\" x', 0), (r"f" + '\\', 5)], + [('"fo\\\"" x', 0), ("fo\"", 6)], + ] + for q, a in tokens: + nose.tools.eq_(http_cookies._read_quoted_string(*q), a) + + +def test_read_pairs(): + vals = [ + [ + "one", + [["one", None]] + ], + [ + "one=two", + [["one", "two"]] + ], + [ + 'one="two"', + [["one", "two"]] + ], + [ + 'one="two"; three=four', + [["one", "two"], ["three", "four"]] + ], + [ + 'one="two"; three=four; five', + [["one", "two"], ["three", "four"], ["five", None]] + ], + [ + 'one="\\"two"; three=four', + [["one", '"two'], ["three", "four"]] + ], + ] + for s, lst in vals: + ret, off = http_cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + + +def test_pairs_roundtrips(): + pairs = [ + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one", + [["one", None]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="uno"; two="\due"', + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="un\\"o"', + [["one", 'un"o']] + ], + [ + "one=uno; two; three=tre", + [["one", "uno"], ["two", None], ["three", "tre"]] + ], + [ + "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " + "_rcc2=53VdltWl+Ov6ordflA==;", + [ + ["_lvs2", "zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g="], + ["_rcc2", "53VdltWl+Ov6ordflA=="] + ] + ] + ] + for s, lst in pairs: + ret, off = http_cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + s2 = http_cookies._format_pairs(lst) + ret, off = http_cookies._read_pairs(s2) + nose.tools.eq_(ret, lst) + + +def test_parse_set_cookie(): + pass -- cgit v1.2.3 From 2630da7263242411d413b5e4b2c520d29848c918 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 12 Apr 2015 11:26:02 +1200 Subject: cookies: Cater for special values, fix some bugs found in real-world testing --- netlib/http_cookies.py | 48 ++++++++++++++++++++++++++++++++--------------- test/test_http_cookies.py | 8 ++++++++ 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index e11e0f90..82675418 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -59,29 +59,39 @@ def _read_quoted_string(s, start): return "".join(ret), i+1 -def _read_value(s, start): +def _read_value(s, start, special): """ Reads a value - the RHS of a token/value pair in a cookie. + + special: If the value is special, commas are premitted. Else comma + terminates. This helps us support old and new style values. """ - if s[start] == '"': + if start >= len(s): + return "", start + elif s[start] == '"': return _read_quoted_string(s, start) + elif special: + return _read_until(s, start, ";") else: return _read_until(s, start, ";,") -def _read_pairs(s): +def _read_pairs(s, specials=()): """ Read pairs of lhs=rhs values. + + specials: A lower-cased list of keys that may contain commas. """ off = 0 vals = [] while 1: lhs, off = _read_token(s, off) + lhs = lhs.lstrip() rhs = None if off < len(s): if s[off] == "=": - rhs, off = _read_value(s, off+1) - vals.append([lhs.lstrip(), rhs]) + rhs, off = _read_value(s, off+1, lhs.lower() in specials) + vals.append([lhs, rhs]) off += 1 if not off < len(s): break @@ -89,18 +99,30 @@ def _read_pairs(s): ESCAPE = re.compile(r"([\"\\])") -SPECIAL = re.compile(r"^\w+$") -def _format_pairs(lst): +def _has_special(s): + for i in s: + if i in '",;\\': + return True + o = ord(i) + if o < 0x21 or o > 0x7e: + return True + return False + + +def _format_pairs(lst, specials=()): + """ + specials: A lower-cased list of keys that will not be quoted. + """ vals = [] for k, v in lst: if v is None: vals.append(k) else: - match = SPECIAL.search(v) - if match: - v = ESCAPE.sub(r"\1", v) + if k.lower() not in specials and _has_special(v): + v = ESCAPE.sub(r"\\\1", v) + v = '"%s"'%v vals.append("%s=%s"%(k, v)) return "; ".join(vals) @@ -118,11 +140,7 @@ def unparse_cookies(od): """ Formats a Cookie header value. """ - vals = [] - for i in od.lst: - vals.append("%s=%s"%(i[0], i[1])) - return "; ".join(vals) - + return _format_pairs(od.lst) def parse_set_cookies(s): diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py index b3f1f914..31e5f0b0 100644 --- a/test/test_http_cookies.py +++ b/test/test_http_cookies.py @@ -37,6 +37,10 @@ def test_read_pairs(): "one=two", [["one", "two"]] ], + [ + "one=", + [["one", ""]] + ], [ 'one="two"', [["one", "two"]] @@ -81,6 +85,10 @@ def test_pairs_roundtrips(): 'one="un\\"o"', [["one", 'un"o']] ], + [ + 'one="uno,due"', + [["one", 'uno,due']] + ], [ "one=uno; two; three=tre", [["one", "uno"], ["two", None], ["three", "tre"]] -- cgit v1.2.3 From f131f9b855e77554072415c925ed112ec74ee48a Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sat, 11 Apr 2015 15:40:18 -0700 Subject: handshake tests, serialization test --- netlib/websockets/implementations.py | 19 +++++++---- netlib/websockets/websockets.py | 51 +++++++++++++++++++++-------- test/test_websockets.py | 63 ++++++++++++++++++++++++++++++++---- 3 files changed, 105 insertions(+), 28 deletions(-) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index ff42ff65..73a84690 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -32,7 +32,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) - key = ws.server_process_handshake(client_hs) + key = ws.process_handshake_from_client(client_hs) response = ws.create_server_handshake(key) self.wfile.write(response) self.wfile.flush() @@ -46,9 +46,9 @@ class WebSocketsEchoHandler(tcp.BaseHandler): class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) - self.version = "13" - self.key = ws.generate_client_nounce() - self.resource = "/" + self.version = "13" + self.client_nounce = ws.create_client_nounce() + self.resource = "/" def connect(self): super(WebSocketsClient, self).connect() @@ -56,7 +56,7 @@ class WebSocketsClient(tcp.TCPClient): handshake = ws.create_client_handshake( self.address.host, self.address.port, - self.key, + self.client_nounce, self.version, self.resource ) @@ -64,9 +64,14 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(handshake) self.wfile.flush() - response = ws.read_handshake(self.rfile.read, 1) + server_handshake = ws.read_handshake(self.rfile.read, 1) - if not response: + if not server_handshake: + self.close() + + server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) + + if not server_nounce == ws.create_server_nounce(self.client_nounce): self.close() def read_next_message(self): diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index 527d55d6..cf9a68aa 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -84,7 +84,7 @@ class WebSocketsFrame(object): Construct a websocket frame from an in-memory bytestring to construct a frame from a stream of bytes, use from_byte_stream() directly """ - self.from_byte_stream(io.BytesIO(bytestring).read) + return cls.from_byte_stream(io.BytesIO(bytestring).read) @classmethod @@ -115,7 +115,7 @@ class WebSocketsFrame(object): actual_payload_length = actual_length ) - def frame_is_valid(self): + def is_valid(self): """ Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame has not been corrupted. @@ -155,12 +155,11 @@ class WebSocketsFrame(object): ("masking_key - " + str(self.masking_key)), ("payload - " + str(self.payload)), ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length)), - ("use_validation - " + str(self.use_validation))]) + ("actual_payload_length - " + str(self.actual_payload_length))]) def safe_to_bytes(self): try: - assert self.frame_is_valid() + assert self.is_valid() return self.to_bytes() except: raise WebSocketFrameValidationException() @@ -197,7 +196,7 @@ class WebSocketsFrame(object): if self.actual_payload_length < 126: pass - + elif self.actual_payload_length < max_16_bit_int: # '!H' pack as 16 bit unsigned short @@ -267,6 +266,20 @@ class WebSocketsFrame(object): actual_payload_length = actual_payload_length ) + def __eq__(self, other): + return ( + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length) + def apply_mask(message, masking_key): """ Data sent from the server must be masked to prevent malicious clients @@ -300,16 +313,14 @@ def create_client_handshake(host, port, key, version, resource): request = "GET %s HTTP/1.1" % resource return build_handshake(headers, request) - -def create_server_handshake(key, magic = websockets_magic): +def create_server_handshake(key): """ The server response is a valid HTTP 101 response. """ - digest = b64encode(sha1(key + magic).hexdigest().decode('hex')) headers = [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', digest) + ('Sec-WebSocket-Accept', create_server_nounce(key)) ] request = "HTTP/1.1 101 Switching Protocols" return build_handshake(headers, request) @@ -322,7 +333,6 @@ def build_handshake(headers, request): handshake.append(b'\r\n') return b'\r\n'.join(handshake) - def read_handshake(read_bytes, num_bytes_per_read): """ From provided function that reads bytes, read in a @@ -355,13 +365,26 @@ def get_payload_length_pair(payload_bytestring): length_code = 127 return (length_code, actual_length) -def server_process_handshake(handshake): - headers = Message(StringIO(handshake.split('\r\n', 1)[1])) +def process_handshake_from_client(handshake): + headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": return key = headers['Sec-WebSocket-Key'] return key -def generate_client_nounce(): +def process_handshake_from_server(handshake, client_nounce): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Accept'] + return key + +def headers_from_http_message(http_message): + return Message(StringIO(http_message.split('\r\n', 1)[1])) + +def create_server_nounce(client_nounce): + return b64encode(sha1(client_nounce + websockets_magic).hexdigest().decode('hex')) + +def create_client_nounce(): return b64encode(os.urandom(16)).decode('utf-8') diff --git a/test/test_websockets.py b/test/test_websockets.py index 0b2647ef..a5ebf3d1 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,29 +1,78 @@ +from netlib import tcp from netlib import test from netlib.websockets import implementations as impl from netlib.websockets import websockets as ws import os +from nose.tools import raises class TestWebSockets(test.ServerTestBase): handler = impl.WebSocketsEchoHandler + def random_bytes(self, n = 100): + return os.urandom(n) + def echo(self, msg): client = impl.WebSocketsClient(("127.0.0.1", self.port)) client.connect() client.send_message(msg) response = client.read_next_message() - print "Assert response: " + response + " == msg: " + msg assert response == msg def test_simple_echo(self): self.echo("hello I'm the client") def test_frame_sizes(self): - small_string = os.urandom(100) # length can fit in the the 7 bit payload length - medium_string = os.urandom(50000) # 50kb, sligthly larger than can fit in a 7 bit int - large_string = os.urandom(150000) # 150kb, slightly larger than can fit in a 16 bit int + small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length + medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int + large_msg = self.random_bytes(150000) # 150kb, slightly larger than can fit in a 16 bit int + + self.echo(small_msg) + self.echo(medium_msg) + self.echo(large_msg) + + def test_default_builder(self): + """ + default builder should always generate valid frames + """ + msg = self.random_bytes() + client_frame = ws.WebSocketsFrame.default(msg, from_client = True) + assert client_frame.is_valid() + + server_frame = ws.WebSocketsFrame.default(msg, from_client = False) + assert server_frame.is_valid() + + def test_serialization_bijection(self): + for is_client in [True, False]: + for num_bytes in [100, 50000, 150000]: + frame = ws.WebSocketsFrame.default(self.random_bytes(num_bytes), is_client) + assert frame == ws.WebSocketsFrame.from_bytes(frame.to_bytes()) + + bytes = b'\x81\x11cba' + assert ws.WebSocketsFrame.from_bytes(bytes).to_bytes() == bytes + + +class BadHandshakeHandler(impl.WebSocketsEchoHandler): + def handshake(self): + client_hs = ws.read_handshake(self.rfile.read, 1) + key = ws.process_handshake_from_client(client_hs) + response = ws.create_server_handshake("malformed_key") + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + +class TestBadHandshake(test.ServerTestBase): + """ + Ensure that the client disconnects if the server handshake is malformed + """ + handler = BadHandshakeHandler + + @raises(tcp.NetLibDisconnect) + def test(self): + client = impl.WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message("hello") + + - self.echo(small_string) - self.echo(medium_string) - self.echo(large_string) -- cgit v1.2.3 From 0ed2a290639833d772b89cf333577820e84f8204 Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sat, 11 Apr 2015 17:28:52 -0700 Subject: whitespace --- test/test_websockets.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/test_websockets.py b/test/test_websockets.py index a5ebf3d1..0c23e355 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -70,9 +70,4 @@ class TestBadHandshake(test.ServerTestBase): def test(self): client = impl.WebSocketsClient(("127.0.0.1", self.port)) client.connect() - client.send_message("hello") - - - - - + client.send_message("hello") \ No newline at end of file -- cgit v1.2.3 From 2d72a1b6b56f1643cd1d8be59eee55aa7ca2f17f Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Mon, 13 Apr 2015 13:36:09 -0700 Subject: 100% test coverage, though still need plenty more --- netlib/http.py | 14 -------------- netlib/websockets/implementations.py | 10 ++-------- netlib/websockets/websockets.py | 9 ++++----- test/test_websockets.py | 14 ++++++++++++-- 4 files changed, 18 insertions(+), 29 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 2c72621d..26438863 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -29,20 +29,6 @@ def _is_valid_host(host): return None return True -def is_successful_upgrade(request, response): - """ - determines if a client and server successfully agreed to an HTTP protocol upgrade - - https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism - """ - http_switching_protocols_code = 101 - - if request and response: - responseUpgrade = request.headers.get("Upgrade") - requestUpgrade = response.headers.get("Upgrade") - if response.code == http_switching_protocols_code and responseUpgrade == requestUpgrade: - return requestUpgrade[0] if len(requestUpgrade) > 0 else None - return None def parse_url(url): """ diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 73a84690..1ded3b85 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -65,9 +65,6 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.flush() server_handshake = ws.read_handshake(self.rfile.read, 1) - - if not server_handshake: - self.close() server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) @@ -75,11 +72,8 @@ class WebSocketsClient(tcp.TCPClient): self.close() def read_next_message(self): - try: - return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload - except IndexError: - self.close() - + return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + def send_message(self, message): frame = ws.WebSocketsFrame.default(message, from_client = True) self.wfile.write(frame.safe_to_bytes()) diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index cf9a68aa..ea3db21d 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -158,11 +158,10 @@ class WebSocketsFrame(object): ("actual_payload_length - " + str(self.actual_payload_length))]) def safe_to_bytes(self): - try: - assert self.is_valid() - return self.to_bytes() - except: - raise WebSocketFrameValidationException() + if self.is_valid(): + return self.to_bytes() + else: + raise WebSocketFrameValidationException() def to_bytes(self): """ diff --git a/test/test_websockets.py b/test/test_websockets.py index 0c23e355..951aa41f 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -22,8 +22,8 @@ class TestWebSockets(test.ServerTestBase): self.echo("hello I'm the client") def test_frame_sizes(self): - small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length - medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int + small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length + medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int large_msg = self.random_bytes(150000) # 150kb, slightly larger than can fit in a 16 bit int self.echo(small_msg) @@ -42,6 +42,10 @@ class TestWebSockets(test.ServerTestBase): assert server_frame.is_valid() def test_serialization_bijection(self): + """ + Ensure that various frame types can be serialized/deserialized back and forth + between to_bytes() and from_bytes() + """ for is_client in [True, False]: for num_bytes in [100, 50000, 150000]: frame = ws.WebSocketsFrame.default(self.random_bytes(num_bytes), is_client) @@ -50,6 +54,12 @@ class TestWebSockets(test.ServerTestBase): bytes = b'\x81\x11cba' assert ws.WebSocketsFrame.from_bytes(bytes).to_bytes() == bytes + @raises(ws.WebSocketFrameValidationException) + def test_safe_to_bytes(self): + frame = ws.WebSocketsFrame.default(self.random_bytes(8)) + frame.actual_payload_length = 1 #corrupt the frame + frame.safe_to_bytes() + class BadHandshakeHandler(impl.WebSocketsEchoHandler): def handshake(self): -- cgit v1.2.3 From de9e7411253c4f67ea4d0b96f6f9e952024c5fa3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 10:02:10 +1200 Subject: Firm up cookie parsing and formatting API Make a tough call: we won't support old-style comma-separated set-cookie headers. Real world testing has shown that the latest rfc (6265) is often violated in ways that make the parsing problem indeterminate. Since this is much more common than the old style deprecated set-cookie variant, we focus on the most useful case. --- netlib/http_cookies.py | 112 ++++++++++++++++++++++++++++++++------------ test/test_http_cookies.py | 115 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 195 insertions(+), 32 deletions(-) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 82675418..a1f240f5 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -1,13 +1,27 @@ """ A flexible module for cookie parsing and manipulation. -We try to be as permissive as possible. Parsing accepts formats from RFC6265 an -RFC2109. Serialization follows RFC6265 strictly. +This module differs from usual standards-compliant cookie modules in a number of +ways. We try to be as permissive as possible, and to retain even mal-formed +information. Duplicate cookies are preserved in parsing, and can be set in +formatting. We do attempt to escape and quote values where needed, but will not +reject data that violate the specs. + +Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do +not parse the comma-separated variant of Set-Cookie that allows multiple cookies +to be set in a single header. Technically this should be feasible, but it turns +out that violations of RFC6265 that makes the parsing problem indeterminate are +much more common than genuine occurences of the multi-cookie variants. +Serialization follows RFC6265. http://tools.ietf.org/html/rfc6265 http://tools.ietf.org/html/rfc2109 + http://tools.ietf.org/html/rfc2965 """ +# TODO +# - Disallow LHS-only Cookie values + import re import odict @@ -59,7 +73,7 @@ def _read_quoted_string(s, start): return "".join(ret), i+1 -def _read_value(s, start, special): +def _read_value(s, start, delims): """ Reads a value - the RHS of a token/value pair in a cookie. @@ -70,37 +84,41 @@ def _read_value(s, start, special): return "", start elif s[start] == '"': return _read_quoted_string(s, start) - elif special: - return _read_until(s, start, ";") else: - return _read_until(s, start, ";,") + return _read_until(s, start, delims) -def _read_pairs(s, specials=()): +def _read_pairs(s, off=0, term=None, specials=()): """ Read pairs of lhs=rhs values. - specials: A lower-cased list of keys that may contain commas. + off: start offset + term: if True, treat a comma as a terminator for the pairs lists + specials: a lower-cased list of keys that may contain commas if term is + True """ - off = 0 vals = [] while 1: lhs, off = _read_token(s, off) lhs = lhs.lstrip() - rhs = None - if off < len(s): - if s[off] == "=": - rhs, off = _read_value(s, off+1, lhs.lower() in specials) - vals.append([lhs, rhs]) + if lhs: + rhs = None + if off < len(s): + if s[off] == "=": + if term and lhs.lower() not in specials: + delims = ";," + else: + delims = ";" + rhs, off = _read_value(s, off+1, delims) + vals.append([lhs, rhs]) off += 1 if not off < len(s): break + if term and s[off-1] == ",": + break return vals, off -ESCAPE = re.compile(r"([\"\\])") - - def _has_special(s): for i in s: if i in '",;\\': @@ -111,6 +129,9 @@ def _has_special(s): return False +ESCAPE = re.compile(r"([\"\\])") + + def _format_pairs(lst, specials=()): """ specials: A lower-cased list of keys that will not be quoted. @@ -127,25 +148,58 @@ def _format_pairs(lst, specials=()): return "; ".join(vals) -def parse_cookies(s): +def _format_set_cookie_pairs(lst): + return _format_pairs( + lst, + specials = ("expires", "path") + ) + + +def _parse_set_cookie_pairs(s): """ - Parses a Cookie header value. - Returns an ODict object. + For Set-Cookie, we support multiple cookies as described in RFC2109. + This function therefore returns a list of lists. """ - pairs, off = _read_pairs(s) - return odict.ODict(pairs) + pairs, off = _read_pairs( + s, + specials = ("expires", "path") + ) + return pairs -def unparse_cookies(od): +def parse_set_cookie_header(str): """ - Formats a Cookie header value. + Parse a Set-Cookie header value + + Returns a (name, value, attrs) tuple, or None, where attrs is an + ODictCaseless set of attributes. No attempt is made to parse attribute + values - they are treated purely as strings. """ - return _format_pairs(od.lst) + pairs = _parse_set_cookie_pairs(str) + if pairs: + return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + + +def format_set_cookie_header(name, value, attrs): + """ + Formats a Set-Cookie header value. + """ + pairs = [[name, value]] + pairs.extend(attrs.lst) + return _format_set_cookie_pairs(pairs) -def parse_set_cookies(s): - start = 0 +def parse_cookie_header(str): + """ + Parse a Cookie header value. + Returns a (possibly empty) ODict object. + """ + pairs, off = _read_pairs(str) + return odict.ODict(pairs) -def unparse_set_cookies(s): - pass +def format_cookie_header(od): + """ + Formats a Cookie header value. + """ + return _format_pairs(od.lst) diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py index 31e5f0b0..c0e5a5b7 100644 --- a/test/test_http_cookies.py +++ b/test/test_http_cookies.py @@ -1,6 +1,8 @@ -from netlib import http_cookies, odict +import pprint import nose.tools +from netlib import http_cookies, odict + def test_read_token(): tokens = [ @@ -65,6 +67,10 @@ def test_read_pairs(): def test_pairs_roundtrips(): pairs = [ + [ + "", + [] + ], [ "one=uno", [["one", "uno"]] @@ -110,5 +116,108 @@ def test_pairs_roundtrips(): nose.tools.eq_(ret, lst) -def test_parse_set_cookie(): - pass +def test_cookie_roundtrips(): + pairs = [ + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + ] + for s, lst in pairs: + ret = http_cookies.parse_cookie_header(s) + nose.tools.eq_(ret.lst, lst) + s2 = http_cookies.format_cookie_header(ret) + ret = http_cookies.parse_cookie_header(s2) + nose.tools.eq_(ret.lst, lst) + + +# TODO +# I've seen the following pathological cookie in the wild: +# +# cid=09,0,0,0,0; expires=Wed, 10-Jun-2015 21:54:53 GMT; path=/ +# +# It's not compliant under any RFC - the latest RFC prohibits commas in cookie +# values completely, earlier RFCs require them to be within a quoted string. +# +# If we ditch support for earlier RFCs, we can handle this correctly. This +# leaves us with the question: what's more common, multiple-value Set-Cookie +# headers, or Set-Cookie headers that violate the standards? + +def test_parse_set_cookie_pairs(): + pairs = [ + [ + "one=uno", + [ + ["one", "uno"] + ] + ], + [ + "one=uno; foo", + [ + ["one", "uno"], + ["foo", None] + ] + ], + [ + "mun=1.390.f60; " + "expires=sun, 11-oct-2015 12:38:31 gmt; path=/; " + "domain=b.aol.com", + [ + ["mun", "1.390.f60"], + ["expires", "sun, 11-oct-2015 12:38:31 gmt"], + ["path", "/"], + ["domain", "b.aol.com"] + ] + ], + [ + r'rpb=190%3d1%2616726%3d1%2634832%3d1%2634874%3d1; ' + 'domain=.rubiconproject.com; ' + 'expires=mon, 11-may-2015 21:54:57 gmt; ' + 'path=/', + [ + ['rpb', r'190%3d1%2616726%3d1%2634832%3d1%2634874%3d1'], + ['domain', '.rubiconproject.com'], + ['expires', 'mon, 11-may-2015 21:54:57 gmt'], + ['path', '/'] + ] + ], + ] + for s, lst in pairs: + ret = http_cookies._parse_set_cookie_pairs(s) + nose.tools.eq_(ret, lst) + s2 = http_cookies._format_set_cookie_pairs(ret) + ret2 = http_cookies._parse_set_cookie_pairs(s2) + nose.tools.eq_(ret2, lst) + + +def test_parse_set_cookie_header(): + vals = [ + [ + "", None + ], + [ + "one=uno", + ("one", "uno", []) + ], + [ + "one=uno; foo=bar", + ("one", "uno", [["foo", "bar"]]) + ] + ] + for s, expected in vals: + ret = http_cookies.parse_set_cookie_header(s) + if expected: + assert ret[0] == expected[0] + assert ret[1] == expected[1] + nose.tools.eq_(ret[2].lst, expected[2]) + s2 = http_cookies.format_set_cookie_header(*ret) + ret2 = http_cookies.parse_set_cookie_header(s2) + assert ret2[0] == expected[0] + assert ret2[1] == expected[1] + nose.tools.eq_(ret2[2].lst, expected[2]) + else: + assert ret is None -- cgit v1.2.3 From 6db5e0a4a133e6e6150f9cab87cd56b40d6db0b2 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 10:13:03 +1200 Subject: Remove old-style set-cookie cruft, unit tests to 100% --- netlib/http_cookies.py | 14 +++----------- test/test_http_cookies.py | 6 ++++++ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index a1f240f5..297efb80 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -88,14 +88,12 @@ def _read_value(s, start, delims): return _read_until(s, start, delims) -def _read_pairs(s, off=0, term=None, specials=()): +def _read_pairs(s, off=0, specials=()): """ Read pairs of lhs=rhs values. off: start offset - term: if True, treat a comma as a terminator for the pairs lists - specials: a lower-cased list of keys that may contain commas if term is - True + specials: a lower-cased list of keys that may contain commas """ vals = [] while 1: @@ -105,17 +103,11 @@ def _read_pairs(s, off=0, term=None, specials=()): rhs = None if off < len(s): if s[off] == "=": - if term and lhs.lower() not in specials: - delims = ";," - else: - delims = ";" - rhs, off = _read_value(s, off+1, delims) + rhs, off = _read_value(s, off+1, ";") vals.append([lhs, rhs]) off += 1 if not off < len(s): break - if term and s[off-1] == ",": - break return vals, off diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py index c0e5a5b7..ad509254 100644 --- a/test/test_http_cookies.py +++ b/test/test_http_cookies.py @@ -155,6 +155,12 @@ def test_parse_set_cookie_pairs(): ["one", "uno"] ] ], + [ + "one=un\x20", + [ + ["one", "un\x20"] + ] + ], [ "one=uno; foo", [ -- cgit v1.2.3 From d739882bf2dc65925c001c5bf848f5664640d299 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 13:50:57 +1200 Subject: Add an .extend method for ODicts --- netlib/odict.py | 6 ++++++ test/test_odict.py | 7 ++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/netlib/odict.py b/netlib/odict.py index 7a2f611b..7a54f282 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -108,6 +108,12 @@ class ODict(object): lst = copy.deepcopy(self.lst) return self.__class__(lst) + def extend(self, other): + """ + Add the contents of other, preserving any duplicates. + """ + self.lst.extend(other.lst) + def __repr__(self): elements = [] for itm in self.lst: diff --git a/test/test_odict.py b/test/test_odict.py index d90bc6e5..c2415b6d 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -109,6 +109,12 @@ class TestODict: assert self.od.get_first("one") == "two" assert self.od.get_first("two") == None + def test_extend(self): + a = odict.ODict([["a", "b"], ["c", "d"]]) + b = odict.ODict([["a", "b"], ["e", "f"]]) + a.extend(b) + assert len(a) == 4 + assert a["a"] == ["b", "b"] class TestODictCaseless: def setUp(self): @@ -144,4 +150,3 @@ class TestODictCaseless: assert self.od.keys() == ["foo"] self.od.add("bar", 2) assert len(self.od.keys()) == 2 - -- cgit v1.2.3 From aeebf31927eb3ff74824525005c7b146024de6d5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 16:20:02 +1200 Subject: odict: don't convert values to strings when added --- netlib/odict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/odict.py b/netlib/odict.py index 7a54f282..a0ea9e53 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -84,7 +84,7 @@ class ODict(object): return False def add(self, key, value): - self.lst.append([key, str(value)]) + self.lst.append([key, value]) def get(self, k, d=None): if k in self: @@ -117,7 +117,7 @@ class ODict(object): def __repr__(self): elements = [] for itm in self.lst: - elements.append(itm[0] + ": " + itm[1]) + elements.append(itm[0] + ": " + str(itm[1])) elements.append("") return "\r\n".join(elements) -- cgit v1.2.3 From 0c85c72dc43d0d017e2bf5af9c2def46968d0499 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 15 Apr 2015 10:28:17 +1200 Subject: ODict improvements - Setting values now tries to preserve the existing order, rather than just appending to the end. - __repr__ now returns a repr of the tuple list. The old repr becomes a .format() method. This is clearer, makes troubleshooting easier, and doesn't assume all data in ODicts are header-like --- netlib/odict.py | 25 +++++++++++++++++++------ netlib/wsgi.py | 29 ++++++++++++++++++----------- test/test_http.py | 11 +++++++++-- test/test_http_cookies.py | 15 +++------------ test/test_odict.py | 25 +++++++++++++++++++++++-- test/test_wsgi.py | 1 - 6 files changed, 72 insertions(+), 34 deletions(-) diff --git a/netlib/odict.py b/netlib/odict.py index a0ea9e53..dd738c55 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -13,7 +13,8 @@ def safe_subn(pattern, repl, target, *args, **kwargs): class ODict(object): """ - A dictionary-like object for managing ordered (key, value) data. + A dictionary-like object for managing ordered (key, value) data. Think + about it as a convenient interface to a list of (key, value) tuples. """ def __init__(self, lst=None): self.lst = lst or [] @@ -64,11 +65,20 @@ class ODict(object): key, they are cleared. """ if isinstance(valuelist, basestring): - raise ValueError("Expected list of values instead of string. Example: odict['Host'] = ['www.example.com']") - - new = self._filter_lst(k, self.lst) - for i in valuelist: - new.append([k, i]) + raise ValueError( + "Expected list of values instead of string. " + "Example: odict['Host'] = ['www.example.com']" + ) + kc = self._kconv(k) + new = [] + for i in self.lst: + if self._kconv(i[0]) == kc: + if valuelist: + new.append([k, valuelist.pop(0)]) + else: + new.append(i) + while valuelist: + new.append([k, valuelist.pop(0)]) self.lst = new def __delitem__(self, k): @@ -115,6 +125,9 @@ class ODict(object): self.lst.extend(other.lst) def __repr__(self): + return repr(self.lst) + + def format(self): elements = [] for itm in self.lst: elements.append(itm[0] + ": " + str(itm[1])) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index bac27d5a..1b979608 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,5 +1,8 @@ from __future__ import (absolute_import, print_function, division) -import cStringIO, urllib, time, traceback +import cStringIO +import urllib +import time +import traceback from . import odict, tcp @@ -23,15 +26,18 @@ class Request(object): def date_time_string(): """Return the current date and time formatted for a message header.""" WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] - MONTHS = [None, - 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] + MONTHS = [ + None, + 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', + 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' + ] now = time.time() year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - WEEKS[wd], - day, MONTHS[month], year, - hh, mm, ss) + WEEKS[wd], + day, MONTHS[month], year, + hh, mm, ss + ) return s @@ -100,6 +106,7 @@ class WSGIAdaptor(object): status = None, headers = None ) + def write(data): if not state["headers_sent"]: soc.write("HTTP/1.1 %s\r\n"%state["status"]) @@ -108,7 +115,7 @@ class WSGIAdaptor(object): h["Server"] = [self.sversion] if 'date' not in h: h["Date"] = [date_time_string()] - soc.write(str(h)) + soc.write(h.format()) soc.write("\r\n") state["headers_sent"] = True if data: @@ -130,7 +137,9 @@ class WSGIAdaptor(object): errs = cStringIO.StringIO() try: - dataiter = self.app(self.make_environ(request, errs, **env), start_response) + dataiter = self.app( + self.make_environ(request, errs, **env), start_response + ) for i in dataiter: write(i) if not state["headers_sent"]: @@ -143,5 +152,3 @@ class WSGIAdaptor(object): except Exception: # pragma: no cover pass return errs.getvalue() - - diff --git a/test/test_http.py b/test/test_http.py index fed60946..b1c62458 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -53,6 +53,7 @@ def test_connection_close(): h["connection"] = ["close"] assert http.connection_close((1, 1), h) + def test_get_header_tokens(): h = odict.ODictCaseless() assert http.get_header_tokens(h, "foo") == [] @@ -69,11 +70,13 @@ def test_read_http_body_request(): r = cStringIO.StringIO("testing") assert http.read_http_body(r, h, None, "GET", None, True) == "" + def test_read_http_body_response(): h = odict.ODictCaseless() s = cStringIO.StringIO("testing") assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" + def test_read_http_body(): # test default case h = odict.ODictCaseless() @@ -115,6 +118,7 @@ def test_read_http_body(): s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" + def test_expected_http_body_size(): # gibber in the content-length field h = odict.ODictCaseless() @@ -135,6 +139,7 @@ def test_expected_http_body_size(): h = odict.ODictCaseless() assert http.expected_http_body_size(h, True, "GET", None) == 0 + def test_parse_http_protocol(): assert http.parse_http_protocol("HTTP/1.1") == (1, 1) assert http.parse_http_protocol("HTTP/0.0") == (0, 0) @@ -189,6 +194,7 @@ def test_parse_init_http(): assert not http.parse_init_http("GET /test foo/1.1") assert not http.parse_init_http("GET /test\xc0 HTTP/1.1") + class TestReadHeaders: def _read(self, data, verbatim=False): if not verbatim: @@ -251,11 +257,12 @@ class TestReadResponseNoContentLength(test.ServerTestBase): httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None) assert content == "bar\r\n\r\n" + def test_read_response(): def tst(data, method, limit, include_body=True): data = textwrap.dedent(data) r = cStringIO.StringIO(data) - return http.read_response(r, method, limit, include_body=include_body) + return http.read_response(r, method, limit, include_body = include_body) tutils.raises("server disconnect", tst, "", "GET", None) tutils.raises("invalid server response", tst, "foo", "GET", None) @@ -351,6 +358,7 @@ def test_parse_url(): # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt assert not http.parse_url('http://lo[calhost') + def test_parse_http_basic_auth(): vals = ("basic", "foo", "bar") assert http.parse_http_basic_auth(http.assemble_http_basic_auth(*vals)) == vals @@ -358,4 +366,3 @@ def test_parse_http_basic_auth(): assert not http.parse_http_basic_auth("foo bar") v = "basic " + binascii.b2a_base64("foo") assert not http.parse_http_basic_auth(v) - diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py index ad509254..7438af7c 100644 --- a/test/test_http_cookies.py +++ b/test/test_http_cookies.py @@ -135,18 +135,6 @@ def test_cookie_roundtrips(): nose.tools.eq_(ret.lst, lst) -# TODO -# I've seen the following pathological cookie in the wild: -# -# cid=09,0,0,0,0; expires=Wed, 10-Jun-2015 21:54:53 GMT; path=/ -# -# It's not compliant under any RFC - the latest RFC prohibits commas in cookie -# values completely, earlier RFCs require them to be within a quoted string. -# -# If we ditch support for earlier RFCs, we can handle this correctly. This -# leaves us with the question: what's more common, multiple-value Set-Cookie -# headers, or Set-Cookie headers that violate the standards? - def test_parse_set_cookie_pairs(): pairs = [ [ @@ -205,6 +193,9 @@ def test_parse_set_cookie_header(): [ "", None ], + [ + ";", None + ], [ "one=uno", ("one", "uno", []) diff --git a/test/test_odict.py b/test/test_odict.py index c2415b6d..c01c4dbe 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -6,6 +6,11 @@ class TestODict: def setUp(self): self.od = odict.ODict() + def test_repr(self): + h = odict.ODict() + h["one"] = ["two"] + assert repr(h) + def test_str_err(self): h = odict.ODict() tutils.raises(ValueError, h.__setitem__, "key", "foo") @@ -20,7 +25,7 @@ class TestODict: "two: tre\r\n", "\r\n" ] - out = repr(self.od) + out = self.od.format() for i in expected: assert out.find(i) >= 0 @@ -39,7 +44,7 @@ class TestODict: self.od["one"] = ["uno"] expected1 = "one: uno\r\n" expected2 = "\r\n" - out = repr(self.od) + out = self.od.format() assert out.find(expected1) >= 0 assert out.find(expected2) >= 0 @@ -150,3 +155,19 @@ class TestODictCaseless: assert self.od.keys() == ["foo"] self.od.add("bar", 2) assert len(self.od.keys()) == 2 + + def test_add_order(self): + od = odict.ODict( + [ + ["one", "uno"], + ["two", "due"], + ["three", "tre"], + ] + ) + od["two"] = ["foo", "bar"] + assert od.lst == [ + ["one", "uno"], + ["two", "foo"], + ["three", "tre"], + ["two", "bar"], + ] diff --git a/test/test_wsgi.py b/test/test_wsgi.py index 6e1fb146..1c8c5263 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -100,4 +100,3 @@ class TestWSGI: start_response(status, response_headers, ei) yield "bbb" assert "Internal Server Error" in self._serve(app) - -- cgit v1.2.3 From c53d89fd7fad6c46458ab3d0140528e344de605f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 16 Apr 2015 08:30:54 +1200 Subject: Improve flexibility of http_cookies._format_pairs --- netlib/http_cookies.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 297efb80..dab95ed0 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -124,7 +124,7 @@ def _has_special(s): ESCAPE = re.compile(r"([\"\\])") -def _format_pairs(lst, specials=()): +def _format_pairs(lst, specials=(), sep="; "): """ specials: A lower-cased list of keys that will not be quoted. """ @@ -137,7 +137,7 @@ def _format_pairs(lst, specials=()): v = ESCAPE.sub(r"\\\1", v) v = '"%s"'%v vals.append("%s=%s"%(k, v)) - return "; ".join(vals) + return sep.join(vals) def _format_set_cookie_pairs(lst): -- cgit v1.2.3 From 488c25d812a321f5a03253b62ab33b61ecc13de1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 17 Apr 2015 13:57:39 +1200 Subject: websockets: whitespace, PEP8 --- netlib/websockets/websockets.py | 169 +++++++++++++++++++++++----------------- 1 file changed, 96 insertions(+), 73 deletions(-) diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index ea3db21d..8782ea49 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -1,31 +1,34 @@ from __future__ import absolute_import -from base64 import b64encode -from hashlib import sha1 -from mimetools import Message -from netlib import tcp -from netlib import utils -from StringIO import StringIO +import base64 +import hashlib +import mimetools +import StringIO import os -import SocketServer import struct import io -# Colleciton of utility functions that implement small portions of the RFC6455 WebSockets Protocol -# Useful for building WebSocket clients and servers. -# -# Emphassis is on readabilty, simplicity and modularity, not performance or completeness +from .. import utils + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. # -# This is a work in progress and does not yet contain all the utilites need to create fully complient client/servers +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness # +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # # Spec: https://tools.ietf.org/html/rfc6455 -# The magic sha that websocket servers must know to prove they understand RFC6455 +# The magic sha that websocket servers must know to prove they understand +# RFC6455 websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + class WebSocketFrameValidationException(Exception): pass + class WebSocketsFrame(object): """ Represents one websockets frame. @@ -33,7 +36,7 @@ class WebSocketsFrame(object): from_bytes() is also avaliable. WebSockets Frame as defined in RFC6455 - + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-------+-+-------------+-------------------------------+ |F|R|R|R| opcode|M| Payload len | Extended payload length | @@ -62,7 +65,7 @@ class WebSocketsFrame(object): rsv1 = 0, # decimal integer 1 or 0 rsv2 = 0, # decimal integer 1 or 0 rsv3 = 0, # decimal integer 1 or 0 - payload = None, # bytestring + payload = None, # bytestring masking_key = None, # 32 bit byte string actual_payload_length = None, # any decimal integer ): @@ -81,18 +84,17 @@ class WebSocketsFrame(object): @classmethod def from_bytes(cls, bytestring): """ - Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use from_byte_stream() directly - """ + Construct a websocket frame from an in-memory bytestring to construct + a frame from a stream of bytes, use from_byte_stream() directly + """ return cls.from_byte_stream(io.BytesIO(bytestring).read) - @classmethod def default(cls, message, from_client = False): """ - Construct a basic websocket frame from some default values. + Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. - """ + """ length_code, actual_length = get_payload_length_pair(message) if from_client: @@ -103,7 +105,7 @@ class WebSocketsFrame(object): mask_bit = 0 masking_key = None payload = message - + return cls( fin = 1, # final frame opcode = 1, # text @@ -117,10 +119,10 @@ class WebSocketsFrame(object): def is_valid(self): """ - Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame - has not been corrupted. - """ - try: + Validate websocket frame invariants, call at anytime to ensure the + WebSocketsFrame has not been corrupted. + """ + try: assert 0 <= self.fin <= 1 assert 0 <= self.rsv1 <= 1 assert 0 <= self.rsv2 <= 1 @@ -128,18 +130,18 @@ class WebSocketsFrame(object): assert 1 <= self.opcode <= 4 assert 0 <= self.mask_bit <= 1 assert 1 <= self.payload_length_code <= 127 - + if self.mask_bit == 1: assert 1 <= len(self.masking_key) <= 4 else: - assert self.masking_key == None - + assert self.masking_key is None + assert self.actual_payload_length == len(self.payload) if self.payload is not None and self.masking_key is not None: assert apply_mask(self.payload, self.masking_key) == self.decoded_payload - return True + return True except AssertionError: return False @@ -165,30 +167,32 @@ class WebSocketsFrame(object): def to_bytes(self): """ - Serialize the frame back into the wire format, returns a bytestring - If you haven't checked is_valid_frame() then there's no guarentees that the - serialized bytes will be correct. see safe_to_bytes() - """ + Serialize the frame back into the wire format, returns a bytestring If + you haven't checked is_valid_frame() then there's no guarentees that + the serialized bytes will be correct. see safe_to_bytes() + """ max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) - # break down of the bit-math used to construct the first byte from the frame's integer values - # first shift the significant bit into the correct position + # break down of the bit-math used to construct the first byte from the + # frame's integer values first shift the significant bit into the + # correct position # 00000001 << 7 = 10000000 # ... # then combine: - # + # # 10000000 fin # 01000000 res1 # 00100000 res2 # 00010000 res3 # 00000001 opcode - # -------- OR + # -------- OR # 11110001 = first_byte - first_byte = (self.fin << 7) | (self.rsv1 << 6) | (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode - + first_byte = (self.fin << 7) | (self.rsv1 << 6) |\ + (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + second_byte = (self.mask_bit << 7) | self.payload_length_code bytes = chr(first_byte) + chr(second_byte) @@ -199,11 +203,13 @@ class WebSocketsFrame(object): elif self.actual_payload_length < max_16_bit_int: # '!H' pack as 16 bit unsigned short - bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length - + # add 2 byte extended payload length + bytes += struct.pack('!H', self.actual_payload_length) + elif self.actual_payload_length < max_64_bit_int: # '!Q' = pack as 64 bit unsigned long long - bytes += struct.pack('!Q', self.actual_payload_length) # add 8 bytes extended payload length + # add 8 bytes extended payload length + bytes += struct.pack('!Q', self.actual_payload_length) if self.masking_key is not None: bytes += self.masking_key @@ -212,43 +218,46 @@ class WebSocketsFrame(object): return bytes - @classmethod def from_byte_stream(cls, read_bytes): """ read a websockets frame sent by a server or client - + read_bytes is a function that can be backed - by sockets or by any byte reader. So this + by sockets or by any byte reader. So this function may be used to read frames from disk/wire/memory - """ - first_byte = utils.bytes_to_int(read_bytes(1)) + """ + first_byte = utils.bytes_to_int(read_bytes(1)) second_byte = utils.bytes_to_int(read_bytes(1)) - - fin = first_byte >> 7 # grab the left most bit - opcode = first_byte & 15 # grab right most 4 bits by and-ing with 00001111 - mask_bit = second_byte >> 7 # grab left most bit - payload_length = second_byte & 127 # grab the next 7 bits + + # grab the left most bit + fin = first_byte >> 7 + # grab right most 4 bits by and-ing with 00001111 + opcode = first_byte & 15 + # grab left most bit + mask_bit = second_byte >> 7 + # grab the next 7 bits + payload_length = second_byte & 127 # payload_lengthy > 125 indicates you need to read more bytes # to get the actual payload length if payload_length <= 125: - actual_payload_length = payload_length + actual_payload_length = payload_length elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(read_bytes(2)) + actual_payload_length = utils.bytes_to_int(read_bytes(2)) - elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(read_bytes(8)) + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) # masking key only present if mask bit set if mask_bit == 1: masking_key = read_bytes(4) else: masking_key = None - + payload = read_bytes(actual_payload_length) - + if mask_bit == 1: decoded_payload = apply_mask(payload, masking_key) else: @@ -295,12 +304,15 @@ def apply_mask(message, masking_key): result += chr(ord(char) ^ masks[len(result) % 4]) return result + def random_masking_key(): return os.urandom(4) + def create_client_handshake(host, port, key, version, resource): """ - WebSockets connections are intiated by the client with a valid HTTP upgrade request + WebSockets connections are intiated by the client with a valid HTTP + upgrade request """ headers = [ ('Host', '%s:%s' % (host, port)), @@ -312,10 +324,11 @@ def create_client_handshake(host, port, key, version, resource): request = "GET %s HTTP/1.1" % resource return build_handshake(headers, request) + def create_server_handshake(key): """ - The server response is a valid HTTP 101 response. - """ + The server response is a valid HTTP 101 response. + """ headers = [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), @@ -332,12 +345,13 @@ def build_handshake(headers, request): handshake.append(b'\r\n') return b'\r\n'.join(handshake) + def read_handshake(read_bytes, num_bytes_per_read): """ - From provided function that reads bytes, read in a + From provided function that reads bytes, read in a complete HTTP request, which terminates with a CLRF - """ - response = b'' + """ + response = b'' doubleCLRF = b'\r\n\r\n' while True: bytes = read_bytes(num_bytes_per_read) @@ -348,14 +362,15 @@ def read_handshake(read_bytes, num_bytes_per_read): break return response + def get_payload_length_pair(payload_bytestring): """ A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is larger - than 125 - """ + extended length code to represent the actual length if length code is + larger than 125 + """ actual_length = len(payload_bytestring) - + if actual_length <= 125: length_code = actual_length elif actual_length >= 126 and actual_length <= 65535: @@ -364,6 +379,7 @@ def get_payload_length_pair(payload_bytestring): length_code = 127 return (length_code, actual_length) + def process_handshake_from_client(handshake): headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": @@ -371,6 +387,7 @@ def process_handshake_from_client(handshake): key = headers['Sec-WebSocket-Key'] return key + def process_handshake_from_server(handshake, client_nounce): headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": @@ -378,12 +395,18 @@ def process_handshake_from_server(handshake, client_nounce): key = headers['Sec-WebSocket-Accept'] return key + def headers_from_http_message(http_message): - return Message(StringIO(http_message.split('\r\n', 1)[1])) + return mimetools.Message( + StringIO.StringIO(http_message.split('\r\n', 1)[1]) + ) + def create_server_nounce(client_nounce): - return b64encode(sha1(client_nounce + websockets_magic).hexdigest().decode('hex')) + return base64.b64encode( + hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex') + ) -def create_client_nounce(): - return b64encode(os.urandom(16)).decode('utf-8') +def create_client_nounce(): + return base64.b64encode(os.urandom(16)).decode('utf-8') -- cgit v1.2.3 From 7defb5be862a4251da9d7c530593f7e9be3e739e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 17 Apr 2015 14:29:20 +1200 Subject: websockets: more whitespace, WebSocketFrame -> Frame --- netlib/websockets/implementations.py | 12 ++--- netlib/websockets/websockets.py | 100 +++++++++++++++++------------------ test/test_websockets.py | 45 +++++++++------- 3 files changed, 81 insertions(+), 76 deletions(-) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 1ded3b85..337c5496 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -9,7 +9,7 @@ import os # Simple websocket client and servers that are used to exercise the functionality in websockets.py # These are *not* fully RFC6455 compliant -class WebSocketsEchoHandler(tcp.BaseHandler): +class WebSocketsEchoHandler(tcp.BaseHandler): def __init__(self, connection, address, server): super(WebSocketsEchoHandler, self).__init__(connection, address, server) self.handshake_done = False @@ -22,14 +22,14 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.read_next_message() def read_next_message(self): - decoded = ws.WebSocketsFrame.from_byte_stream(self.rfile.read).decoded_payload + decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload self.on_message(decoded) def send_message(self, message): - frame = ws.WebSocketsFrame.default(message, from_client = False) + frame = ws.Frame.default(message, from_client = False) self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() - + def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) key = ws.process_handshake_from_client(client_hs) @@ -72,9 +72,9 @@ class WebSocketsClient(tcp.TCPClient): self.close() def read_next_message(self): - return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + return ws.Frame.from_byte_stream(self.rfile.read).payload def send_message(self, message): - frame = ws.WebSocketsFrame.default(message, from_client = True) + frame = ws.Frame.default(message, from_client = True) self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index 8782ea49..86d98caf 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -29,7 +29,7 @@ class WebSocketFrameValidationException(Exception): pass -class WebSocketsFrame(object): +class Frame(object): """ Represents one websockets frame. Constructor takes human readable forms of the frame components @@ -98,29 +98,29 @@ class WebSocketsFrame(object): length_code, actual_length = get_payload_length_pair(message) if from_client: - mask_bit = 1 + mask_bit = 1 masking_key = random_masking_key() - payload = apply_mask(message, masking_key) + payload = apply_mask(message, masking_key) else: - mask_bit = 0 + mask_bit = 0 masking_key = None - payload = message + payload = message return cls( - fin = 1, # final frame - opcode = 1, # text - mask_bit = mask_bit, - payload_length_code = length_code, - payload = payload, - masking_key = masking_key, - decoded_payload = message, + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, actual_payload_length = actual_length ) def is_valid(self): """ - Validate websocket frame invariants, call at anytime to ensure the - WebSocketsFrame has not been corrupted. + Validate websocket frame invariants, call at anytime to ensure the + Frame has not been corrupted. """ try: assert 0 <= self.fin <= 1 @@ -147,17 +147,18 @@ class WebSocketsFrame(object): def human_readable(self): return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask_bit - " + str(self.mask_bit)), - ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + str(self.masking_key)), - ("payload - " + str(self.payload)), - ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length))]) + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)) + ]) def safe_to_bytes(self): if self.is_valid(): @@ -167,11 +168,10 @@ class WebSocketsFrame(object): def to_bytes(self): """ - Serialize the frame back into the wire format, returns a bytestring If - you haven't checked is_valid_frame() then there's no guarentees that - the serialized bytes will be correct. see safe_to_bytes() + Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees + that the serialized bytes will be correct. see safe_to_bytes() """ - max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) @@ -199,13 +199,10 @@ class WebSocketsFrame(object): if self.actual_payload_length < 126: pass - elif self.actual_payload_length < max_16_bit_int: - # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length bytes += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < max_64_bit_int: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length @@ -215,7 +212,6 @@ class WebSocketsFrame(object): bytes += self.masking_key bytes += self.payload # already will be encoded if neccessary - return bytes @classmethod @@ -264,29 +260,31 @@ class WebSocketsFrame(object): decoded_payload = payload return cls( - fin = fin, - opcode = opcode, - mask_bit = mask_bit, - payload_length_code = payload_length, - payload = payload, - masking_key = masking_key, - decoded_payload = decoded_payload, + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, actual_payload_length = actual_payload_length ) def __eq__(self, other): return ( - self.fin == other.fin and - self.rsv1 == other.rsv1 and - self.rsv2 == other.rsv2 and - self.rsv3 == other.rsv3 and - self.opcode == other.opcode and - self.mask_bit == other.mask_bit and - self.payload_length_code == other.payload_length_code and - self.masking_key == other.masking_key and - self.payload == other.payload and - self.decoded_payload == other.decoded_payload and - self.actual_payload_length == other.actual_payload_length) + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length + ) + def apply_mask(message, masking_key): """ diff --git a/test/test_websockets.py b/test/test_websockets.py index 951aa41f..d1753638 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -5,6 +5,7 @@ from netlib.websockets import websockets as ws import os from nose.tools import raises + class TestWebSockets(test.ServerTestBase): handler = impl.WebSocketsEchoHandler @@ -22,9 +23,12 @@ class TestWebSockets(test.ServerTestBase): self.echo("hello I'm the client") def test_frame_sizes(self): - small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length - medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int - large_msg = self.random_bytes(150000) # 150kb, slightly larger than can fit in a 16 bit int + # length can fit in the the 7 bit payload length + small_msg = self.random_bytes(100) + # 50kb, sligthly larger than can fit in a 7 bit int + medium_msg = self.random_bytes(50000) + # 150kb, slightly larger than can fit in a 16 bit int + large_msg = self.random_bytes(150000) self.echo(small_msg) self.echo(medium_msg) @@ -33,51 +37,54 @@ class TestWebSockets(test.ServerTestBase): def test_default_builder(self): """ default builder should always generate valid frames - """ + """ msg = self.random_bytes() - client_frame = ws.WebSocketsFrame.default(msg, from_client = True) + client_frame = ws.Frame.default(msg, from_client = True) assert client_frame.is_valid() - server_frame = ws.WebSocketsFrame.default(msg, from_client = False) + server_frame = ws.Frame.default(msg, from_client = False) assert server_frame.is_valid() def test_serialization_bijection(self): """ - Ensure that various frame types can be serialized/deserialized back and forth - between to_bytes() and from_bytes() - """ + Ensure that various frame types can be serialized/deserialized back + and forth between to_bytes() and from_bytes() + """ for is_client in [True, False]: - for num_bytes in [100, 50000, 150000]: - frame = ws.WebSocketsFrame.default(self.random_bytes(num_bytes), is_client) - assert frame == ws.WebSocketsFrame.from_bytes(frame.to_bytes()) + for num_bytes in [100, 50000, 150000]: + frame = ws.Frame.default( + self.random_bytes(num_bytes), is_client + ) + assert frame == ws.Frame.from_bytes(frame.to_bytes()) bytes = b'\x81\x11cba' - assert ws.WebSocketsFrame.from_bytes(bytes).to_bytes() == bytes + assert ws.Frame.from_bytes(bytes).to_bytes() == bytes @raises(ws.WebSocketFrameValidationException) def test_safe_to_bytes(self): - frame = ws.WebSocketsFrame.default(self.random_bytes(8)) - frame.actual_payload_length = 1 #corrupt the frame + frame = ws.Frame.default(self.random_bytes(8)) + frame.actual_payload_length = 1 # corrupt the frame frame.safe_to_bytes() class BadHandshakeHandler(impl.WebSocketsEchoHandler): def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) - key = ws.process_handshake_from_client(client_hs) - response = ws.create_server_handshake("malformed_key") + ws.process_handshake_from_client(client_hs) + response = ws.create_server_handshake("malformed_key") self.wfile.write(response) self.wfile.flush() self.handshake_done = True + class TestBadHandshake(test.ServerTestBase): """ Ensure that the client disconnects if the server handshake is malformed - """ + """ handler = BadHandshakeHandler @raises(tcp.NetLibDisconnect) def test(self): client = impl.WebSocketsClient(("127.0.0.1", self.port)) client.connect() - client.send_message("hello") \ No newline at end of file + client.send_message("hello") -- cgit v1.2.3 From 0c2ad1edb1af013576f4ac31e05b308ffb440116 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 17 Apr 2015 16:29:09 +0200 Subject: fix socket_close on Windows, refs mitmproxy/mitmproxy#527 --- netlib/tcp.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 10269aa4..84008e2c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -247,24 +247,32 @@ def close_socket(sock): """ try: # We already indicate that we close our end. - # If we close RD, any further received bytes would result in a RST being set, which we want to avoid - # for our purposes sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux # Section 4.2.2.13 of RFC 1122 tells us that a close() with any # pending readable data could lead to an immediate RST being sent (which is the case on Windows). # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html # - # However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: - # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that - # recv() would block infinitely. - # As a workaround, we set a timeout here even if we are in blocking mode. - # Please let us know if you have a better solution to this problem. - - sock.settimeout(sock.gettimeout() or 20) - # may raise a timeout/disconnect exception. - while sock.recv(4096): # pragma: no cover - pass + # This in turn results in the following issue: If we send an error page to the client and then close the socket, + # the RST may be received by the client before the error page and the users sees a connection error rather than + # the error page. Thus, we try to empty the read buffer on Windows first. + # (see https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) + # + if os.name == "nt": # pragma: no cover + # We cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: + # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. + # As a workaround, we set a timeout here even if we are in blocking mode. + sock.settimeout(sock.gettimeout() or 20) + + # limit at a megabyte so that we don't read infinitely + for _ in xrange(1024 ** 3 // 4096): + # may raise a timeout/disconnect exception. + if not sock.recv(4096): + break + + # Now we can close the other half as well. + sock.shutdown(socket.SHUT_RD) except socket.error: pass -- cgit v1.2.3 From 74389ef04a3fdda4d388acb6d655adde78fccd7d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 20 Apr 2015 09:38:09 +1200 Subject: Websockets: reorganise - websockets.py to top-level - implementations into test suite --- netlib/websockets.py | 410 +++++++++++++++++++++++++++++++++++ netlib/websockets/__init__.py | 1 - netlib/websockets/implementations.py | 80 ------- netlib/websockets/websockets.py | 410 ----------------------------------- test/test_websockets.py | 105 +++++++-- 5 files changed, 499 insertions(+), 507 deletions(-) create mode 100644 netlib/websockets.py delete mode 100644 netlib/websockets/__init__.py delete mode 100644 netlib/websockets/implementations.py delete mode 100644 netlib/websockets/websockets.py diff --git a/netlib/websockets.py b/netlib/websockets.py new file mode 100644 index 00000000..83e90238 --- /dev/null +++ b/netlib/websockets.py @@ -0,0 +1,410 @@ +from __future__ import absolute_import + +import base64 +import hashlib +import mimetools +import StringIO +import os +import struct +import io + +from . import utils + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness +# +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand +# RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + + +class WebSocketFrameValidationException(Exception): + pass + + +class Frame(object): + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + def __init__( + self, + fin, # decmial integer 1 or 0 + opcode, # decmial integer 1 - 4 + mask_bit, # decimal integer 1 or 0 + payload_length_code, # decimal integer 1 - 127 + decoded_payload, # bytestring + rsv1 = 0, # decimal integer 1 or 0 + rsv2 = 0, # decimal integer 1 or 0 + rsv3 = 0, # decimal integer 1 or 0 + payload = None, # bytestring + masking_key = None, # 32 bit byte string + actual_payload_length = None, # any decimal integer + ): + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.mask_bit = mask_bit + self.payload_length_code = payload_length_code + self.masking_key = masking_key + self.payload = payload + self.decoded_payload = decoded_payload + self.actual_payload_length = actual_payload_length + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring to construct + a frame from a stream of bytes, use from_byte_stream() directly + """ + return cls.from_byte_stream(io.BytesIO(bytestring).read) + + @classmethod + def default(cls, message, from_client = False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + length_code, actual_length = get_payload_length_pair(message) + + if from_client: + mask_bit = 1 + masking_key = random_masking_key() + payload = apply_mask(message, masking_key) + else: + mask_bit = 0 + masking_key = None + payload = message + + return cls( + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, + actual_payload_length = actual_length + ) + + def is_valid(self): + """ + Validate websocket frame invariants, call at anytime to ensure the + Frame has not been corrupted. + """ + try: + assert 0 <= self.fin <= 1 + assert 0 <= self.rsv1 <= 1 + assert 0 <= self.rsv2 <= 1 + assert 0 <= self.rsv3 <= 1 + assert 1 <= self.opcode <= 4 + assert 0 <= self.mask_bit <= 1 + assert 1 <= self.payload_length_code <= 127 + + if self.mask_bit == 1: + assert 1 <= len(self.masking_key) <= 4 + else: + assert self.masking_key is None + + assert self.actual_payload_length == len(self.payload) + + if self.payload is not None and self.masking_key is not None: + assert apply_mask(self.payload, self.masking_key) == self.decoded_payload + + return True + except AssertionError: + return False + + def human_readable(self): + return "\n".join([ + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)) + ]) + + def safe_to_bytes(self): + if self.is_valid(): + return self.to_bytes() + else: + raise WebSocketFrameValidationException() + + def to_bytes(self): + """ + Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees + that the serialized bytes will be correct. see safe_to_bytes() + """ + max_16_bit_int = (1 << 16) + max_64_bit_int = (1 << 63) + + # break down of the bit-math used to construct the first byte from the + # frame's integer values first shift the significant bit into the + # correct position + # 00000001 << 7 = 10000000 + # ... + # then combine: + # + # 10000000 fin + # 01000000 res1 + # 00100000 res2 + # 00010000 res3 + # 00000001 opcode + # -------- OR + # 11110001 = first_byte + + first_byte = (self.fin << 7) | (self.rsv1 << 6) |\ + (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + + second_byte = (self.mask_bit << 7) | self.payload_length_code + + bytes = chr(first_byte) + chr(second_byte) + + if self.actual_payload_length < 126: + pass + elif self.actual_payload_length < max_16_bit_int: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + bytes += struct.pack('!H', self.actual_payload_length) + elif self.actual_payload_length < max_64_bit_int: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + bytes += struct.pack('!Q', self.actual_payload_length) + + if self.masking_key is not None: + bytes += self.masking_key + + bytes += self.payload # already will be encoded if neccessary + return bytes + + @classmethod + def from_byte_stream(cls, read_bytes): + """ + read a websockets frame sent by a server or client + + read_bytes is a function that can be backed + by sockets or by any byte reader. So this + function may be used to read frames from disk/wire/memory + """ + first_byte = utils.bytes_to_int(read_bytes(1)) + second_byte = utils.bytes_to_int(read_bytes(1)) + + # grab the left most bit + fin = first_byte >> 7 + # grab right most 4 bits by and-ing with 00001111 + opcode = first_byte & 15 + # grab left most bit + mask_bit = second_byte >> 7 + # grab the next 7 bits + payload_length = second_byte & 127 + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if payload_length <= 125: + actual_payload_length = payload_length + + elif payload_length == 126: + actual_payload_length = utils.bytes_to_int(read_bytes(2)) + + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = read_bytes(4) + else: + masking_key = None + + payload = read_bytes(actual_payload_length) + + if mask_bit == 1: + decoded_payload = apply_mask(payload, masking_key) + else: + decoded_payload = payload + + return cls( + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, + actual_payload_length = actual_payload_length + ) + + def __eq__(self, other): + return ( + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length + ) + + +def apply_mask(message, masking_key): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + This method both encodes and decodes strings with the provided mask + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + masks = [utils.bytes_to_int(byte) for byte in masking_key] + result = "" + for char in message: + result += chr(ord(char) ^ masks[len(result) % 4]) + return result + + +def random_masking_key(): + return os.urandom(4) + + +def create_client_handshake(host, port, key, version, resource): + """ + WebSockets connections are intiated by the client with a valid HTTP + upgrade request + """ + headers = [ + ('Host', '%s:%s' % (host, port)), + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Key', key), + ('Sec-WebSocket-Version', version) + ] + request = "GET %s HTTP/1.1" % resource + return build_handshake(headers, request) + + +def create_server_handshake(key): + """ + The server response is a valid HTTP 101 response. + """ + headers = [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', create_server_nounce(key)) + ] + request = "HTTP/1.1 101 Switching Protocols" + return build_handshake(headers, request) + + +def build_handshake(headers, request): + handshake = [request.encode('utf-8')] + for header, value in headers: + handshake.append(("%s: %s" % (header, value)).encode('utf-8')) + handshake.append(b'\r\n') + return b'\r\n'.join(handshake) + + +def read_handshake(read_bytes, num_bytes_per_read): + """ + From provided function that reads bytes, read in a + complete HTTP request, which terminates with a CLRF + """ + response = b'' + doubleCLRF = b'\r\n\r\n' + while True: + bytes = read_bytes(num_bytes_per_read) + if not bytes: + break + response += bytes + if doubleCLRF in response: + break + return response + + +def get_payload_length_pair(payload_bytestring): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + actual_length = len(payload_bytestring) + + if actual_length <= 125: + length_code = actual_length + elif actual_length >= 126 and actual_length <= 65535: + length_code = 126 + else: + length_code = 127 + return (length_code, actual_length) + + +def process_handshake_from_client(handshake): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Key'] + return key + + +def process_handshake_from_server(handshake, client_nounce): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Accept'] + return key + + +def headers_from_http_message(http_message): + return mimetools.Message( + StringIO.StringIO(http_message.split('\r\n', 1)[1]) + ) + + +def create_server_nounce(client_nounce): + return base64.b64encode( + hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex') + ) + + +def create_client_nounce(): + return base64.b64encode(os.urandom(16)).decode('utf-8') diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py deleted file mode 100644 index 9b4faa33..00000000 --- a/netlib/websockets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from __future__ import (absolute_import, print_function, division) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py deleted file mode 100644 index 337c5496..00000000 --- a/netlib/websockets/implementations.py +++ /dev/null @@ -1,80 +0,0 @@ -from netlib import tcp -from base64 import b64encode -from StringIO import StringIO -from . import websockets as ws -import struct -import SocketServer -import os - -# Simple websocket client and servers that are used to exercise the functionality in websockets.py -# These are *not* fully RFC6455 compliant - -class WebSocketsEchoHandler(tcp.BaseHandler): - def __init__(self, connection, address, server): - super(WebSocketsEchoHandler, self).__init__(connection, address, server) - self.handshake_done = False - - def handle(self): - while True: - if not self.handshake_done: - self.handshake() - else: - self.read_next_message() - - def read_next_message(self): - decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload - self.on_message(decoded) - - def send_message(self, message): - frame = ws.Frame.default(message, from_client = False) - self.wfile.write(frame.safe_to_bytes()) - self.wfile.flush() - - def handshake(self): - client_hs = ws.read_handshake(self.rfile.read, 1) - key = ws.process_handshake_from_client(client_hs) - response = ws.create_server_handshake(key) - self.wfile.write(response) - self.wfile.flush() - self.handshake_done = True - - def on_message(self, message): - if message is not None: - self.send_message(message) - - -class WebSocketsClient(tcp.TCPClient): - def __init__(self, address, source_address=None): - super(WebSocketsClient, self).__init__(address, source_address) - self.version = "13" - self.client_nounce = ws.create_client_nounce() - self.resource = "/" - - def connect(self): - super(WebSocketsClient, self).connect() - - handshake = ws.create_client_handshake( - self.address.host, - self.address.port, - self.client_nounce, - self.version, - self.resource - ) - - self.wfile.write(handshake) - self.wfile.flush() - - server_handshake = ws.read_handshake(self.rfile.read, 1) - - server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) - - if not server_nounce == ws.create_server_nounce(self.client_nounce): - self.close() - - def read_next_message(self): - return ws.Frame.from_byte_stream(self.rfile.read).payload - - def send_message(self, message): - frame = ws.Frame.default(message, from_client = True) - self.wfile.write(frame.safe_to_bytes()) - self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py deleted file mode 100644 index 86d98caf..00000000 --- a/netlib/websockets/websockets.py +++ /dev/null @@ -1,410 +0,0 @@ -from __future__ import absolute_import - -import base64 -import hashlib -import mimetools -import StringIO -import os -import struct -import io - -from .. import utils - -# Colleciton of utility functions that implement small portions of the RFC6455 -# WebSockets Protocol Useful for building WebSocket clients and servers. -# -# Emphassis is on readabilty, simplicity and modularity, not performance or -# completeness -# -# This is a work in progress and does not yet contain all the utilites need to -# create fully complient client/servers # -# Spec: https://tools.ietf.org/html/rfc6455 - -# The magic sha that websocket servers must know to prove they understand -# RFC6455 -websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' - - -class WebSocketFrameValidationException(Exception): - pass - - -class Frame(object): - """ - Represents one websockets frame. - Constructor takes human readable forms of the frame components - from_bytes() is also avaliable. - - WebSockets Frame as defined in RFC6455 - - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-------+-+-------------+-------------------------------+ - |F|R|R|R| opcode|M| Payload len | Extended payload length | - |I|S|S|S| (4) |A| (7) | (16/64) | - |N|V|V|V| |S| | (if payload len==126/127) | - | |1|2|3| |K| | | - +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - | Extended payload length continued, if payload len == 127 | - + - - - - - - - - - - - - - - - +-------------------------------+ - | |Masking-key, if MASK set to 1 | - +-------------------------------+-------------------------------+ - | Masking-key (continued) | Payload Data | - +-------------------------------- - - - - - - - - - - - - - - - + - : Payload Data continued ... : - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - | Payload Data continued ... | - +---------------------------------------------------------------+ - """ - def __init__( - self, - fin, # decmial integer 1 or 0 - opcode, # decmial integer 1 - 4 - mask_bit, # decimal integer 1 or 0 - payload_length_code, # decimal integer 1 - 127 - decoded_payload, # bytestring - rsv1 = 0, # decimal integer 1 or 0 - rsv2 = 0, # decimal integer 1 or 0 - rsv3 = 0, # decimal integer 1 or 0 - payload = None, # bytestring - masking_key = None, # 32 bit byte string - actual_payload_length = None, # any decimal integer - ): - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - self.opcode = opcode - self.mask_bit = mask_bit - self.payload_length_code = payload_length_code - self.masking_key = masking_key - self.payload = payload - self.decoded_payload = decoded_payload - self.actual_payload_length = actual_payload_length - - @classmethod - def from_bytes(cls, bytestring): - """ - Construct a websocket frame from an in-memory bytestring to construct - a frame from a stream of bytes, use from_byte_stream() directly - """ - return cls.from_byte_stream(io.BytesIO(bytestring).read) - - @classmethod - def default(cls, message, from_client = False): - """ - Construct a basic websocket frame from some default values. - Creates a non-fragmented text frame. - """ - length_code, actual_length = get_payload_length_pair(message) - - if from_client: - mask_bit = 1 - masking_key = random_masking_key() - payload = apply_mask(message, masking_key) - else: - mask_bit = 0 - masking_key = None - payload = message - - return cls( - fin = 1, # final frame - opcode = 1, # text - mask_bit = mask_bit, - payload_length_code = length_code, - payload = payload, - masking_key = masking_key, - decoded_payload = message, - actual_payload_length = actual_length - ) - - def is_valid(self): - """ - Validate websocket frame invariants, call at anytime to ensure the - Frame has not been corrupted. - """ - try: - assert 0 <= self.fin <= 1 - assert 0 <= self.rsv1 <= 1 - assert 0 <= self.rsv2 <= 1 - assert 0 <= self.rsv3 <= 1 - assert 1 <= self.opcode <= 4 - assert 0 <= self.mask_bit <= 1 - assert 1 <= self.payload_length_code <= 127 - - if self.mask_bit == 1: - assert 1 <= len(self.masking_key) <= 4 - else: - assert self.masking_key is None - - assert self.actual_payload_length == len(self.payload) - - if self.payload is not None and self.masking_key is not None: - assert apply_mask(self.payload, self.masking_key) == self.decoded_payload - - return True - except AssertionError: - return False - - def human_readable(self): - return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask_bit - " + str(self.mask_bit)), - ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + str(self.masking_key)), - ("payload - " + str(self.payload)), - ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length)) - ]) - - def safe_to_bytes(self): - if self.is_valid(): - return self.to_bytes() - else: - raise WebSocketFrameValidationException() - - def to_bytes(self): - """ - Serialize the frame back into the wire format, returns a bytestring - If you haven't checked is_valid_frame() then there's no guarentees - that the serialized bytes will be correct. see safe_to_bytes() - """ - max_16_bit_int = (1 << 16) - max_64_bit_int = (1 << 63) - - # break down of the bit-math used to construct the first byte from the - # frame's integer values first shift the significant bit into the - # correct position - # 00000001 << 7 = 10000000 - # ... - # then combine: - # - # 10000000 fin - # 01000000 res1 - # 00100000 res2 - # 00010000 res3 - # 00000001 opcode - # -------- OR - # 11110001 = first_byte - - first_byte = (self.fin << 7) | (self.rsv1 << 6) |\ - (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode - - second_byte = (self.mask_bit << 7) | self.payload_length_code - - bytes = chr(first_byte) + chr(second_byte) - - if self.actual_payload_length < 126: - pass - elif self.actual_payload_length < max_16_bit_int: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - bytes += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < max_64_bit_int: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - bytes += struct.pack('!Q', self.actual_payload_length) - - if self.masking_key is not None: - bytes += self.masking_key - - bytes += self.payload # already will be encoded if neccessary - return bytes - - @classmethod - def from_byte_stream(cls, read_bytes): - """ - read a websockets frame sent by a server or client - - read_bytes is a function that can be backed - by sockets or by any byte reader. So this - function may be used to read frames from disk/wire/memory - """ - first_byte = utils.bytes_to_int(read_bytes(1)) - second_byte = utils.bytes_to_int(read_bytes(1)) - - # grab the left most bit - fin = first_byte >> 7 - # grab right most 4 bits by and-ing with 00001111 - opcode = first_byte & 15 - # grab left most bit - mask_bit = second_byte >> 7 - # grab the next 7 bits - payload_length = second_byte & 127 - - # payload_lengthy > 125 indicates you need to read more bytes - # to get the actual payload length - if payload_length <= 125: - actual_payload_length = payload_length - - elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(read_bytes(2)) - - elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(read_bytes(8)) - - # masking key only present if mask bit set - if mask_bit == 1: - masking_key = read_bytes(4) - else: - masking_key = None - - payload = read_bytes(actual_payload_length) - - if mask_bit == 1: - decoded_payload = apply_mask(payload, masking_key) - else: - decoded_payload = payload - - return cls( - fin = fin, - opcode = opcode, - mask_bit = mask_bit, - payload_length_code = payload_length, - payload = payload, - masking_key = masking_key, - decoded_payload = decoded_payload, - actual_payload_length = actual_payload_length - ) - - def __eq__(self, other): - return ( - self.fin == other.fin and - self.rsv1 == other.rsv1 and - self.rsv2 == other.rsv2 and - self.rsv3 == other.rsv3 and - self.opcode == other.opcode and - self.mask_bit == other.mask_bit and - self.payload_length_code == other.payload_length_code and - self.masking_key == other.masking_key and - self.payload == other.payload and - self.decoded_payload == other.decoded_payload and - self.actual_payload_length == other.actual_payload_length - ) - - -def apply_mask(message, masking_key): - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns - - This method both encodes and decodes strings with the provided mask - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - masks = [utils.bytes_to_int(byte) for byte in masking_key] - result = "" - for char in message: - result += chr(ord(char) ^ masks[len(result) % 4]) - return result - - -def random_masking_key(): - return os.urandom(4) - - -def create_client_handshake(host, port, key, version, resource): - """ - WebSockets connections are intiated by the client with a valid HTTP - upgrade request - """ - headers = [ - ('Host', '%s:%s' % (host, port)), - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Key', key), - ('Sec-WebSocket-Version', version) - ] - request = "GET %s HTTP/1.1" % resource - return build_handshake(headers, request) - - -def create_server_handshake(key): - """ - The server response is a valid HTTP 101 response. - """ - headers = [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nounce(key)) - ] - request = "HTTP/1.1 101 Switching Protocols" - return build_handshake(headers, request) - - -def build_handshake(headers, request): - handshake = [request.encode('utf-8')] - for header, value in headers: - handshake.append(("%s: %s" % (header, value)).encode('utf-8')) - handshake.append(b'\r\n') - return b'\r\n'.join(handshake) - - -def read_handshake(read_bytes, num_bytes_per_read): - """ - From provided function that reads bytes, read in a - complete HTTP request, which terminates with a CLRF - """ - response = b'' - doubleCLRF = b'\r\n\r\n' - while True: - bytes = read_bytes(num_bytes_per_read) - if not bytes: - break - response += bytes - if doubleCLRF in response: - break - return response - - -def get_payload_length_pair(payload_bytestring): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - actual_length = len(payload_bytestring) - - if actual_length <= 125: - length_code = actual_length - elif actual_length >= 126 and actual_length <= 65535: - length_code = 126 - else: - length_code = 127 - return (length_code, actual_length) - - -def process_handshake_from_client(handshake): - headers = headers_from_http_message(handshake) - if headers.get("Upgrade", None) != "websocket": - return - key = headers['Sec-WebSocket-Key'] - return key - - -def process_handshake_from_server(handshake, client_nounce): - headers = headers_from_http_message(handshake) - if headers.get("Upgrade", None) != "websocket": - return - key = headers['Sec-WebSocket-Accept'] - return key - - -def headers_from_http_message(http_message): - return mimetools.Message( - StringIO.StringIO(http_message.split('\r\n', 1)[1]) - ) - - -def create_server_nounce(client_nounce): - return base64.b64encode( - hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex') - ) - - -def create_client_nounce(): - return base64.b64encode(os.urandom(16)).decode('utf-8') diff --git a/test/test_websockets.py b/test/test_websockets.py index d1753638..62268423 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,19 +1,92 @@ from netlib import tcp from netlib import test -from netlib.websockets import implementations as impl -from netlib.websockets import websockets as ws +from netlib import websockets import os from nose.tools import raises +class WebSocketsEchoHandler(tcp.BaseHandler): + def __init__(self, connection, address, server): + super(WebSocketsEchoHandler, self).__init__( + connection, address, server + ) + self.handshake_done = False + + def handle(self): + while True: + if not self.handshake_done: + self.handshake() + else: + self.read_next_message() + + def read_next_message(self): + decoded = websockets.Frame.from_byte_stream(self.rfile.read).decoded_payload + self.on_message(decoded) + + def send_message(self, message): + frame = websockets.Frame.default(message, from_client = False) + self.wfile.write(frame.safe_to_bytes()) + self.wfile.flush() + + def handshake(self): + client_hs = websockets.read_handshake(self.rfile.read, 1) + key = websockets.process_handshake_from_client(client_hs) + response = websockets.create_server_handshake(key) + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + + def on_message(self, message): + if message is not None: + self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + def __init__(self, address, source_address=None): + super(WebSocketsClient, self).__init__(address, source_address) + self.version = "13" + self.client_nounce = websockets.create_client_nounce() + self.resource = "/" + + def connect(self): + super(WebSocketsClient, self).connect() + + handshake = websockets.create_client_handshake( + self.address.host, + self.address.port, + self.client_nounce, + self.version, + self.resource + ) + + self.wfile.write(handshake) + self.wfile.flush() + + server_handshake = websockets.read_handshake(self.rfile.read, 1) + server_nounce = websockets.process_handshake_from_server( + server_handshake, self.client_nounce + ) + + if not server_nounce == websockets.create_server_nounce(self.client_nounce): + self.close() + + def read_next_message(self): + return websockets.Frame.from_byte_stream(self.rfile.read).payload + + def send_message(self, message): + frame = websockets.Frame.default(message, from_client = True) + self.wfile.write(frame.safe_to_bytes()) + self.wfile.flush() + + class TestWebSockets(test.ServerTestBase): - handler = impl.WebSocketsEchoHandler + handler = WebSocketsEchoHandler def random_bytes(self, n = 100): return os.urandom(n) def echo(self, msg): - client = impl.WebSocketsClient(("127.0.0.1", self.port)) + client = WebSocketsClient(("127.0.0.1", self.port)) client.connect() client.send_message(msg) response = client.read_next_message() @@ -39,10 +112,10 @@ class TestWebSockets(test.ServerTestBase): default builder should always generate valid frames """ msg = self.random_bytes() - client_frame = ws.Frame.default(msg, from_client = True) + client_frame = websockets.Frame.default(msg, from_client = True) assert client_frame.is_valid() - server_frame = ws.Frame.default(msg, from_client = False) + server_frame = websockets.Frame.default(msg, from_client = False) assert server_frame.is_valid() def test_serialization_bijection(self): @@ -52,26 +125,26 @@ class TestWebSockets(test.ServerTestBase): """ for is_client in [True, False]: for num_bytes in [100, 50000, 150000]: - frame = ws.Frame.default( + frame = websockets.Frame.default( self.random_bytes(num_bytes), is_client ) - assert frame == ws.Frame.from_bytes(frame.to_bytes()) + assert frame == websockets.Frame.from_bytes(frame.to_bytes()) bytes = b'\x81\x11cba' - assert ws.Frame.from_bytes(bytes).to_bytes() == bytes + assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes - @raises(ws.WebSocketFrameValidationException) + @raises(websockets.WebSocketFrameValidationException) def test_safe_to_bytes(self): - frame = ws.Frame.default(self.random_bytes(8)) + frame = websockets.Frame.default(self.random_bytes(8)) frame.actual_payload_length = 1 # corrupt the frame frame.safe_to_bytes() -class BadHandshakeHandler(impl.WebSocketsEchoHandler): +class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - client_hs = ws.read_handshake(self.rfile.read, 1) - ws.process_handshake_from_client(client_hs) - response = ws.create_server_handshake("malformed_key") + client_hs = websockets.read_handshake(self.rfile.read, 1) + websockets.process_handshake_from_client(client_hs) + response = websockets.create_server_handshake("malformed_key") self.wfile.write(response) self.wfile.flush() self.handshake_done = True @@ -85,6 +158,6 @@ class TestBadHandshake(test.ServerTestBase): @raises(tcp.NetLibDisconnect) def test(self): - client = impl.WebSocketsClient(("127.0.0.1", self.port)) + client = WebSocketsClient(("127.0.0.1", self.port)) client.connect() client.send_message("hello") -- cgit v1.2.3 From 4ea1ccb638366fbdac2d294c23ce8052dcf250c2 Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sun, 19 Apr 2015 22:18:30 -0700 Subject: fixing test coverage, adding to_file/from_file reader writes to match socks.py --- netlib/websockets.py | 62 +++++++++++++++++++++++++++---------------------- pathod | 1 + test/test_websockets.py | 61 ++++++++++++++++++++++++++++++++++-------------- 3 files changed, 79 insertions(+), 45 deletions(-) create mode 160000 pathod diff --git a/netlib/websockets.py b/netlib/websockets.py index 83e90238..5b9d8fbd 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -25,6 +25,11 @@ from . import utils websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +class CONST(object): + MAX_16_BIT_INT = (1 << 16) + MAX_64_BIT_INT = (1 << 64) + + class WebSocketFrameValidationException(Exception): pass @@ -81,14 +86,6 @@ class Frame(object): self.decoded_payload = decoded_payload self.actual_payload_length = actual_payload_length - @classmethod - def from_bytes(cls, bytestring): - """ - Construct a websocket frame from an in-memory bytestring to construct - a frame from a stream of bytes, use from_byte_stream() directly - """ - return cls.from_byte_stream(io.BytesIO(bytestring).read) - @classmethod def default(cls, message, from_client = False): """ @@ -145,7 +142,7 @@ class Frame(object): except AssertionError: return False - def human_readable(self): + def human_readable(self): # pragma: nocover return "\n".join([ ("fin - " + str(self.fin)), ("rsv1 - " + str(self.rsv1)), @@ -160,6 +157,14 @@ class Frame(object): ("actual_payload_length - " + str(self.actual_payload_length)) ]) + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use from_file() directly + """ + return cls.from_file(io.BytesIO(bytestring)) + def safe_to_bytes(self): if self.is_valid(): return self.to_bytes() @@ -172,8 +177,6 @@ class Frame(object): If you haven't checked is_valid_frame() then there's no guarentees that the serialized bytes will be correct. see safe_to_bytes() """ - max_16_bit_int = (1 << 16) - max_64_bit_int = (1 << 63) # break down of the bit-math used to construct the first byte from the # frame's integer values first shift the significant bit into the @@ -199,11 +202,11 @@ class Frame(object): if self.actual_payload_length < 126: pass - elif self.actual_payload_length < max_16_bit_int: + elif self.actual_payload_length < CONST.MAX_16_BIT_INT: # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length bytes += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < max_64_bit_int: + elif self.actual_payload_length < CONST.MAX_64_BIT_INT: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length bytes += struct.pack('!Q', self.actual_payload_length) @@ -214,17 +217,20 @@ class Frame(object): bytes += self.payload # already will be encoded if neccessary return bytes + def to_file(self, writer): + writer.write(self.to_bytes()) + writer.flush() + @classmethod - def from_byte_stream(cls, read_bytes): + def from_file(cls, reader): """ read a websockets frame sent by a server or client - - read_bytes is a function that can be backed - by sockets or by any byte reader. So this - function may be used to read frames from disk/wire/memory - """ - first_byte = utils.bytes_to_int(read_bytes(1)) - second_byte = utils.bytes_to_int(read_bytes(1)) + + reader is a "file like" object that could be backed by a network stream or a disk + or an in memory stream reader + """ + first_byte = utils.bytes_to_int(reader.read(1)) + second_byte = utils.bytes_to_int(reader.read(1)) # grab the left most bit fin = first_byte >> 7 @@ -241,18 +247,18 @@ class Frame(object): actual_payload_length = payload_length elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(read_bytes(2)) + actual_payload_length = utils.bytes_to_int(reader.read(2)) elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(read_bytes(8)) + actual_payload_length = utils.bytes_to_int(reader.read(8)) # masking key only present if mask bit set if mask_bit == 1: - masking_key = read_bytes(4) + masking_key = reader.read(4) else: masking_key = None - payload = read_bytes(actual_payload_length) + payload = reader.read(actual_payload_length) if mask_bit == 1: decoded_payload = apply_mask(payload, masking_key) @@ -344,7 +350,7 @@ def build_handshake(headers, request): return b'\r\n'.join(handshake) -def read_handshake(read_bytes, num_bytes_per_read): +def read_handshake(reader, num_bytes_per_read): """ From provided function that reads bytes, read in a complete HTTP request, which terminates with a CLRF @@ -352,7 +358,7 @@ def read_handshake(read_bytes, num_bytes_per_read): response = b'' doubleCLRF = b'\r\n\r\n' while True: - bytes = read_bytes(num_bytes_per_read) + bytes = reader.read(num_bytes_per_read) if not bytes: break response += bytes @@ -386,7 +392,7 @@ def process_handshake_from_client(handshake): return key -def process_handshake_from_server(handshake, client_nounce): +def process_handshake_from_server(handshake): headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": return diff --git a/pathod b/pathod new file mode 160000 index 00000000..be450cf9 --- /dev/null +++ b/pathod @@ -0,0 +1 @@ +Subproject commit be450cf9db1d819b1023029c8d403f401e010c98 diff --git a/test/test_websockets.py b/test/test_websockets.py index 62268423..34692183 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,6 +1,7 @@ from netlib import tcp from netlib import test from netlib import websockets +import io import os from nose.tools import raises @@ -20,16 +21,15 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.read_next_message() def read_next_message(self): - decoded = websockets.Frame.from_byte_stream(self.rfile.read).decoded_payload + decoded = websockets.Frame.from_file(self.rfile).decoded_payload self.on_message(decoded) def send_message(self, message): frame = websockets.Frame.default(message, from_client = False) - self.wfile.write(frame.safe_to_bytes()) - self.wfile.flush() - + frame.to_file(self.wfile) + def handshake(self): - client_hs = websockets.read_handshake(self.rfile.read, 1) + client_hs = websockets.read_handshake(self.rfile, 1) key = websockets.process_handshake_from_client(client_hs) response = websockets.create_server_handshake(key) self.wfile.write(response) @@ -62,22 +62,18 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(handshake) self.wfile.flush() - server_handshake = websockets.read_handshake(self.rfile.read, 1) - server_nounce = websockets.process_handshake_from_server( - server_handshake, self.client_nounce - ) + server_handshake = websockets.read_handshake(self.rfile, 1) + server_nounce = websockets.process_handshake_from_server(server_handshake) if not server_nounce == websockets.create_server_nounce(self.client_nounce): self.close() def read_next_message(self): - return websockets.Frame.from_byte_stream(self.rfile.read).payload + return websockets.Frame.from_file(self.rfile).payload def send_message(self, message): frame = websockets.Frame.default(message, from_client = True) - self.wfile.write(frame.safe_to_bytes()) - self.wfile.flush() - + frame.to_file(self.wfile) class TestWebSockets(test.ServerTestBase): handler = WebSocketsEchoHandler @@ -128,10 +124,10 @@ class TestWebSockets(test.ServerTestBase): frame = websockets.Frame.default( self.random_bytes(num_bytes), is_client ) - assert frame == websockets.Frame.from_bytes(frame.to_bytes()) + assert frame == websockets.Frame.from_bytes(frame.safe_to_bytes()) - bytes = b'\x81\x11cba' - assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes + bytes = b'\x81\x03cba' + assert websockets.Frame.from_bytes(bytes).safe_to_bytes() == bytes @raises(websockets.WebSocketFrameValidationException) def test_safe_to_bytes(self): @@ -139,10 +135,41 @@ class TestWebSockets(test.ServerTestBase): frame.actual_payload_length = 1 # corrupt the frame frame.safe_to_bytes() + def test_handshake(self): + bad_upgrade = "not_websockets" + bad_header_handshake = websockets.build_handshake([ + ('Host', '%s:%s' % ("a", "b")), + ('Connection', "c"), + ('Upgrade', bad_upgrade), + ('Sec-WebSocket-Key', "d"), + ('Sec-WebSocket-Version', "e") + ], "f") + + # check behavior when required header values are missing + assert None == websockets.process_handshake_from_server(bad_header_handshake) + assert None == websockets.process_handshake_from_client(bad_header_handshake) + + key = "test_key" + + client_handshake = websockets.create_client_handshake("a","b",key,"d","e") + assert key == websockets.process_handshake_from_client(client_handshake) + + server_handshake = websockets.create_server_handshake(key) + assert websockets.create_server_nounce(key) == websockets.process_handshake_from_server(server_handshake) + + handshake = websockets.create_client_handshake("a","b","c","d","e") + stream = io.BytesIO(handshake) + assert handshake == websockets.read_handshake(stream, 1) + + # ensure readhandshake doesn't loop forever on empty stream + empty_stream = io.BytesIO("") + assert "" == websockets.read_handshake(empty_stream, 1) + + class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - client_hs = websockets.read_handshake(self.rfile.read, 1) + client_hs = websockets.read_handshake(self.rfile, 1) websockets.process_handshake_from_client(client_hs) response = websockets.create_server_handshake("malformed_key") self.wfile.write(response) -- cgit v1.2.3 From fae964d3157eeaa471392d5ba53615925729411a Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sun, 19 Apr 2015 22:20:53 -0700 Subject: remove subproject commit --- pathod | 1 - 1 file changed, 1 deletion(-) delete mode 160000 pathod diff --git a/pathod b/pathod deleted file mode 160000 index be450cf9..00000000 --- a/pathod +++ /dev/null @@ -1 +0,0 @@ -Subproject commit be450cf9db1d819b1023029c8d403f401e010c98 -- cgit v1.2.3 From 2c9079b518ccb453dc3670cb358281df5ceb7362 Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sun, 19 Apr 2015 22:22:15 -0700 Subject: whitespace --- test/test_websockets.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_websockets.py b/test/test_websockets.py index 34692183..035f9e17 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -165,8 +165,7 @@ class TestWebSockets(test.ServerTestBase): empty_stream = io.BytesIO("") assert "" == websockets.read_handshake(empty_stream, 1) - - + class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): client_hs = websockets.read_handshake(self.rfile, 1) -- cgit v1.2.3 From 2c660d76337b11eb438a2978ec3bda3ac10babd5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 11:05:12 +1200 Subject: Migrate requeset reading from mitmproxy to netlib --- netlib/http.py | 124 +++++++++++++++++++++++++++++++++++++++++++++++++++++- netlib/utils.py | 2 +- test/test_http.py | 72 +++++++++++++++++++++++++++++++ 3 files changed, 195 insertions(+), 3 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 26438863..aacdd1d4 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,7 +1,10 @@ from __future__ import (absolute_import, print_function, division) -import string, urlparse, binascii +import collections +import string +import urlparse +import binascii import sys -from . import odict, utils +from . import odict, utils, tcp class HttpError(Exception): @@ -30,6 +33,19 @@ def _is_valid_host(host): return True +def get_line(fp): + """ + Get a line, possibly preceded by a blank. + """ + line = fp.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = fp.readline() + if line == "": + raise tcp.NetLibDisconnect() + return line + + def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. @@ -436,3 +452,107 @@ def expected_http_body_size(headers, is_request, request_method, response_code): if is_request: return 0 return -1 + + +Request = collections.namedtuple( + "Request", + [ + "form_in", + "method", + "scheme", + "host", + "port", + "path", + "httpversion", + "headers", + "content" + ] +) + + +def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): + """ + Parse an HTTP request from a file stream + + Args: + rfile (file): Input file to read from + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. + """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = get_line(rfile) + + request_line_parts = parse_init(request_line) + if not request_line_parts: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + path = None + else: + form_in = "absolute" + r = parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = read_headers(rfile) + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect") + if expect_header and expect_header.lower() == "100-continue" and httpversion >= (1, 1): + wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' + ) + wfile.flush() + del headers['expect'] + + if include_body: + content = read_http_body( + rfile, headers, body_size_limit, method, None, True + ) + + return Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content + ) diff --git a/netlib/utils.py b/netlib/utils.py index 03a70977..57532453 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -46,4 +46,4 @@ def hexdump(s): parts.append( (o, x, cleanBin(part, True)) ) - return parts \ No newline at end of file + return parts diff --git a/test/test_http.py b/test/test_http.py index b1c62458..5bd7cab2 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -366,3 +366,75 @@ def test_parse_http_basic_auth(): assert not http.parse_http_basic_auth("foo bar") v = "basic " + binascii.b2a_base64("foo") assert not http.parse_http_basic_auth(v) + + +def test_get_line(): + r = cStringIO.StringIO("\nfoo") + assert http.get_line(r) == "foo" + tutils.raises(tcp.NetLibDisconnect, http.get_line, r) + + +class TestReadRequest(): + + def tst(self, data, **kwargs): + r = cStringIO.StringIO(data) + return http.read_request(r, **kwargs) + + def test_invalid(self): + tutils.raises( + "bad http request", + self.tst, + "xxx" + ) + tutils.raises( + "bad http request line", + self.tst, + "get /\xff HTTP/1.1" + ) + tutils.raises( + "invalid headers", + self.tst, + "get / HTTP/1.1\r\nfoo" + ) + + def test_asterisk_form_in(self): + v = self.tst("OPTIONS * HTTP/1.1") + assert v.form_in == "relative" + assert v.method == "OPTIONS" + + def test_absolute_form_in(self): + tutils.raises( + "Bad HTTP request line", + self.tst, + "GET oops-no-protocol.com HTTP/1.1" + ) + v = self.tst("GET http://address:22/ HTTP/1.1") + assert v.form_in == "absolute" + assert v.port == 22 + assert v.host == "address" + assert v.scheme == "http" + + def test_connect(self): + tutils.raises( + "Bad HTTP request line", + self.tst, + "CONNECT oops-no-port.com HTTP/1.1" + ) + v = self.tst("CONNECT foo.com:443 HTTP/1.1") + assert v.form_in == "authority" + assert v.method == "CONNECT" + assert v.port == 443 + assert v.host == "foo.com" + + def test_expect(self): + w = cStringIO.StringIO() + r = cStringIO.StringIO( + "GET / HTTP/1.1\r\n" + "Content-Length: 3\r\n" + "Expect: 100-continue\r\n\r\n" + "foobar", + ) + v = http.read_request(r, wfile=w) + assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" + assert v.content == "foo" + assert r.read(3) == "bar" -- cgit v1.2.3 From dd7ea896f24514bb2534b3762255e99f0aabc055 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 11:11:16 +1200 Subject: Return a named tuple from read_response --- netlib/http.py | 18 +++++++++++++++--- test/test_http.py | 8 +++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index aacdd1d4..5501ce73 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -314,6 +314,18 @@ def parse_response_line(line): return (proto, code, msg) +Response = collections.namedtuple( + "Response", + [ + "httpversion", + "code", + "msg", + "headers", + "content" + ] +) + + def read_response(rfile, request_method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. @@ -352,7 +364,7 @@ def read_response(rfile, request_method, body_size_limit, include_body=True): # if include_body==False then a None content means the body should be # read separately content = None - return httpversion, code, msg, headers, content + return Response(httpversion, code, msg, headers, content) def read_http_body(*args, **kwargs): @@ -531,8 +543,8 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): if headers is None: raise HttpError(400, "Invalid headers") - expect_header = headers.get_first("expect") - if expect_header and expect_header.lower() == "100-continue" and httpversion >= (1, 1): + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): wfile.write( 'HTTP/1.1 100 Continue\r\n' '\r\n' diff --git a/test/test_http.py b/test/test_http.py index 5bd7cab2..4f8ef2c5 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -254,15 +254,17 @@ class TestReadResponseNoContentLength(test.ServerTestBase): def test_no_content_length(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None) - assert content == "bar\r\n\r\n" + resp = http.read_response(c.rfile, "GET", None) + assert resp.content == "bar\r\n\r\n" def test_read_response(): def tst(data, method, limit, include_body=True): data = textwrap.dedent(data) r = cStringIO.StringIO(data) - return http.read_response(r, method, limit, include_body = include_body) + return http.read_response( + r, method, limit, include_body = include_body + ) tutils.raises("server disconnect", tst, "", "GET", None) tutils.raises("invalid server response", tst, "foo", "GET", None) -- cgit v1.2.3 From 7d83e388aa78bb3637f71a4afb60af1baecb0314 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 11:19:00 +1200 Subject: Whitespace, pep8, mixed indentation --- netlib/http.py | 19 +++++++++++---- netlib/utils.py | 4 +++- test/test_http.py | 70 ++++++++++++++++++++++++++++++++++++++++++++----------- 3 files changed, 74 insertions(+), 19 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 5501ce73..b925fe87 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -331,12 +331,15 @@ def read_response(rfile, request_method, body_size_limit, include_body=True): Return an (httpversion, code, msg, headers, content) tuple. By default, both response header and body are read. - If include_body=False is specified, content may be one of the following: + If include_body=False is specified, content may be one of the + following: - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a response to a HEAD request) + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) """ line = rfile.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message + # Possible leftover from previous message + if line == "\r\n" or line == "\n": line = rfile.readline() if not line: raise HttpErrorConnClosed(502, "Server disconnect.") @@ -373,7 +376,15 @@ def read_http_body(*args, **kwargs): ) -def read_http_body_chunked(rfile, headers, limit, request_method, response_code, is_request, max_chunk_size=None): +def read_http_body_chunked( + rfile, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None +): """ Read an HTTP message body: diff --git a/netlib/utils.py b/netlib/utils.py index 57532453..66bbdb5e 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -8,9 +8,11 @@ def isascii(s): return False return True + # best way to do it in python 2.x def bytes_to_int(i): - return int(i.encode('hex'), 16) + return int(i.encode('hex'), 16) + def cleanBin(s, fixspacing=False): """ diff --git a/test/test_http.py b/test/test_http.py index 4f8ef2c5..8b99c769 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -1,4 +1,6 @@ -import cStringIO, textwrap, binascii +import cStringIO +import textwrap +import binascii from netlib import http, odict, tcp, test import tutils @@ -21,7 +23,11 @@ def test_read_chunked(): h["transfer-encoding"] = ["chunked"] s = cStringIO.StringIO("1\r\na\r\n0\r\n") - tutils.raises("malformed chunked body", http.read_http_body, s, h, None, "GET", None, True) + tutils.raises( + "malformed chunked body", + http.read_http_body, + s, h, None, "GET", None, True + ) s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") assert http.read_http_body(s, h, None, "GET", None, True) == "a" @@ -30,13 +36,25 @@ def test_read_chunked(): assert http.read_http_body(s, h, None, "GET", None, True) == "a" s = cStringIO.StringIO("\r\n") - tutils.raises("closed prematurely", http.read_http_body, s, h, None, "GET", None, True) + tutils.raises( + "closed prematurely", + http.read_http_body, + s, h, None, "GET", None, True + ) s = cStringIO.StringIO("1\r\nfoo") - tutils.raises("malformed chunked body", http.read_http_body, s, h, None, "GET", None, True) + tutils.raises( + "malformed chunked body", + http.read_http_body, + s, h, None, "GET", None, True + ) s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", None, True) + tutils.raises( + http.HttpError, + http.read_http_body, + s, h, None, "GET", None, True + ) s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") tutils.raises("too large", http.read_http_body, s, h, 2, "GET", None, True) @@ -87,17 +105,29 @@ def test_read_http_body(): # test content length: invalid header h["content-length"] = ["foo"] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", 200, False) + tutils.raises( + http.HttpError, + http.read_http_body, + s, h, None, "GET", 200, False + ) # test content length: invalid header #2 h["content-length"] = [-1] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", 200, False) + tutils.raises( + http.HttpError, + http.read_http_body, + s, h, None, "GET", 200, False + ) # test content length: content length > actual content h["content-length"] = [5] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, 4, "GET", 200, False) + tutils.raises( + http.HttpError, + http.read_http_body, + s, h, 4, "GET", 200, False + ) # test content length: content length < actual content s = cStringIO.StringIO("testing") @@ -110,7 +140,11 @@ def test_read_http_body(): # test no content length: limit < actual content s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, s, h, 4, "GET", 200, False) + tutils.raises( + http.HttpError, + http.read_http_body, + s, h, 4, "GET", 200, False + ) # test chunked h = odict.ODictCaseless() @@ -271,11 +305,15 @@ def test_read_response(): data = """ HTTP/1.1 200 OK """ - assert tst(data, "GET", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), '') + assert tst(data, "GET", None) == ( + (1, 1), 200, 'OK', odict.ODictCaseless(), '' + ) data = """ HTTP/1.1 200 """ - assert tst(data, "GET", None) == ((1, 1), 200, '', odict.ODictCaseless(), '') + assert tst(data, "GET", None) == ( + (1, 1), 200, '', odict.ODictCaseless(), '' + ) data = """ HTTP/x 200 OK """ @@ -290,7 +328,9 @@ def test_read_response(): HTTP/1.1 200 OK """ - assert tst(data, "GET", None) == ((1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '') + assert tst(data, "GET", None) == ( + (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' + ) data = """ HTTP/1.1 200 OK @@ -315,7 +355,7 @@ def test_read_response(): foo """ - assert tst(data, "GET", None, include_body=False)[4] == None + assert tst(data, "GET", None, include_body=False)[4] is None def test_parse_url(): @@ -363,7 +403,9 @@ def test_parse_url(): def test_parse_http_basic_auth(): vals = ("basic", "foo", "bar") - assert http.parse_http_basic_auth(http.assemble_http_basic_auth(*vals)) == vals + assert http.parse_http_basic_auth( + http.assemble_http_basic_auth(*vals) + ) == vals assert not http.parse_http_basic_auth("") assert not http.parse_http_basic_auth("foo bar") v = "basic " + binascii.b2a_base64("foo") -- cgit v1.2.3 From e5f12648380cb4401f77e3cae51189ef97b603dc Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 13:39:00 +1200 Subject: Whitespace, indentation, nounce -> nonce --- netlib/http_cookies.py | 24 ++++++++--------- netlib/websockets.py | 50 +++++++++++++++++----------------- test/test_websockets.py | 71 +++++++++++++++++++++++++++++-------------------- 3 files changed, 79 insertions(+), 66 deletions(-) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index dab95ed0..8e245891 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -1,18 +1,18 @@ """ A flexible module for cookie parsing and manipulation. -This module differs from usual standards-compliant cookie modules in a number of -ways. We try to be as permissive as possible, and to retain even mal-formed +This module differs from usual standards-compliant cookie modules in a number +of ways. We try to be as permissive as possible, and to retain even mal-formed information. Duplicate cookies are preserved in parsing, and can be set in formatting. We do attempt to escape and quote values where needed, but will not reject data that violate the specs. Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do -not parse the comma-separated variant of Set-Cookie that allows multiple cookies -to be set in a single header. Technically this should be feasible, but it turns -out that violations of RFC6265 that makes the parsing problem indeterminate are -much more common than genuine occurences of the multi-cookie variants. -Serialization follows RFC6265. +not parse the comma-separated variant of Set-Cookie that allows multiple +cookies to be set in a single header. Technically this should be feasible, but +it turns out that violations of RFC6265 that makes the parsing problem +indeterminate are much more common than genuine occurences of the multi-cookie +variants. Serialization follows RFC6265. http://tools.ietf.org/html/rfc6265 http://tools.ietf.org/html/rfc2109 @@ -32,11 +32,11 @@ def _read_until(s, start, term): Read until one of the characters in term is reached. """ if start == len(s): - return "", start+1 + return "", start + 1 for i in range(start, len(s)): if s[i] in term: return s[start:i], i - return s[start:i+1], i+1 + return s[start:i + 1], i + 1 def _read_token(s, start): @@ -59,7 +59,7 @@ def _read_quoted_string(s, start): escaping = False ret = [] # Skip the first quote - for i in range(start+1, len(s)): + for i in range(start + 1, len(s)): if escaping: ret.append(s[i]) escaping = False @@ -70,7 +70,7 @@ def _read_quoted_string(s, start): pass else: ret.append(s[i]) - return "".join(ret), i+1 + return "".join(ret), i + 1 def _read_value(s, start, delims): @@ -103,7 +103,7 @@ def _read_pairs(s, off=0, specials=()): rhs = None if off < len(s): if s[off] == "=": - rhs, off = _read_value(s, off+1, ";") + rhs, off = _read_value(s, off + 1, ";") vals.append([lhs, rhs]) off += 1 if not off < len(s): diff --git a/netlib/websockets.py b/netlib/websockets.py index 5b9d8fbd..f2d467a5 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -67,23 +67,23 @@ class Frame(object): mask_bit, # decimal integer 1 or 0 payload_length_code, # decimal integer 1 - 127 decoded_payload, # bytestring - rsv1 = 0, # decimal integer 1 or 0 - rsv2 = 0, # decimal integer 1 or 0 - rsv3 = 0, # decimal integer 1 or 0 - payload = None, # bytestring - masking_key = None, # 32 bit byte string + rsv1 = 0, # decimal integer 1 or 0 + rsv2 = 0, # decimal integer 1 or 0 + rsv3 = 0, # decimal integer 1 or 0 + payload = None, # bytestring + masking_key = None, # 32 bit byte string actual_payload_length = None, # any decimal integer ): - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - self.opcode = opcode - self.mask_bit = mask_bit - self.payload_length_code = payload_length_code - self.masking_key = masking_key - self.payload = payload - self.decoded_payload = decoded_payload + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.mask_bit = mask_bit + self.payload_length_code = payload_length_code + self.masking_key = masking_key + self.payload = payload + self.decoded_payload = decoded_payload self.actual_payload_length = actual_payload_length @classmethod @@ -162,7 +162,7 @@ class Frame(object): """ Construct a websocket frame from an in-memory bytestring to construct a frame from a stream of bytes, use from_file() directly - """ + """ return cls.from_file(io.BytesIO(bytestring)) def safe_to_bytes(self): @@ -206,7 +206,7 @@ class Frame(object): # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length bytes += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < CONST.MAX_64_BIT_INT: + elif self.actual_payload_length < CONST.MAX_64_BIT_INT: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length bytes += struct.pack('!Q', self.actual_payload_length) @@ -225,10 +225,10 @@ class Frame(object): def from_file(cls, reader): """ read a websockets frame sent by a server or client - - reader is a "file like" object that could be backed by a network stream or a disk - or an in memory stream reader - """ + + reader is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ first_byte = utils.bytes_to_int(reader.read(1)) second_byte = utils.bytes_to_int(reader.read(1)) @@ -336,7 +336,7 @@ def create_server_handshake(key): headers = [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nounce(key)) + ('Sec-WebSocket-Accept', create_server_nonce(key)) ] request = "HTTP/1.1 101 Switching Protocols" return build_handshake(headers, request) @@ -406,11 +406,11 @@ def headers_from_http_message(http_message): ) -def create_server_nounce(client_nounce): +def create_server_nonce(client_nonce): return base64.b64encode( - hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex') + hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') ) -def create_client_nounce(): +def create_client_nonce(): return base64.b64encode(os.urandom(16)).decode('utf-8') diff --git a/test/test_websockets.py b/test/test_websockets.py index 035f9e17..1f2025bf 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -27,7 +27,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def send_message(self, message): frame = websockets.Frame.default(message, from_client = False) frame.to_file(self.wfile) - + def handshake(self): client_hs = websockets.read_handshake(self.rfile, 1) key = websockets.process_handshake_from_client(client_hs) @@ -45,7 +45,7 @@ class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) self.version = "13" - self.client_nounce = websockets.create_client_nounce() + self.client_nonce = websockets.create_client_nonce() self.resource = "/" def connect(self): @@ -54,7 +54,7 @@ class WebSocketsClient(tcp.TCPClient): handshake = websockets.create_client_handshake( self.address.host, self.address.port, - self.client_nounce, + self.client_nonce, self.version, self.resource ) @@ -63,9 +63,11 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.flush() server_handshake = websockets.read_handshake(self.rfile, 1) - server_nounce = websockets.process_handshake_from_server(server_handshake) + server_nonce = websockets.process_handshake_from_server( + server_handshake + ) - if not server_nounce == websockets.create_server_nounce(self.client_nounce): + if not server_nonce == websockets.create_server_nonce(self.client_nonce): self.close() def read_next_message(self): @@ -75,6 +77,7 @@ class WebSocketsClient(tcp.TCPClient): frame = websockets.Frame.default(message, from_client = True) frame.to_file(self.wfile) + class TestWebSockets(test.ServerTestBase): handler = WebSocketsEchoHandler @@ -124,7 +127,9 @@ class TestWebSockets(test.ServerTestBase): frame = websockets.Frame.default( self.random_bytes(num_bytes), is_client ) - assert frame == websockets.Frame.from_bytes(frame.safe_to_bytes()) + assert frame == websockets.Frame.from_bytes( + frame.safe_to_bytes() + ) bytes = b'\x81\x03cba' assert websockets.Frame.from_bytes(bytes).safe_to_bytes() == bytes @@ -136,36 +141,44 @@ class TestWebSockets(test.ServerTestBase): frame.safe_to_bytes() def test_handshake(self): - bad_upgrade = "not_websockets" - bad_header_handshake = websockets.build_handshake([ - ('Host', '%s:%s' % ("a", "b")), - ('Connection', "c"), - ('Upgrade', bad_upgrade), - ('Sec-WebSocket-Key', "d"), - ('Sec-WebSocket-Version', "e") - ], "f") + bad_upgrade = "not_websockets" + bad_header_handshake = websockets.build_handshake([ + ('Host', '%s:%s' % ("a", "b")), + ('Connection', "c"), + ('Upgrade', bad_upgrade), + ('Sec-WebSocket-Key', "d"), + ('Sec-WebSocket-Version', "e") + ], "f") + + # check behavior when required header values are missing + assert None is websockets.process_handshake_from_server( + bad_header_handshake + ) + assert None is websockets.process_handshake_from_client( + bad_header_handshake + ) - # check behavior when required header values are missing - assert None == websockets.process_handshake_from_server(bad_header_handshake) - assert None == websockets.process_handshake_from_client(bad_header_handshake) + key = "test_key" - key = "test_key" + client_handshake = websockets.create_client_handshake( + "a", "b", key, "d", "e" + ) + assert key == websockets.process_handshake_from_client( + client_handshake + ) - client_handshake = websockets.create_client_handshake("a","b",key,"d","e") - assert key == websockets.process_handshake_from_client(client_handshake) + server_handshake = websockets.create_server_handshake(key) + assert websockets.create_server_nonce(key) == websockets.process_handshake_from_server(server_handshake) - server_handshake = websockets.create_server_handshake(key) - assert websockets.create_server_nounce(key) == websockets.process_handshake_from_server(server_handshake) + handshake = websockets.create_client_handshake("a", "b", "c", "d", "e") + stream = io.BytesIO(handshake) + assert handshake == websockets.read_handshake(stream, 1) - handshake = websockets.create_client_handshake("a","b","c","d","e") - stream = io.BytesIO(handshake) - assert handshake == websockets.read_handshake(stream, 1) + # ensure readhandshake doesn't loop forever on empty stream + empty_stream = io.BytesIO("") + assert "" == websockets.read_handshake(empty_stream, 1) - # ensure readhandshake doesn't loop forever on empty stream - empty_stream = io.BytesIO("") - assert "" == websockets.read_handshake(empty_stream, 1) - class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): client_hs = websockets.read_handshake(self.rfile, 1) -- cgit v1.2.3 From 3e0a71ea345131a5f2dcc9581a7d93b8ebe09b13 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 22:39:45 +1200 Subject: websockets: refactor to use http and header functions in http.py --- netlib/http.py | 126 ++++++++++++++++++++++++++---------------------- netlib/websockets.py | 108 ++++++++++++++--------------------------- test/test_websockets.py | 112 ++++++++++++++++++------------------------ 3 files changed, 152 insertions(+), 194 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index b925fe87..fe27240a 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -4,7 +4,7 @@ import string import urlparse import binascii import sys -from . import odict, utils, tcp +from . import odict, utils, tcp, http_status class HttpError(Exception): @@ -314,62 +314,6 @@ def parse_response_line(line): return (proto, code, msg) -Response = collections.namedtuple( - "Response", - [ - "httpversion", - "code", - "msg", - "headers", - "content" - ] -) - - -def read_response(rfile, request_method, body_size_limit, include_body=True): - """ - Return an (httpversion, code, msg, headers, content) tuple. - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - line = rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = read_headers(rfile) - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = read_http_body( - rfile, - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return Response(httpversion, code, msg, headers, content) - - def read_http_body(*args, **kwargs): return "".join( content for _, content, _ in read_http_body_chunked(*args, **kwargs) @@ -579,3 +523,71 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): headers, content ) + + +Response = collections.namedtuple( + "Response", + [ + "httpversion", + "code", + "msg", + "headers", + "content" + ] +) + + +def read_response(rfile, request_method, body_size_limit, include_body=True): + """ + Return an (httpversion, code, msg, headers, content) tuple. + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ + line = rfile.readline() + # Possible leftover from previous message + if line == "\r\n" or line == "\n": + line = rfile.readline() + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = read_headers(rfile) + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = read_http_body( + rfile, + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return Response(httpversion, code, msg, headers, content) + + +def request_preamble(method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) + + +def response_preamble(code, message=None, http_major="1", http_minor="1"): + if message is None: + message = http_status.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/websockets.py b/netlib/websockets.py index f2d467a5..a03185fa 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -2,13 +2,11 @@ from __future__ import absolute_import import base64 import hashlib -import mimetools -import StringIO import os import struct import io -from . import utils +from . import utils, odict # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -23,6 +21,7 @@ from . import utils # The magic sha that websocket servers must know to prove they understand # RFC6455 websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +VERSION = "13" class CONST(object): @@ -151,9 +150,9 @@ class Frame(object): ("opcode - " + str(self.opcode)), ("mask_bit - " + str(self.mask_bit)), ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + str(self.masking_key)), - ("payload - " + str(self.payload)), - ("decoded_payload - " + str(self.decoded_payload)), + ("masking_key - " + repr(str(self.masking_key))), + ("payload - " + repr(str(self.payload))), + ("decoded_payload - " + repr(str(self.decoded_payload))), ("actual_payload_length - " + str(self.actual_payload_length)) ]) @@ -198,24 +197,24 @@ class Frame(object): second_byte = (self.mask_bit << 7) | self.payload_length_code - bytes = chr(first_byte) + chr(second_byte) + b = chr(first_byte) + chr(second_byte) if self.actual_payload_length < 126: pass elif self.actual_payload_length < CONST.MAX_16_BIT_INT: # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length - bytes += struct.pack('!H', self.actual_payload_length) + b += struct.pack('!H', self.actual_payload_length) elif self.actual_payload_length < CONST.MAX_64_BIT_INT: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length - bytes += struct.pack('!Q', self.actual_payload_length) + b += struct.pack('!Q', self.actual_payload_length) if self.masking_key is not None: - bytes += self.masking_key + b += self.masking_key - bytes += self.payload # already will be encoded if neccessary - return bytes + b += self.payload # already will be encoded if neccessary + return b def to_file(self, writer): writer.write(self.to_bytes()) @@ -313,58 +312,35 @@ def random_masking_key(): return os.urandom(4) -def create_client_handshake(host, port, key, version, resource): +def client_handshake_headers(key=None, version=VERSION): """ - WebSockets connections are intiated by the client with a valid HTTP - upgrade request + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of ODictCaseless """ - headers = [ - ('Host', '%s:%s' % (host, port)), + if not key: + key = base64.b64encode(os.urandom(16)).decode('utf-8') + return odict.ODictCaseless([ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), ('Sec-WebSocket-Key', key), ('Sec-WebSocket-Version', version) - ] - request = "GET %s HTTP/1.1" % resource - return build_handshake(headers, request) + ]) -def create_server_handshake(key): +def server_handshake_headers(key): """ The server response is a valid HTTP 101 response. """ - headers = [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nonce(key)) - ] - request = "HTTP/1.1 101 Switching Protocols" - return build_handshake(headers, request) - - -def build_handshake(headers, request): - handshake = [request.encode('utf-8')] - for header, value in headers: - handshake.append(("%s: %s" % (header, value)).encode('utf-8')) - handshake.append(b'\r\n') - return b'\r\n'.join(handshake) - - -def read_handshake(reader, num_bytes_per_read): - """ - From provided function that reads bytes, read in a - complete HTTP request, which terminates with a CLRF - """ - response = b'' - doubleCLRF = b'\r\n\r\n' - while True: - bytes = reader.read(num_bytes_per_read) - if not bytes: - break - response += bytes - if doubleCLRF in response: - break - return response + return odict.ODictCaseless( + [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', create_server_nonce(key)) + ] + ) def get_payload_length_pair(payload_bytestring): @@ -384,33 +360,19 @@ def get_payload_length_pair(payload_bytestring): return (length_code, actual_length) -def process_handshake_from_client(handshake): - headers = headers_from_http_message(handshake) - if headers.get("Upgrade", None) != "websocket": +def check_client_handshake(req): + if req.headers.get_first("upgrade", None) != "websocket": return - key = headers['Sec-WebSocket-Key'] - return key + return req.headers.get_first('sec-websocket-key') -def process_handshake_from_server(handshake): - headers = headers_from_http_message(handshake) - if headers.get("Upgrade", None) != "websocket": +def check_server_handshake(resp): + if resp.headers.get_first("upgrade", None) != "websocket": return - key = headers['Sec-WebSocket-Accept'] - return key - - -def headers_from_http_message(http_message): - return mimetools.Message( - StringIO.StringIO(http_message.split('\r\n', 1)[1]) - ) + return resp.headers.get_first('sec-websocket-accept') def create_server_nonce(client_nonce): return base64.b64encode( hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') ) - - -def create_client_nonce(): - return base64.b64encode(os.urandom(16)).decode('utf-8') diff --git a/test/test_websockets.py b/test/test_websockets.py index 1f2025bf..9b27e810 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,6 +1,4 @@ -from netlib import tcp -from netlib import test -from netlib import websockets +from netlib import tcp, test, websockets, http, odict import io import os from nose.tools import raises @@ -21,18 +19,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.read_next_message() def read_next_message(self): - decoded = websockets.Frame.from_file(self.rfile).decoded_payload - self.on_message(decoded) + frame = websockets.Frame.from_file(self.rfile) + self.on_message(frame.decoded_payload) def send_message(self, message): frame = websockets.Frame.default(message, from_client = False) frame.to_file(self.wfile) def handshake(self): - client_hs = websockets.read_handshake(self.rfile, 1) - key = websockets.process_handshake_from_client(client_hs) - response = websockets.create_server_handshake(key) - self.wfile.write(response) + req = http.read_request(self.rfile) + key = websockets.check_client_handshake(req) + + self.wfile.write(http.response_preamble(101) + "\r\n") + headers = websockets.server_handshake_headers(key) + self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True @@ -44,28 +44,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler): class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) - self.version = "13" - self.client_nonce = websockets.create_client_nonce() - self.resource = "/" + self.client_nonce = None def connect(self): super(WebSocketsClient, self).connect() - handshake = websockets.create_client_handshake( - self.address.host, - self.address.port, - self.client_nonce, - self.version, - self.resource - ) - - self.wfile.write(handshake) + preamble = http.request_preamble("GET", "/") + self.wfile.write(preamble + "\r\n") + headers = websockets.client_handshake_headers() + self.client_nonce = headers.get_first("sec-websocket-key") + self.wfile.write(headers.format() + "\r\n") self.wfile.flush() - server_handshake = websockets.read_handshake(self.rfile, 1) - server_nonce = websockets.process_handshake_from_server( - server_handshake - ) + resp = http.read_response(self.rfile, "get", None) + server_nonce = websockets.check_server_handshake(resp) if not server_nonce == websockets.create_server_nonce(self.client_nonce): self.close() @@ -140,51 +132,43 @@ class TestWebSockets(test.ServerTestBase): frame.actual_payload_length = 1 # corrupt the frame frame.safe_to_bytes() - def test_handshake(self): - bad_upgrade = "not_websockets" - bad_header_handshake = websockets.build_handshake([ - ('Host', '%s:%s' % ("a", "b")), - ('Connection', "c"), - ('Upgrade', bad_upgrade), - ('Sec-WebSocket-Key', "d"), - ('Sec-WebSocket-Version', "e") - ], "f") - - # check behavior when required header values are missing - assert None is websockets.process_handshake_from_server( - bad_header_handshake - ) - assert None is websockets.process_handshake_from_client( - bad_header_handshake - ) - - key = "test_key" - - client_handshake = websockets.create_client_handshake( - "a", "b", key, "d", "e" + def test_check_server_handshake(self): + resp = http.Response( + (1, 1), + 101, + "Switching Protocols", + websockets.server_handshake_headers("key"), + "" ) - assert key == websockets.process_handshake_from_client( - client_handshake + assert websockets.check_server_handshake(resp) + resp.headers["Upgrade"] = ["not_websocket"] + assert not websockets.check_server_handshake(resp) + + def test_check_client_handshake(self): + resp = http.Request( + "relative", + "get", + "http", + "host", + 22, + "/", + (1, 1), + websockets.client_handshake_headers("key"), + "" ) - - server_handshake = websockets.create_server_handshake(key) - assert websockets.create_server_nonce(key) == websockets.process_handshake_from_server(server_handshake) - - handshake = websockets.create_client_handshake("a", "b", "c", "d", "e") - stream = io.BytesIO(handshake) - assert handshake == websockets.read_handshake(stream, 1) - - # ensure readhandshake doesn't loop forever on empty stream - empty_stream = io.BytesIO("") - assert "" == websockets.read_handshake(empty_stream, 1) + assert websockets.check_client_handshake(resp) == "key" + resp.headers["Upgrade"] = ["not_websocket"] + assert not websockets.check_client_handshake(resp) class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - client_hs = websockets.read_handshake(self.rfile, 1) - websockets.process_handshake_from_client(client_hs) - response = websockets.create_server_handshake("malformed_key") - self.wfile.write(response) + client_hs = http.read_request(self.rfile) + websockets.check_client_handshake(client_hs) + + self.wfile.write(http.response_preamble(101) + "\r\n") + headers = websockets.server_handshake_headers("malformed key") + self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True -- cgit v1.2.3 From 1b509d5aea31a636b6c8ce854e0dd685e34d03de Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 22:51:01 +1200 Subject: Whitespace, interface simplification - safe_tobytes doesn't buy us much - move masking key generation inline --- netlib/websockets.py | 17 ++--------------- test/test_websockets.py | 13 +++---------- 2 files changed, 5 insertions(+), 25 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index a03185fa..0cd4dba1 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -29,10 +29,6 @@ class CONST(object): MAX_64_BIT_INT = (1 << 64) -class WebSocketFrameValidationException(Exception): - pass - - class Frame(object): """ Represents one websockets frame. @@ -95,7 +91,8 @@ class Frame(object): if from_client: mask_bit = 1 - masking_key = random_masking_key() + # Random masking key + masking_key = os.urandom(4) payload = apply_mask(message, masking_key) else: mask_bit = 0 @@ -164,12 +161,6 @@ class Frame(object): """ return cls.from_file(io.BytesIO(bytestring)) - def safe_to_bytes(self): - if self.is_valid(): - return self.to_bytes() - else: - raise WebSocketFrameValidationException() - def to_bytes(self): """ Serialize the frame back into the wire format, returns a bytestring @@ -308,10 +299,6 @@ def apply_mask(message, masking_key): return result -def random_masking_key(): - return os.urandom(4) - - def client_handshake_headers(key=None, version=VERSION): """ Create the headers for a valid HTTP upgrade request. If Key is not diff --git a/test/test_websockets.py b/test/test_websockets.py index 9b27e810..3fc67dfe 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,5 +1,4 @@ -from netlib import tcp, test, websockets, http, odict -import io +from netlib import tcp, test, websockets, http import os from nose.tools import raises @@ -120,17 +119,11 @@ class TestWebSockets(test.ServerTestBase): self.random_bytes(num_bytes), is_client ) assert frame == websockets.Frame.from_bytes( - frame.safe_to_bytes() + frame.to_bytes() ) bytes = b'\x81\x03cba' - assert websockets.Frame.from_bytes(bytes).safe_to_bytes() == bytes - - @raises(websockets.WebSocketFrameValidationException) - def test_safe_to_bytes(self): - frame = websockets.Frame.default(self.random_bytes(8)) - frame.actual_payload_length = 1 # corrupt the frame - frame.safe_to_bytes() + assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes def test_check_server_handshake(self): resp = http.Response( -- cgit v1.2.3 From 176e29fc094119b036ba76d6e5cc1f2d7fb838e0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 23:13:42 +1200 Subject: websockets: constants, variable names, refactoring --- netlib/websockets.py | 75 ++++++++++++++++++++++++++----------------------- test/test_websockets.py | 27 ++++++++++++++++++ 2 files changed, 67 insertions(+), 35 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index 0cd4dba1..1e9c96cc 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -22,11 +22,17 @@ from . import utils, odict # RFC6455 websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" +MAX_16_BIT_INT = (1 << 16) +MAX_64_BIT_INT = (1 << 64) -class CONST(object): - MAX_16_BIT_INT = (1 << 16) - MAX_64_BIT_INT = (1 << 64) +class OPCODE: + CONTINUE = 0x00 + TEXT = 0x01 + BINARY = 0x02 + CLOSE = 0x08 + PING = 0x09 + PONG = 0x0a class Frame(object): @@ -101,7 +107,7 @@ class Frame(object): return cls( fin = 1, # final frame - opcode = 1, # text + opcode = OPCODE.TEXT, # text mask_bit = mask_bit, payload_length_code = length_code, payload = payload, @@ -115,28 +121,27 @@ class Frame(object): Validate websocket frame invariants, call at anytime to ensure the Frame has not been corrupted. """ - try: - assert 0 <= self.fin <= 1 - assert 0 <= self.rsv1 <= 1 - assert 0 <= self.rsv2 <= 1 - assert 0 <= self.rsv3 <= 1 - assert 1 <= self.opcode <= 4 - assert 0 <= self.mask_bit <= 1 - assert 1 <= self.payload_length_code <= 127 - - if self.mask_bit == 1: - assert 1 <= len(self.masking_key) <= 4 - else: - assert self.masking_key is None - - assert self.actual_payload_length == len(self.payload) - - if self.payload is not None and self.masking_key is not None: - assert apply_mask(self.payload, self.masking_key) == self.decoded_payload - - return True - except AssertionError: + constraints = [ + 0 <= self.fin <= 1, + 0 <= self.rsv1 <= 1, + 0 <= self.rsv2 <= 1, + 0 <= self.rsv3 <= 1, + 1 <= self.opcode <= 4, + 0 <= self.mask_bit <= 1, + 1 <= self.payload_length_code <= 127, + self.actual_payload_length == len(self.payload) + ] + if not all(constraints): + return False + elif self.mask_bit == 1 and not 1 <= len(self.masking_key) <= 4: + return False + elif self.mask_bit == 0 and self.masking_key is not None: return False + elif self.payload and self.masking_key: + decoded = apply_mask(self.payload, self.masking_key) + if decoded != self.decoded_payload: + return False + return True def human_readable(self): # pragma: nocover return "\n".join([ @@ -192,11 +197,11 @@ class Frame(object): if self.actual_payload_length < 126: pass - elif self.actual_payload_length < CONST.MAX_16_BIT_INT: + elif self.actual_payload_length < MAX_16_BIT_INT: # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length b += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < CONST.MAX_64_BIT_INT: + elif self.actual_payload_length < MAX_64_BIT_INT: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length b += struct.pack('!Q', self.actual_payload_length) @@ -212,15 +217,15 @@ class Frame(object): writer.flush() @classmethod - def from_file(cls, reader): + def from_file(cls, fp): """ read a websockets frame sent by a server or client - reader is a "file like" object that could be backed by a network + fp is a "file like" object that could be backed by a network stream or a disk or an in memory stream reader """ - first_byte = utils.bytes_to_int(reader.read(1)) - second_byte = utils.bytes_to_int(reader.read(1)) + first_byte = utils.bytes_to_int(fp.read(1)) + second_byte = utils.bytes_to_int(fp.read(1)) # grab the left most bit fin = first_byte >> 7 @@ -237,18 +242,18 @@ class Frame(object): actual_payload_length = payload_length elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(reader.read(2)) + actual_payload_length = utils.bytes_to_int(fp.read(2)) elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(reader.read(8)) + actual_payload_length = utils.bytes_to_int(fp.read(8)) # masking key only present if mask bit set if mask_bit == 1: - masking_key = reader.read(4) + masking_key = fp.read(4) else: masking_key = None - payload = reader.read(actual_payload_length) + payload = fp.read(actual_payload_length) if mask_bit == 1: decoded_payload = apply_mask(payload, masking_key) diff --git a/test/test_websockets.py b/test/test_websockets.py index 3fc67dfe..9e205e70 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -108,6 +108,33 @@ class TestWebSockets(test.ServerTestBase): server_frame = websockets.Frame.default(msg, from_client = False) assert server_frame.is_valid() + def test_is_valid(self): + def f(): + return websockets.Frame.default(self.random_bytes(10), True) + + frame = f() + assert frame.is_valid() + + frame = f() + frame.fin = 2 + assert not frame.is_valid() + + frame = f() + frame.mask_bit = 1 + frame.masking_key = "foobbarboo" + assert not frame.is_valid() + + frame = f() + frame.mask_bit = 0 + frame.masking_key = "foob" + assert not frame.is_valid() + + frame = f() + frame.masking_key = "foob" + frame.decoded_payload = "xxxx" + assert not frame.is_valid() + + def test_serialization_bijection(self): """ Ensure that various frame types can be serialized/deserialized back -- cgit v1.2.3 From 4fb49c8e55cc3c64ac0d5cf8fb913518f1973162 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 23:49:27 +1200 Subject: websockets: (very) slightly nicer is_valid constraints --- netlib/websockets.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index 1e9c96cc..d5c5c2fe 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -129,14 +129,12 @@ class Frame(object): 1 <= self.opcode <= 4, 0 <= self.mask_bit <= 1, 1 <= self.payload_length_code <= 127, - self.actual_payload_length == len(self.payload) + self.actual_payload_length == len(self.payload), + 1 <= len(self.masking_key) <= 4 if self.mask_bit else True, + self.masking_key is not None if self.mask_bit else True ] if not all(constraints): return False - elif self.mask_bit == 1 and not 1 <= len(self.masking_key) <= 4: - return False - elif self.mask_bit == 0 and self.masking_key is not None: - return False elif self.payload and self.masking_key: decoded = apply_mask(self.payload, self.masking_key) if decoded != self.decoded_payload: -- cgit v1.2.3 From 42a87a1d8b3eeccfdd8e5e504f1cd4d90ae1dbfb Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 23 Apr 2015 08:23:51 +1200 Subject: websockets: handshake checks only take headers --- netlib/http.py | 8 ++++---- netlib/websockets.py | 12 ++++++------ test/test_http.py | 6 +++--- test/test_websockets.py | 38 +++++++++++--------------------------- 4 files changed, 24 insertions(+), 40 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index fe27240a..43155486 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -33,7 +33,7 @@ def _is_valid_host(host): return True -def get_line(fp): +def get_request_line(fp): """ Get a line, possibly preceded by a blank. """ @@ -41,8 +41,6 @@ def get_line(fp): if line == "\r\n" or line == "\n": # Possible leftover from previous message line = fp.readline() - if line == "": - raise tcp.NetLibDisconnect() return line @@ -457,7 +455,9 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): httpversion, host, port, scheme, method, path, headers, content = ( None, None, None, None, None, None, None, None) - request_line = get_line(rfile) + request_line = get_request_line(rfile) + if not request_line: + raise tcp.NetLibDisconnect() request_line_parts = parse_init(request_line) if not request_line_parts: diff --git a/netlib/websockets.py b/netlib/websockets.py index d5c5c2fe..da03768d 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -350,16 +350,16 @@ def get_payload_length_pair(payload_bytestring): return (length_code, actual_length) -def check_client_handshake(req): - if req.headers.get_first("upgrade", None) != "websocket": +def check_client_handshake(headers): + if headers.get_first("upgrade", None) != "websocket": return - return req.headers.get_first('sec-websocket-key') + return headers.get_first('sec-websocket-key') -def check_server_handshake(resp): - if resp.headers.get_first("upgrade", None) != "websocket": +def check_server_handshake(headers): + if headers.get_first("upgrade", None) != "websocket": return - return resp.headers.get_first('sec-websocket-accept') + return headers.get_first('sec-websocket-accept') def create_server_nonce(client_nonce): diff --git a/test/test_http.py b/test/test_http.py index 8b99c769..962eb9cb 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -412,10 +412,10 @@ def test_parse_http_basic_auth(): assert not http.parse_http_basic_auth(v) -def test_get_line(): +def test_get_request_line(): r = cStringIO.StringIO("\nfoo") - assert http.get_line(r) == "foo" - tutils.raises(tcp.NetLibDisconnect, http.get_line, r) + assert http.get_request_line(r) == "foo" + assert not http.get_request_line(r) class TestReadRequest(): diff --git a/test/test_websockets.py b/test/test_websockets.py index 9e205e70..6f3b429d 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -27,7 +27,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def handshake(self): req = http.read_request(self.rfile) - key = websockets.check_client_handshake(req) + key = websockets.check_client_handshake(req.headers) self.wfile.write(http.response_preamble(101) + "\r\n") headers = websockets.server_handshake_headers(key) @@ -56,7 +56,7 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.flush() resp = http.read_response(self.rfile, "get", None) - server_nonce = websockets.check_server_handshake(resp) + server_nonce = websockets.check_server_handshake(resp.headers) if not server_nonce == websockets.create_server_nonce(self.client_nonce): self.close() @@ -153,38 +153,22 @@ class TestWebSockets(test.ServerTestBase): assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes def test_check_server_handshake(self): - resp = http.Response( - (1, 1), - 101, - "Switching Protocols", - websockets.server_handshake_headers("key"), - "" - ) - assert websockets.check_server_handshake(resp) - resp.headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_server_handshake(resp) + headers = websockets.server_handshake_headers("key") + assert websockets.check_server_handshake(headers) + headers["Upgrade"] = ["not_websocket"] + assert not websockets.check_server_handshake(headers) def test_check_client_handshake(self): - resp = http.Request( - "relative", - "get", - "http", - "host", - 22, - "/", - (1, 1), - websockets.client_handshake_headers("key"), - "" - ) - assert websockets.check_client_handshake(resp) == "key" - resp.headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_client_handshake(resp) + headers = websockets.client_handshake_headers("key") + assert websockets.check_client_handshake(headers) == "key" + headers["Upgrade"] = ["not_websocket"] + assert not websockets.check_client_handshake(headers) class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): client_hs = http.read_request(self.rfile) - websockets.check_client_handshake(client_hs) + websockets.check_client_handshake(client_hs.headers) self.wfile.write(http.response_preamble(101) + "\r\n") headers = websockets.server_handshake_headers("malformed key") -- cgit v1.2.3 From bdd52fead339e634022a2251bb2bd85a924ca8d2 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 08:47:09 +1200 Subject: websockets: extract frame header creation into a function --- netlib/websockets.py | 263 ++++++++++++++++++++++++++---------------------- test/test_websockets.py | 4 + 2 files changed, 147 insertions(+), 120 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index da03768d..abf86262 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -35,6 +35,139 @@ class OPCODE: PONG = 0x0a +def apply_mask(message, masking_key): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + This method both encodes and decodes strings with the provided mask + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + masks = [utils.bytes_to_int(byte) for byte in masking_key] + result = "" + for char in message: + result += chr(ord(char) ^ masks[len(result) % 4]) + return result + + +def client_handshake_headers(key=None, version=VERSION): + """ + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of ODictCaseless + """ + if not key: + key = base64.b64encode(os.urandom(16)).decode('utf-8') + return odict.ODictCaseless([ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Key', key), + ('Sec-WebSocket-Version', version) + ]) + + +def server_handshake_headers(key): + """ + The server response is a valid HTTP 101 response. + """ + return odict.ODictCaseless( + [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', create_server_nonce(key)) + ] + ) + + +def get_payload_length_pair(payload_bytestring): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + actual_length = len(payload_bytestring) + + if actual_length <= 125: + length_code = actual_length + elif actual_length >= 126 and actual_length <= 65535: + length_code = 126 + else: + length_code = 127 + return (length_code, actual_length) + + +def make_length_code(len): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + if len <= 125: + return len + elif len >= 126 and len <= 65535: + return 126 + else: + return 127 + + +def check_client_handshake(headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first('sec-websocket-key') + + +def check_server_handshake(headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first('sec-websocket-accept') + + +def create_server_nonce(client_nonce): + return base64.b64encode( + hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') + ) + + +def frame_header_bytes( + opcode = 0, + payload_length = 0, + fin = 0, + rsv1 = 0, + rsv2 = 0, + rsv3 = 0, + mask = 0, + masking_key = None, + length_code = None +): + first_byte = (fin << 7) | (rsv1 << 6) |\ + (rsv2 << 4) | (rsv3 << 4) | opcode + + if length_code is None: + length_code = make_length_code(payload_length) + + second_byte = (mask << 7) | length_code + + b = chr(first_byte) + chr(second_byte) + + if payload_length < 126: + pass + elif payload_length < MAX_16_BIT_INT: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + b += struct.pack('!H', payload_length) + elif payload_length < MAX_64_BIT_INT: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + b += struct.pack('!Q', payload_length) + if masking_key is not None: + b += masking_key + return b + + class Frame(object): """ Represents one websockets frame. @@ -170,43 +303,16 @@ class Frame(object): If you haven't checked is_valid_frame() then there's no guarentees that the serialized bytes will be correct. see safe_to_bytes() """ - - # break down of the bit-math used to construct the first byte from the - # frame's integer values first shift the significant bit into the - # correct position - # 00000001 << 7 = 10000000 - # ... - # then combine: - # - # 10000000 fin - # 01000000 res1 - # 00100000 res2 - # 00010000 res3 - # 00000001 opcode - # -------- OR - # 11110001 = first_byte - - first_byte = (self.fin << 7) | (self.rsv1 << 6) |\ - (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode - - second_byte = (self.mask_bit << 7) | self.payload_length_code - - b = chr(first_byte) + chr(second_byte) - - if self.actual_payload_length < 126: - pass - elif self.actual_payload_length < MAX_16_BIT_INT: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - b += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < MAX_64_BIT_INT: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - b += struct.pack('!Q', self.actual_payload_length) - - if self.masking_key is not None: - b += self.masking_key - + b = frame_header_bytes( + opcode = self.opcode, + fin = self.fin, + rsv1 = self.rsv1, + rsv2 = self.rsv2, + rsv3 = self.rsv3, + mask = self.mask_bit, + masking_key = self.masking_key, + payload_length = self.actual_payload_length + ) b += self.payload # already will be encoded if neccessary return b @@ -283,86 +389,3 @@ class Frame(object): self.decoded_payload == other.decoded_payload and self.actual_payload_length == other.actual_payload_length ) - - -def apply_mask(message, masking_key): - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns - - This method both encodes and decodes strings with the provided mask - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - masks = [utils.bytes_to_int(byte) for byte in masking_key] - result = "" - for char in message: - result += chr(ord(char) ^ masks[len(result) % 4]) - return result - - -def client_handshake_headers(key=None, version=VERSION): - """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. - - Returns an instance of ODictCaseless - """ - if not key: - key = base64.b64encode(os.urandom(16)).decode('utf-8') - return odict.ODictCaseless([ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Key', key), - ('Sec-WebSocket-Version', version) - ]) - - -def server_handshake_headers(key): - """ - The server response is a valid HTTP 101 response. - """ - return odict.ODictCaseless( - [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nonce(key)) - ] - ) - - -def get_payload_length_pair(payload_bytestring): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - actual_length = len(payload_bytestring) - - if actual_length <= 125: - length_code = actual_length - elif actual_length >= 126 and actual_length <= 65535: - length_code = 126 - else: - length_code = 127 - return (length_code, actual_length) - - -def check_client_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-key') - - -def check_server_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-accept') - - -def create_server_nonce(client_nonce): - return base64.b64encode( - hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') - ) diff --git a/test/test_websockets.py b/test/test_websockets.py index 6f3b429d..17f7f728 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -3,6 +3,10 @@ import os from nose.tools import raises +def test_frame_header_bytes(): + assert websockets.frame_header_bytes() + + class WebSocketsEchoHandler(tcp.BaseHandler): def __init__(self, connection, address, server): super(WebSocketsEchoHandler, self).__init__( -- cgit v1.2.3 From 3519871f340cb0466fc6935d6e8e3b7822d36c52 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 09:21:04 +1200 Subject: websockets: refactor to avoid rundantly specifying payloads and payload lengths --- netlib/websockets.py | 60 +++++++++++++++++++------------------------------ test/test_websockets.py | 17 +++----------- 2 files changed, 26 insertions(+), 51 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index abf86262..7c127563 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -198,15 +198,13 @@ class Frame(object): self, fin, # decmial integer 1 or 0 opcode, # decmial integer 1 - 4 - mask_bit, # decimal integer 1 or 0 - payload_length_code, # decimal integer 1 - 127 - decoded_payload, # bytestring + payload = "", # bytestring + masking_key = None, # 32 bit byte string + mask_bit = 0, # decimal integer 1 or 0 + payload_length_code = None, # decimal integer 1 - 127 rsv1 = 0, # decimal integer 1 or 0 rsv2 = 0, # decimal integer 1 or 0 rsv3 = 0, # decimal integer 1 or 0 - payload = None, # bytestring - masking_key = None, # 32 bit byte string - actual_payload_length = None, # any decimal integer ): self.fin = fin self.rsv1 = rsv1 @@ -217,8 +215,6 @@ class Frame(object): self.payload_length_code = payload_length_code self.masking_key = masking_key self.payload = payload - self.decoded_payload = decoded_payload - self.actual_payload_length = actual_payload_length @classmethod def default(cls, message, from_client = False): @@ -226,27 +222,19 @@ class Frame(object): Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. """ - length_code, actual_length = get_payload_length_pair(message) - if from_client: mask_bit = 1 - # Random masking key masking_key = os.urandom(4) - payload = apply_mask(message, masking_key) else: mask_bit = 0 masking_key = None - payload = message return cls( fin = 1, # final frame opcode = OPCODE.TEXT, # text mask_bit = mask_bit, - payload_length_code = length_code, - payload = payload, + payload = message, masking_key = masking_key, - decoded_payload = message, - actual_payload_length = actual_length ) def is_valid(self): @@ -261,17 +249,12 @@ class Frame(object): 0 <= self.rsv3 <= 1, 1 <= self.opcode <= 4, 0 <= self.mask_bit <= 1, - 1 <= self.payload_length_code <= 127, - self.actual_payload_length == len(self.payload), + #1 <= self.payload_length_code <= 127, 1 <= len(self.masking_key) <= 4 if self.mask_bit else True, self.masking_key is not None if self.mask_bit else True ] if not all(constraints): return False - elif self.payload and self.masking_key: - decoded = apply_mask(self.payload, self.masking_key) - if decoded != self.decoded_payload: - return False return True def human_readable(self): # pragma: nocover @@ -285,8 +268,6 @@ class Frame(object): ("payload_length_code - " + str(self.payload_length_code)), ("masking_key - " + repr(str(self.masking_key))), ("payload - " + repr(str(self.payload))), - ("decoded_payload - " + repr(str(self.decoded_payload))), - ("actual_payload_length - " + str(self.actual_payload_length)) ]) @classmethod @@ -311,9 +292,12 @@ class Frame(object): rsv3 = self.rsv3, mask = self.mask_bit, masking_key = self.masking_key, - payload_length = self.actual_payload_length + payload_length = len(self.payload) if self.payload else 0 ) - b += self.payload # already will be encoded if neccessary + if self.masking_key: + b += apply_mask(self.payload, self.masking_key) + else: + b += self.payload return b def to_file(self, writer): @@ -359,10 +343,8 @@ class Frame(object): payload = fp.read(actual_payload_length) - if mask_bit == 1: - decoded_payload = apply_mask(payload, masking_key) - else: - decoded_payload = payload + if mask_bit == 1 and masking_key: + payload = apply_mask(payload, masking_key) return cls( fin = fin, @@ -371,11 +353,17 @@ class Frame(object): payload_length_code = payload_length, payload = payload, masking_key = masking_key, - decoded_payload = decoded_payload, - actual_payload_length = actual_payload_length ) def __eq__(self, other): + if self.payload_length_code is None: + myplc = make_length_code(len(self.payload)) + else: + myplc = self.payload_length_code + if other.payload_length_code is None: + otherplc = make_length_code(len(other.payload)) + else: + otherplc = other.payload_length_code return ( self.fin == other.fin and self.rsv1 == other.rsv1 and @@ -383,9 +371,7 @@ class Frame(object): self.rsv3 == other.rsv3 and self.opcode == other.opcode and self.mask_bit == other.mask_bit and - self.payload_length_code == other.payload_length_code and self.masking_key == other.masking_key and - self.payload == other.payload and - self.decoded_payload == other.decoded_payload and - self.actual_payload_length == other.actual_payload_length + self.payload == other.payload, + myplc == otherplc ) diff --git a/test/test_websockets.py b/test/test_websockets.py index 17f7f728..bf8ec5cd 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -23,7 +23,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def read_next_message(self): frame = websockets.Frame.from_file(self.rfile) - self.on_message(frame.decoded_payload) + self.on_message(frame.payload) def send_message(self, message): frame = websockets.Frame.default(message, from_client = False) @@ -107,7 +107,6 @@ class TestWebSockets(test.ServerTestBase): """ msg = self.random_bytes() client_frame = websockets.Frame.default(msg, from_client = True) - assert client_frame.is_valid() server_frame = websockets.Frame.default(msg, from_client = False) assert server_frame.is_valid() @@ -128,17 +127,6 @@ class TestWebSockets(test.ServerTestBase): frame.masking_key = "foobbarboo" assert not frame.is_valid() - frame = f() - frame.mask_bit = 0 - frame.masking_key = "foob" - assert not frame.is_valid() - - frame = f() - frame.masking_key = "foob" - frame.decoded_payload = "xxxx" - assert not frame.is_valid() - - def test_serialization_bijection(self): """ Ensure that various frame types can be serialized/deserialized back @@ -149,9 +137,10 @@ class TestWebSockets(test.ServerTestBase): frame = websockets.Frame.default( self.random_bytes(num_bytes), is_client ) - assert frame == websockets.Frame.from_bytes( + frame2 = websockets.Frame.from_bytes( frame.to_bytes() ) + assert frame == frame2 bytes = b'\x81\x03cba' assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes -- cgit v1.2.3 From f22bc0b4c74776bcc312fed1f4ceede83f869a6e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 15:09:21 +1200 Subject: websocket: interface refactoring - Separate out FrameHeader. We need to deal with this separately in many circumstances. - Simpler equality scheme. - Bits are now specified by truthiness - we don't care about the integer value. This means lots of validation is not needed any more. --- netlib/utils.py | 16 +++ netlib/websockets.py | 303 +++++++++++++++++++++++------------------------- test/test_websockets.py | 49 ++++++-- 3 files changed, 201 insertions(+), 167 deletions(-) diff --git a/netlib/utils.py b/netlib/utils.py index 66bbdb5e..44bed43a 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -49,3 +49,19 @@ def hexdump(s): (o, x, cleanBin(part, True)) ) return parts + + +def setbit(byte, offset, value): + """ + Set a bit in a byte to 1 if value is truthy, 0 if not. + """ + if value: + return byte | (1 << offset) + else: + return byte & ~(1 << offset) + + +def getbit(byte, offset): + mask = 1 << offset + if byte & mask: + return True diff --git a/netlib/websockets.py b/netlib/websockets.py index 7c127563..016e75c2 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -1,5 +1,4 @@ from __future__ import absolute_import - import base64 import hashlib import os @@ -83,23 +82,6 @@ def server_handshake_headers(key): ) -def get_payload_length_pair(payload_bytestring): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - actual_length = len(payload_bytestring) - - if actual_length <= 125: - length_code = actual_length - elif actual_length >= 126 and actual_length <= 65535: - length_code = 126 - else: - length_code = 127 - return (length_code, actual_length) - - def make_length_code(len): """ A websockets frame contains an initial length_code, and an optional @@ -132,40 +114,113 @@ def create_server_nonce(client_nonce): ) -def frame_header_bytes( - opcode = 0, - payload_length = 0, - fin = 0, - rsv1 = 0, - rsv2 = 0, - rsv3 = 0, - mask = 0, - masking_key = None, - length_code = None -): - first_byte = (fin << 7) | (rsv1 << 6) |\ - (rsv2 << 4) | (rsv3 << 4) | opcode - - if length_code is None: - length_code = make_length_code(payload_length) - - second_byte = (mask << 7) | length_code - - b = chr(first_byte) + chr(second_byte) - - if payload_length < 126: - pass - elif payload_length < MAX_16_BIT_INT: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - b += struct.pack('!H', payload_length) - elif payload_length < MAX_64_BIT_INT: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - b += struct.pack('!Q', payload_length) - if masking_key is not None: - b += masking_key - return b +DEFAULT = object() +class FrameHeader: + def __init__( + self, + opcode = OPCODE.TEXT, + payload_length = 0, + fin = False, + rsv1 = False, + rsv2 = False, + rsv3 = False, + masking_key = None, + mask = DEFAULT, + length_code = DEFAULT + ): + self.opcode = opcode + self.payload_length = payload_length + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.mask = mask + self.masking_key = masking_key + self.length_code = length_code + + def to_bytes(self): + first_byte = utils.setbit(0, 7, self.fin) + first_byte = utils.setbit(first_byte, 6, self.rsv1) + first_byte = utils.setbit(first_byte, 5, self.rsv2) + first_byte = utils.setbit(first_byte, 4, self.rsv3) + first_byte = first_byte | self.opcode + + if self.length_code is DEFAULT: + length_code = make_length_code(self.payload_length) + else: + length_code = self.length_code + + if self.mask is DEFAULT: + mask = bool(self.masking_key) + else: + mask = self.mask + + second_byte = (mask << 7) | length_code + + b = chr(first_byte) + chr(second_byte) + + if self.payload_length < 126: + pass + elif self.payload_length < MAX_16_BIT_INT: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + b += struct.pack('!H', self.payload_length) + elif self.payload_length < MAX_64_BIT_INT: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + b += struct.pack('!Q', self.payload_length) + if self.masking_key is not None: + b += self.masking_key + return b + + @classmethod + def from_file(klass, fp): + """ + read a websockets frame header + """ + first_byte = utils.bytes_to_int(fp.read(1)) + second_byte = utils.bytes_to_int(fp.read(1)) + + fin = utils.getbit(first_byte, 7) + rsv1 = utils.getbit(first_byte, 6) + rsv2 = utils.getbit(first_byte, 5) + rsv3 = utils.getbit(first_byte, 4) + # grab right most 4 bits by and-ing with 00001111 + opcode = first_byte & 15 + # grab left most bit + mask_bit = second_byte >> 7 + # grab the next 7 bits + length_code = second_byte & 127 + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if length_code <= 125: + payload_length = length_code + elif length_code == 126: + payload_length = utils.bytes_to_int(fp.read(2)) + elif length_code == 127: + payload_length = utils.bytes_to_int(fp.read(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = fp.read(4) + else: + masking_key = None + + return klass( + fin = fin, + rsv1 = rsv1, + rsv2 = rsv2, + rsv3 = rsv3, + opcode = opcode, + mask = mask_bit, + length_code = length_code, + payload_length = payload_length, + masking_key = masking_key, + ) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() class Frame(object): @@ -194,27 +249,10 @@ class Frame(object): | Payload Data continued ... | +---------------------------------------------------------------+ """ - def __init__( - self, - fin, # decmial integer 1 or 0 - opcode, # decmial integer 1 - 4 - payload = "", # bytestring - masking_key = None, # 32 bit byte string - mask_bit = 0, # decimal integer 1 or 0 - payload_length_code = None, # decimal integer 1 - 127 - rsv1 = 0, # decimal integer 1 or 0 - rsv2 = 0, # decimal integer 1 or 0 - rsv3 = 0, # decimal integer 1 or 0 - ): - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - self.opcode = opcode - self.mask_bit = mask_bit - self.payload_length_code = payload_length_code - self.masking_key = masking_key + def __init__(self, payload = "", **kwargs): self.payload = payload + kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) + self.header = FrameHeader(**kwargs) @classmethod def default(cls, message, from_client = False): @@ -230,10 +268,10 @@ class Frame(object): masking_key = None return cls( + message, fin = 1, # final frame opcode = OPCODE.TEXT, # text - mask_bit = mask_bit, - payload = message, + mask = mask_bit, masking_key = masking_key, ) @@ -243,30 +281,30 @@ class Frame(object): Frame has not been corrupted. """ constraints = [ - 0 <= self.fin <= 1, - 0 <= self.rsv1 <= 1, - 0 <= self.rsv2 <= 1, - 0 <= self.rsv3 <= 1, - 1 <= self.opcode <= 4, - 0 <= self.mask_bit <= 1, + 0 <= self.header.fin <= 1, + 0 <= self.header.rsv1 <= 1, + 0 <= self.header.rsv2 <= 1, + 0 <= self.header.rsv3 <= 1, + 1 <= self.header.opcode <= 4, + 0 <= self.header.mask <= 1, #1 <= self.payload_length_code <= 127, - 1 <= len(self.masking_key) <= 4 if self.mask_bit else True, - self.masking_key is not None if self.mask_bit else True + 1 <= len(self.header.masking_key) <= 4 if self.header.mask else True, + self.header.masking_key is not None if self.header.mask else True ] if not all(constraints): return False return True - def human_readable(self): # pragma: nocover + def human_readable(self): return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask_bit - " + str(self.mask_bit)), - ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + repr(str(self.masking_key))), + ("fin - " + str(self.header.fin)), + ("rsv1 - " + str(self.header.rsv1)), + ("rsv2 - " + str(self.header.rsv2)), + ("rsv3 - " + str(self.header.rsv3)), + ("opcode - " + str(self.header.opcode)), + ("mask - " + str(self.header.mask)), + ("length_code - " + str(self.header.length_code)), + ("masking_key - " + repr(str(self.header.masking_key))), ("payload - " + repr(str(self.payload))), ]) @@ -284,18 +322,9 @@ class Frame(object): If you haven't checked is_valid_frame() then there's no guarentees that the serialized bytes will be correct. see safe_to_bytes() """ - b = frame_header_bytes( - opcode = self.opcode, - fin = self.fin, - rsv1 = self.rsv1, - rsv2 = self.rsv2, - rsv3 = self.rsv3, - mask = self.mask_bit, - masking_key = self.masking_key, - payload_length = len(self.payload) if self.payload else 0 - ) - if self.masking_key: - b += apply_mask(self.payload, self.masking_key) + b = self.header.to_bytes() + if self.header.masking_key: + b += apply_mask(self.payload, self.header.masking_key) else: b += self.payload return b @@ -312,66 +341,20 @@ class Frame(object): fp is a "file like" object that could be backed by a network stream or a disk or an in memory stream reader """ - first_byte = utils.bytes_to_int(fp.read(1)) - second_byte = utils.bytes_to_int(fp.read(1)) - - # grab the left most bit - fin = first_byte >> 7 - # grab right most 4 bits by and-ing with 00001111 - opcode = first_byte & 15 - # grab left most bit - mask_bit = second_byte >> 7 - # grab the next 7 bits - payload_length = second_byte & 127 - - # payload_lengthy > 125 indicates you need to read more bytes - # to get the actual payload length - if payload_length <= 125: - actual_payload_length = payload_length - - elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(fp.read(2)) - - elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(fp.read(8)) - - # masking key only present if mask bit set - if mask_bit == 1: - masking_key = fp.read(4) - else: - masking_key = None - - payload = fp.read(actual_payload_length) + header = FrameHeader.from_file(fp) + payload = fp.read(header.payload_length) - if mask_bit == 1 and masking_key: - payload = apply_mask(payload, masking_key) + if header.mask == 1 and header.masking_key: + payload = apply_mask(payload, header.masking_key) return cls( - fin = fin, - opcode = opcode, - mask_bit = mask_bit, - payload_length_code = payload_length, - payload = payload, - masking_key = masking_key, + payload, + fin = header.fin, + opcode = header.opcode, + mask = header.mask, + payload_length = header.payload_length, + masking_key = header.masking_key, ) def __eq__(self, other): - if self.payload_length_code is None: - myplc = make_length_code(len(self.payload)) - else: - myplc = self.payload_length_code - if other.payload_length_code is None: - otherplc = make_length_code(len(other.payload)) - else: - otherplc = other.payload_length_code - return ( - self.fin == other.fin and - self.rsv1 == other.rsv1 and - self.rsv2 == other.rsv2 and - self.rsv3 == other.rsv3 and - self.opcode == other.opcode and - self.mask_bit == other.mask_bit and - self.masking_key == other.masking_key and - self.payload == other.payload, - myplc == otherplc - ) + return self.to_bytes() == other.to_bytes() diff --git a/test/test_websockets.py b/test/test_websockets.py index bf8ec5cd..06876e0b 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,10 +1,9 @@ -from netlib import tcp, test, websockets, http +import cStringIO import os -from nose.tools import raises +from nose.tools import raises -def test_frame_header_bytes(): - assert websockets.frame_header_bytes() +from netlib import tcp, test, websockets, http class WebSocketsEchoHandler(tcp.BaseHandler): @@ -119,12 +118,12 @@ class TestWebSockets(test.ServerTestBase): assert frame.is_valid() frame = f() - frame.fin = 2 + frame.header.fin = 2 assert not frame.is_valid() frame = f() - frame.mask_bit = 1 - frame.masking_key = "foobbarboo" + frame.header.mask_bit = 1 + frame.header.masking_key = "foobbarboo" assert not frame.is_valid() def test_serialization_bijection(self): @@ -181,3 +180,39 @@ class TestBadHandshake(test.ServerTestBase): client = WebSocketsClient(("127.0.0.1", self.port)) client.connect() client.send_message("hello") + + +class TestFrameHeader: + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.FrameHeader(*args, **kwargs) + bytes = f.to_bytes() + f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) + assert f == f2 + round() + round(fin=1) + round(rsv1=1) + round(rsv2=1) + round(rsv3=1) + round(payload_length=1) + round(payload_length=100) + round(payload_length=1000) + round(payload_length=10000) + round(opcode=websockets.OPCODE.PING) + round(masking_key="test") + + def test_funky(self): + f = websockets.FrameHeader(masking_key="test", mask=False) + bytes = f.to_bytes() + f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) + assert not f2.mask + + +class TestFrame: + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.Frame(*args, **kwargs) + bytes = f.to_bytes() + f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes)) + assert f == f2 + round("test") -- cgit v1.2.3 From def93ea8cae69676a91b01e149e8a406fa03eacd Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 15:23:00 +1200 Subject: websockets: remove validation We don't really need this any more. The interface is much less error prone because bit flags are no longer integers, we have a range check on opcode on header instantiation, and we've deferred length code calculation and so forth into the byte render methods. --- netlib/websockets.py | 24 ++++-------------------- test/test_websockets.py | 26 ++++++++------------------ 2 files changed, 12 insertions(+), 38 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index 016e75c2..b1afa620 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -115,6 +115,8 @@ def create_server_nonce(client_nonce): DEFAULT = object() + + class FrameHeader: def __init__( self, @@ -128,6 +130,8 @@ class FrameHeader: mask = DEFAULT, length_code = DEFAULT ): + if not 0 <= opcode < 2 ** 4: + raise ValueError("opcode must be 0-16") self.opcode = opcode self.payload_length = payload_length self.fin = fin @@ -275,26 +279,6 @@ class Frame(object): masking_key = masking_key, ) - def is_valid(self): - """ - Validate websocket frame invariants, call at anytime to ensure the - Frame has not been corrupted. - """ - constraints = [ - 0 <= self.header.fin <= 1, - 0 <= self.header.rsv1 <= 1, - 0 <= self.header.rsv2 <= 1, - 0 <= self.header.rsv3 <= 1, - 1 <= self.header.opcode <= 4, - 0 <= self.header.mask <= 1, - #1 <= self.payload_length_code <= 127, - 1 <= len(self.header.masking_key) <= 4 if self.header.mask else True, - self.header.masking_key is not None if self.header.mask else True - ] - if not all(constraints): - return False - return True - def human_readable(self): return "\n".join([ ("fin - " + str(self.header.fin)), diff --git a/test/test_websockets.py b/test/test_websockets.py index 06876e0b..215b3958 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -4,6 +4,7 @@ import os from nose.tools import raises from netlib import tcp, test, websockets, http +import tutils class WebSocketsEchoHandler(tcp.BaseHandler): @@ -106,25 +107,7 @@ class TestWebSockets(test.ServerTestBase): """ msg = self.random_bytes() client_frame = websockets.Frame.default(msg, from_client = True) - server_frame = websockets.Frame.default(msg, from_client = False) - assert server_frame.is_valid() - - def test_is_valid(self): - def f(): - return websockets.Frame.default(self.random_bytes(10), True) - - frame = f() - assert frame.is_valid() - - frame = f() - frame.header.fin = 2 - assert not frame.is_valid() - - frame = f() - frame.header.mask_bit = 1 - frame.header.masking_key = "foobbarboo" - assert not frame.is_valid() def test_serialization_bijection(self): """ @@ -207,6 +190,9 @@ class TestFrameHeader: f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) assert not f2.mask + def test_violations(self): + tutils.raises("opcode", websockets.FrameHeader, opcode=17) + class TestFrame: def test_roundtrip(self): @@ -216,3 +202,7 @@ class TestFrame: f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes)) assert f == f2 round("test") + + def test_human_readable(self): + f = websockets.Frame() + assert f.human_readable() -- cgit v1.2.3 From 192fd1db7f233b71398c5255cbdebe1928768b55 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 15:31:14 +1200 Subject: websockets: include all header values in frame roundtrip --- netlib/websockets.py | 27 +++++++++++++++------------ test/test_websockets.py | 4 ++++ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index b1afa620..85aad9c6 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -159,7 +159,7 @@ class FrameHeader: else: mask = self.mask - second_byte = (mask << 7) | length_code + second_byte = utils.setbit(length_code, 7, mask) b = chr(first_byte) + chr(second_byte) @@ -189,10 +189,9 @@ class FrameHeader: rsv1 = utils.getbit(first_byte, 6) rsv2 = utils.getbit(first_byte, 5) rsv3 = utils.getbit(first_byte, 4) - # grab right most 4 bits by and-ing with 00001111 + # grab right-most 4 bits opcode = first_byte & 15 - # grab left most bit - mask_bit = second_byte >> 7 + mask_bit = utils.getbit(second_byte, 7) # grab the next 7 bits length_code = second_byte & 127 @@ -279,6 +278,14 @@ class Frame(object): masking_key = masking_key, ) + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use from_file() directly + """ + return cls.from_file(io.BytesIO(bytestring)) + def human_readable(self): return "\n".join([ ("fin - " + str(self.header.fin)), @@ -292,14 +299,6 @@ class Frame(object): ("payload - " + repr(str(self.payload))), ]) - @classmethod - def from_bytes(cls, bytestring): - """ - Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use from_file() directly - """ - return cls.from_file(io.BytesIO(bytestring)) - def to_bytes(self): """ Serialize the frame back into the wire format, returns a bytestring @@ -338,6 +337,10 @@ class Frame(object): mask = header.mask, payload_length = header.payload_length, masking_key = header.masking_key, + rsv1 = header.rsv1, + rsv2 = header.rsv2, + rsv3 = header.rsv3, + length_code = header.length_code ) def __eq__(self, other): diff --git a/test/test_websockets.py b/test/test_websockets.py index 215b3958..8ae18edd 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -202,6 +202,10 @@ class TestFrame: f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes)) assert f == f2 round("test") + round("test", fin=1) + round("test", rsv1=1) + round("test", opcode=websockets.OPCODE.PING) + round("test", masking_key="test") def test_human_readable(self): f = websockets.Frame() -- cgit v1.2.3 From 18df329930eb822395caf279862589d2a40413c9 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 24 Apr 2015 15:42:31 +1200 Subject: websockets: nicer frame construction - Resolve unspecified values on instantiation - Add a check for masking key length - Smarter resolution for masking_key and mask values. Do the right thing unless told not to. --- netlib/websockets.py | 38 +++++++++++++++++++++++--------------- test/test_websockets.py | 12 ++++++++++++ 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index 85aad9c6..493bb18a 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -126,7 +126,7 @@ class FrameHeader: rsv1 = False, rsv2 = False, rsv3 = False, - masking_key = None, + masking_key = DEFAULT, mask = DEFAULT, length_code = DEFAULT ): @@ -138,9 +138,27 @@ class FrameHeader: self.rsv1 = rsv1 self.rsv2 = rsv2 self.rsv3 = rsv3 - self.mask = mask - self.masking_key = masking_key - self.length_code = length_code + + if length_code is DEFAULT: + self.length_code = make_length_code(self.payload_length) + else: + self.length_code = length_code + + if mask is DEFAULT and masking_key is DEFAULT: + self.mask = False + self.masking_key = "" + elif mask is DEFAULT: + self.mask = 1 + self.masking_key = masking_key + elif masking_key is DEFAULT: + self.mask = mask + self.masking_key = os.urandom(4) + else: + self.mask = mask + self.masking_key = masking_key + + if self.masking_key and len(self.masking_key) != 4: + raise ValueError("Masking key must be 4 bytes.") def to_bytes(self): first_byte = utils.setbit(0, 7, self.fin) @@ -149,17 +167,7 @@ class FrameHeader: first_byte = utils.setbit(first_byte, 4, self.rsv3) first_byte = first_byte | self.opcode - if self.length_code is DEFAULT: - length_code = make_length_code(self.payload_length) - else: - length_code = self.length_code - - if self.mask is DEFAULT: - mask = bool(self.masking_key) - else: - mask = self.mask - - second_byte = utils.setbit(length_code, 7, mask) + second_byte = utils.setbit(self.length_code, 7, self.mask) b = chr(first_byte) + chr(second_byte) diff --git a/test/test_websockets.py b/test/test_websockets.py index 8ae18edd..4b286b6f 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -192,6 +192,18 @@ class TestFrameHeader: def test_violations(self): tutils.raises("opcode", websockets.FrameHeader, opcode=17) + tutils.raises("masking key", websockets.FrameHeader, masking_key="x") + + def test_automask(self): + f = websockets.FrameHeader(mask=True) + assert f.masking_key + + f = websockets.FrameHeader(masking_key="foob") + assert f.mask + + f = websockets.FrameHeader(masking_key="foob", mask=0) + assert not f.mask + assert f.masking_key class TestFrame: -- cgit v1.2.3 From b7a2fc85537dca60fb18d25965289d876bd3bd38 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 30 Apr 2015 08:41:13 +1200 Subject: testing: http read_request corner case --- test/test_http.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_http.py b/test/test_http.py index 962eb9cb..f1a31b93 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -440,6 +440,11 @@ class TestReadRequest(): self.tst, "get / HTTP/1.1\r\nfoo" ) + tutils.raises( + tcp.NetLibDisconnect, + self.tst, + "\r\n" + ) def test_asterisk_form_in(self): v = self.tst("OPTIONS * HTTP/1.1") -- cgit v1.2.3 From 80860229209b4c6eb8384e1bca3cabdbe062fe6e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 30 Apr 2015 09:04:22 +1200 Subject: Add a tiny utility class for keeping bi-directional mappings. Use it in websocket and socks. --- netlib/socks.py | 60 ++++++++++++++++++++++++++----------------------- netlib/utils.py | 26 +++++++++++++++++++++ netlib/websockets.py | 25 ++++++++++++++++----- test/test_utils.py | 11 ++++++++- test/test_websockets.py | 4 ++++ 5 files changed, 91 insertions(+), 35 deletions(-) diff --git a/netlib/socks.py b/netlib/socks.py index a3c4e9a2..497b8eef 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -2,7 +2,7 @@ from __future__ import (absolute_import, print_function, division) import socket import struct import array -from . import tcp +from . import tcp, utils class SocksError(Exception): @@ -11,40 +11,45 @@ class SocksError(Exception): self.code = code -class VERSION(object): - SOCKS4 = 0x04 +VERSION = utils.BiDi( + SOCKS4 = 0x04, SOCKS5 = 0x05 +) -class CMD(object): - CONNECT = 0x01 - BIND = 0x02 +CMD = utils.BiDi( + CONNECT = 0x01, + BIND = 0x02, UDP_ASSOCIATE = 0x03 +) -class ATYP(object): - IPV4_ADDRESS = 0x01 - DOMAINNAME = 0x03 +ATYP = utils.BiDi( + IPV4_ADDRESS = 0x01, + DOMAINNAME = 0x03, IPV6_ADDRESS = 0x04 - - -class REP(object): - SUCCEEDED = 0x00 - GENERAL_SOCKS_SERVER_FAILURE = 0x01 - CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02 - NETWORK_UNREACHABLE = 0x03 - HOST_UNREACHABLE = 0x04 - CONNECTION_REFUSED = 0x05 - TTL_EXPIRED = 0x06 - COMMAND_NOT_SUPPORTED = 0x07 - ADDRESS_TYPE_NOT_SUPPORTED = 0x08 - - -class METHOD(object): - NO_AUTHENTICATION_REQUIRED = 0x00 - GSSAPI = 0x01 - USERNAME_PASSWORD = 0x02 +) + + +REP = utils.BiDi( + SUCCEEDED = 0x00, + GENERAL_SOCKS_SERVER_FAILURE = 0x01, + CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02, + NETWORK_UNREACHABLE = 0x03, + HOST_UNREACHABLE = 0x04, + CONNECTION_REFUSED = 0x05, + TTL_EXPIRED = 0x06, + COMMAND_NOT_SUPPORTED = 0x07, + ADDRESS_TYPE_NOT_SUPPORTED = 0x08, +) + + +METHOD = utils.BiDi( + NO_AUTHENTICATION_REQUIRED = 0x00, + GSSAPI = 0x01, + USERNAME_PASSWORD = 0x02, NO_ACCEPTABLE_METHODS = 0xFF +) def _read(f, n): @@ -146,4 +151,3 @@ class Message(object): "Unknown ATYP: %s" % self.atyp ) f.write(struct.pack("!H", self.addr.port)) - diff --git a/netlib/utils.py b/netlib/utils.py index 44bed43a..905d948f 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -65,3 +65,29 @@ def getbit(byte, offset): mask = 1 << offset if byte & mask: return True + + +class BiDi: + """ + A wee utility class for keeping bi-directional mappings, like field + constants in protocols: + + CONST = BiDi(a=1, b=2) + assert CONST.a == 1 + assert CONST[1] == "a" + """ + def __init__(self, **kwargs): + self.names = kwargs + self.values = {} + for k, v in kwargs.items(): + self.values[v] = k + if len(self.names) != len(self.values): + raise ValueError("Duplicate values not allowed.") + + def __getattr__(self, k): + if k in self.names: + return self.names[k] + raise AttributeError("No such attribute: %s", k) + + def __getitem__(self, k): + return self.values[k] diff --git a/netlib/websockets.py b/netlib/websockets.py index 493bb18a..d358ed53 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -25,13 +25,14 @@ MAX_16_BIT_INT = (1 << 16) MAX_64_BIT_INT = (1 << 64) -class OPCODE: - CONTINUE = 0x00 - TEXT = 0x01 - BINARY = 0x02 - CLOSE = 0x08 - PING = 0x09 +OPCODE = utils.BiDi( + CONTINUE = 0x00, + TEXT = 0x01, + BINARY = 0x02, + CLOSE = 0x08, + PING = 0x09, PONG = 0x0a +) def apply_mask(message, masking_key): @@ -160,6 +161,18 @@ class FrameHeader: if self.masking_key and len(self.masking_key) != 4: raise ValueError("Masking key must be 4 bytes.") + def human_readable(self): + return "\n".join([ + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask - " + str(self.mask)), + ("length_code - " + str(self.length_code)), + ("masking_key - " + repr(str(self.masking_key))), + ]) + def to_bytes(self): first_byte = utils.setbit(0, 7, self.fin) first_byte = utils.setbit(first_byte, 6, self.rsv1) diff --git a/test/test_utils.py b/test/test_utils.py index 971e5076..0cdd3fae 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,5 +1,14 @@ from netlib import utils -import socket +import tutils + + +def test_bidi(): + b = utils.BiDi(a=1, b=2) + assert b.a == 1 + assert b[1] == "a" + tutils.raises(AttributeError, getattr, b, "c") + tutils.raises(KeyError, b.__getitem__, 5) + def test_hexdump(): assert utils.hexdump("one\0"*10) diff --git a/test/test_websockets.py b/test/test_websockets.py index 4b286b6f..9266d93e 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -184,6 +184,10 @@ class TestFrameHeader: round(opcode=websockets.OPCODE.PING) round(masking_key="test") + def test_human_readable(self): + f = websockets.FrameHeader(masking_key="test", mask=False) + assert f.human_readable() + def test_funky(self): f = websockets.FrameHeader(masking_key="test", mask=False) bytes = f.to_bytes() -- cgit v1.2.3 From 4dce7ee074c242f5b6530ff64879875d98c1d255 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 30 Apr 2015 12:10:08 +1200 Subject: websockets: more compact and legible human_readable --- netlib/utils.py | 25 +++++++++++++++++++++---- netlib/websockets.py | 38 +++++++++++++++++--------------------- test/test_utils.py | 12 ++++++++++-- test/test_websockets.py | 8 +++++++- 4 files changed, 55 insertions(+), 28 deletions(-) diff --git a/netlib/utils.py b/netlib/utils.py index 905d948f..7e539977 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -70,11 +70,12 @@ def getbit(byte, offset): class BiDi: """ A wee utility class for keeping bi-directional mappings, like field - constants in protocols: + constants in protocols. Names are attributes on the object, dict-like + access maps values to names: CONST = BiDi(a=1, b=2) assert CONST.a == 1 - assert CONST[1] == "a" + assert CONST.get_name(1) == "a" """ def __init__(self, **kwargs): self.names = kwargs @@ -89,5 +90,21 @@ class BiDi: return self.names[k] raise AttributeError("No such attribute: %s", k) - def __getitem__(self, k): - return self.values[k] + def get_name(self, n, default=None): + return self.values.get(n, default) + + +def pretty_size(size): + suffixes = [ + ("B", 2**10), + ("kB", 2**20), + ("MB", 2**30), + ] + for suf, lim in suffixes: + if size >= lim: + continue + else: + x = round(size/float(lim/2**10), 2) + if x == int(x): + x = int(x) + return str(x) + suf diff --git a/netlib/websockets.py b/netlib/websockets.py index d358ed53..1d02d684 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -162,16 +162,21 @@ class FrameHeader: raise ValueError("Masking key must be 4 bytes.") def human_readable(self): - return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask - " + str(self.mask)), - ("length_code - " + str(self.length_code)), - ("masking_key - " + repr(str(self.masking_key))), - ]) + vals = [ + "wf:", + OPCODE.get_name(self.opcode, hex(self.opcode)).lower() + ] + flags = [] + for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: + if getattr(self, i): + flags.append(i) + if flags: + vals.extend([":", "|".join(flags)]) + if self.masking_key: + vals.append(":key=%s"%repr(self.masking_key)) + if self.payload_length: + vals.append(" %s"%utils.pretty_size(self.payload_length)) + return "".join(vals) def to_bytes(self): first_byte = utils.setbit(0, 7, self.fin) @@ -308,17 +313,8 @@ class Frame(object): return cls.from_file(io.BytesIO(bytestring)) def human_readable(self): - return "\n".join([ - ("fin - " + str(self.header.fin)), - ("rsv1 - " + str(self.header.rsv1)), - ("rsv2 - " + str(self.header.rsv2)), - ("rsv3 - " + str(self.header.rsv3)), - ("opcode - " + str(self.header.opcode)), - ("mask - " + str(self.header.mask)), - ("length_code - " + str(self.header.length_code)), - ("masking_key - " + repr(str(self.header.masking_key))), - ("payload - " + repr(str(self.payload))), - ]) + hdr = self.header.human_readable() + return hdr + "\n" + repr(self.payload) def to_bytes(self): """ diff --git a/test/test_utils.py b/test/test_utils.py index 0cdd3fae..942136fd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -5,9 +5,10 @@ import tutils def test_bidi(): b = utils.BiDi(a=1, b=2) assert b.a == 1 - assert b[1] == "a" + assert b.get_name(1) == "a" + assert b.get_name(5) is None tutils.raises(AttributeError, getattr, b, "c") - tutils.raises(KeyError, b.__getitem__, 5) + tutils.raises(ValueError, utils.BiDi, one=1, two=1) def test_hexdump(): @@ -19,3 +20,10 @@ def test_cleanBin(): assert utils.cleanBin("\00ne") == ".ne" assert utils.cleanBin("\nne") == "\nne" assert utils.cleanBin("\nne", True) == ".ne" + + +def test_pretty_size(): + assert utils.pretty_size(100) == "100B" + assert utils.pretty_size(1024) == "1kB" + assert utils.pretty_size(1024 + (1024/2.0)) == "1.5kB" + assert utils.pretty_size(1024*1024) == "1MB" diff --git a/test/test_websockets.py b/test/test_websockets.py index 9266d93e..d8e56a8f 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -185,7 +185,13 @@ class TestFrameHeader: round(masking_key="test") def test_human_readable(self): - f = websockets.FrameHeader(masking_key="test", mask=False) + f = websockets.FrameHeader( + masking_key="test", + fin=True, + payload_length=10 + ) + assert f.human_readable() + f = websockets.FrameHeader() assert f.human_readable() def test_funky(self): -- cgit v1.2.3 From 7d9e38ffb10e92b5127f203c2d8a524da8698b00 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 1 May 2015 10:09:35 +1200 Subject: websockets: A progressive masker. --- netlib/websockets.py | 32 ++++++++++++++++++-------------- test/test_websockets.py | 16 ++++++++++++++++ 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index 1d02d684..84eb03ba 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -35,21 +35,25 @@ OPCODE = utils.BiDi( ) -def apply_mask(message, masking_key): +class Masker: """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns - This method both encodes and decodes strings with the provided mask - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 """ - masks = [utils.bytes_to_int(byte) for byte in masking_key] - result = "" - for char in message: - result += chr(ord(char) ^ masks[len(result) % 4]) - return result + def __init__(self, key): + self.key = key + self.masks = [utils.bytes_to_int(byte) for byte in key] + self.offset = 0 + + def __call__(self, data): + result = "" + for c in data: + result += chr(ord(c) ^ self.masks[self.offset % 4]) + self.offset += 1 + return result def client_handshake_headers(key=None, version=VERSION): @@ -324,7 +328,7 @@ class Frame(object): """ b = self.header.to_bytes() if self.header.masking_key: - b += apply_mask(self.payload, self.header.masking_key) + b += Masker(self.header.masking_key)(self.payload) else: b += self.payload return b @@ -345,7 +349,7 @@ class Frame(object): payload = fp.read(header.payload_length) if header.mask == 1 and header.masking_key: - payload = apply_mask(payload, header.masking_key) + payload = Masker(header.masking_key)(payload) return cls( payload, diff --git a/test/test_websockets.py b/test/test_websockets.py index d8e56a8f..428f7c61 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -232,3 +232,19 @@ class TestFrame: def test_human_readable(self): f = websockets.Frame() assert f.human_readable() + + +def test_masker(): + tests = [ + ["a"], + ["four"], + ["fourf"], + ["fourfive"], + ["a", "aasdfasdfa", "asdf"], + ["a"*50, "aasdfasdfa", "asdf"], + ] + for i in tests: + m = websockets.Masker("abcd") + data = "".join([m(t) for t in i]) + data2 = websockets.Masker("abcd")(data) + assert data2 == "".join(i) -- cgit v1.2.3 From 08b2e2a6a98fd175e1b49d62dffde34e91c77b1c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 1 May 2015 10:31:20 +1200 Subject: websockets: more flexible masking interface. --- netlib/websockets.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index 84eb03ba..0ad0e294 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -48,13 +48,18 @@ class Masker: self.masks = [utils.bytes_to_int(byte) for byte in key] self.offset = 0 - def __call__(self, data): + def mask(self, offset, data): result = "" for c in data: - result += chr(ord(c) ^ self.masks[self.offset % 4]) - self.offset += 1 + result += chr(ord(c) ^ self.masks[offset % 4]) + offset += 1 return result + def __call__(self, data): + ret = self.mask(self.offset, data) + self.offset += len(ret) + return ret + def client_handshake_headers(key=None, version=VERSION): """ -- cgit v1.2.3 From f2bc58cdd2f2b9b0025a88c0faccf55e10b29353 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 5 May 2015 10:47:02 +1200 Subject: Add tcp.Reader.safe_read, use it in socks and websockets safe_read is guaranteed to raise or return a byte string of the requested length. It's particularly useful for implementing binary protocols. --- netlib/socks.py | 32 +++++++++----------------------- netlib/tcp.py | 48 ++++++++++++++++++++++++++++++++++-------------- netlib/websockets.py | 16 ++++++++-------- test/test_http.py | 8 ++++---- test/test_socks.py | 16 ++++++++-------- test/test_websockets.py | 7 +++---- test/tutils.py | 15 ++++++++++++++- 7 files changed, 80 insertions(+), 62 deletions(-) diff --git a/netlib/socks.py b/netlib/socks.py index 497b8eef..6f9f57bd 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -52,20 +52,6 @@ METHOD = utils.BiDi( ) -def _read(f, n): - try: - d = f.read(n) - if len(d) == n: - return d - else: - raise SocksError( - REP.GENERAL_SOCKS_SERVER_FAILURE, - "Incomplete Read" - ) - except socket.error as e: - raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e)) - - class ClientGreeting(object): __slots__ = ("ver", "methods") @@ -75,9 +61,9 @@ class ClientGreeting(object): @classmethod def from_file(cls, f): - ver, nmethods = struct.unpack("!BB", _read(f, 2)) + ver, nmethods = struct.unpack("!BB", f.safe_read(2)) methods = array.array("B") - methods.fromstring(_read(f, nmethods)) + methods.fromstring(f.safe_read(nmethods)) return cls(ver, methods) def to_file(self, f): @@ -94,7 +80,7 @@ class ServerGreeting(object): @classmethod def from_file(cls, f): - ver, method = struct.unpack("!BB", _read(f, 2)) + ver, method = struct.unpack("!BB", f.safe_read(2)) return cls(ver, method) def to_file(self, f): @@ -112,27 +98,27 @@ class Message(object): @classmethod def from_file(cls, f): - ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4)) + ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) if rsv != 0x00: raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Socks Request: Invalid reserved byte: %s" % rsv) if atyp == ATYP.IPV4_ADDRESS: # We use tnoa here as ntop is not commonly available on Windows. - host = socket.inet_ntoa(_read(f, 4)) + host = socket.inet_ntoa(f.safe_read(4)) use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = socket.inet_ntop(socket.AF_INET6, _read(f, 16)) + host = socket.inet_ntop(socket.AF_INET6, f.safe_read(16)) use_ipv6 = True elif atyp == ATYP.DOMAINNAME: - length, = struct.unpack("!B", _read(f, 1)) - host = _read(f, length) + length, = struct.unpack("!B", f.safe_read(1)) + host = f.safe_read(length) use_ipv6 = False else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, "Socks Request: Unknown ATYP: %s" % atyp) - port, = struct.unpack("!H", _read(f, 2)) + port, = struct.unpack("!H", f.safe_read(2)) addr = tcp.Address((host, port), use_ipv6=use_ipv6) return cls(ver, msg, atyp, addr) diff --git a/netlib/tcp.py b/netlib/tcp.py index 84008e2c..dbe114a1 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -24,6 +24,7 @@ OP_NO_SSLv3 = SSL.OP_NO_SSLv3 class NetLibError(Exception): pass class NetLibDisconnect(NetLibError): pass +class NetLibIncomplete(NetLibError): pass class NetLibTimeout(NetLibError): pass class NetLibSSLError(NetLibError): pass @@ -195,10 +196,23 @@ class Reader(_FileLike): break return result + def safe_read(self, length): + """ + Like .read, but is guaranteed to either return length bytes, or + raise an exception. + """ + result = self.read(length) + if length != -1 and len(result) != length: + raise NetLibIncomplete( + "Expected %s bytes, got %s"%(length, len(result)) + ) + return result + class Address(object): """ - This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. + This class wraps an IPv4/IPv6 tuple to provide named attributes and + ipv6 information. """ def __init__(self, address, use_ipv6=False): self.address = tuple(address) @@ -247,22 +261,28 @@ def close_socket(sock): """ try: # We already indicate that we close our end. - sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux + # may raise "Transport endpoint is not connected" on Linux + sock.shutdown(socket.SHUT_WR) - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent (which is the case on Windows). + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending + # readable data could lead to an immediate RST being sent (which is the + # case on Windows). # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html # - # This in turn results in the following issue: If we send an error page to the client and then close the socket, - # the RST may be received by the client before the error page and the users sees a connection error rather than - # the error page. Thus, we try to empty the read buffer on Windows first. - # (see https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) + # This in turn results in the following issue: If we send an error page + # to the client and then close the socket, the RST may be received by + # the client before the error page and the users sees a connection + # error rather than the error page. Thus, we try to empty the read + # buffer on Windows first. (see + # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) # + if os.name == "nt": # pragma: no cover - # We cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: - # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that - # recv() would block infinitely. - # As a workaround, we set a timeout here even if we are in blocking mode. + # We cannot rely on the shutdown()-followed-by-read()-eof technique + # proposed by the page above: Some remote machines just don't send + # a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. As a workaround, we set a timeout + # here even if we are in blocking mode. sock.settimeout(sock.gettimeout() or 20) # limit at a megabyte so that we don't read infinitely @@ -292,10 +312,10 @@ class _Connection(object): def finish(self): self.finished = True - # If we have an SSL connection, wfile.close == connection.close # (We call _FileLike.set_descriptor(conn)) - # Closing the socket is not our task, therefore we don't call close then. + # Closing the socket is not our task, therefore we don't call close + # then. if type(self.connection) != SSL.Connection: if not getattr(self.wfile, "closed", False): try: diff --git a/netlib/websockets.py b/netlib/websockets.py index 0ad0e294..6d08e101 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -5,7 +5,7 @@ import os import struct import io -from . import utils, odict +from . import utils, odict, tcp # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -217,8 +217,8 @@ class FrameHeader: """ read a websockets frame header """ - first_byte = utils.bytes_to_int(fp.read(1)) - second_byte = utils.bytes_to_int(fp.read(1)) + first_byte = utils.bytes_to_int(fp.safe_read(1)) + second_byte = utils.bytes_to_int(fp.safe_read(1)) fin = utils.getbit(first_byte, 7) rsv1 = utils.getbit(first_byte, 6) @@ -235,13 +235,13 @@ class FrameHeader: if length_code <= 125: payload_length = length_code elif length_code == 126: - payload_length = utils.bytes_to_int(fp.read(2)) + payload_length = utils.bytes_to_int(fp.safe_read(2)) elif length_code == 127: - payload_length = utils.bytes_to_int(fp.read(8)) + payload_length = utils.bytes_to_int(fp.safe_read(8)) # masking key only present if mask bit set if mask_bit == 1: - masking_key = fp.read(4) + masking_key = fp.safe_read(4) else: masking_key = None @@ -319,7 +319,7 @@ class Frame(object): Construct a websocket frame from an in-memory bytestring to construct a frame from a stream of bytes, use from_file() directly """ - return cls.from_file(io.BytesIO(bytestring)) + return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) def human_readable(self): hdr = self.header.human_readable() @@ -351,7 +351,7 @@ class Frame(object): stream or a disk or an in memory stream reader """ header = FrameHeader.from_file(fp) - payload = fp.read(header.payload_length) + payload = fp.safe_read(header.payload_length) if header.mask == 1 and header.masking_key: payload = Masker(header.masking_key)(payload) diff --git a/test/test_http.py b/test/test_http.py index f1a31b93..63b39f08 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -91,7 +91,7 @@ def test_read_http_body_request(): def test_read_http_body_response(): h = odict.ODictCaseless() - s = cStringIO.StringIO("testing") + s = tcp.Reader(cStringIO.StringIO("testing")) assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" @@ -135,11 +135,11 @@ def test_read_http_body(): # test no content length: limit > actual content h = odict.ODictCaseless() - s = cStringIO.StringIO("testing") + s = tcp.Reader(cStringIO.StringIO("testing")) assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content - s = cStringIO.StringIO("testing") + s = tcp.Reader(cStringIO.StringIO("testing")) tutils.raises( http.HttpError, http.read_http_body, @@ -149,7 +149,7 @@ def test_read_http_body(): # test chunked h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] - s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") + s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" diff --git a/test/test_socks.py b/test/test_socks.py index aa4f9c11..6e522826 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -7,7 +7,7 @@ import tutils def test_client_greeting(): - raw = StringIO("\x05\x02\x00\xBE\xEF") + raw = tutils.treader("\x05\x02\x00\xBE\xEF") out = StringIO() msg = socks.ClientGreeting.from_file(raw) msg.to_file(out) @@ -20,7 +20,7 @@ def test_client_greeting(): def test_server_greeting(): - raw = StringIO("\x05\x02") + raw = tutils.treader("\x05\x02") out = StringIO() msg = socks.ServerGreeting.from_file(raw) msg.to_file(out) @@ -31,7 +31,7 @@ def test_server_greeting(): def test_message(): - raw = StringIO("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") + raw = tutils.treader("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") out = StringIO() msg = socks.Message.from_file(raw) assert raw.read(2) == "\xBE\xEF" @@ -46,7 +46,7 @@ def test_message(): def test_message_ipv4(): # Test ATYP=0x01 (IPV4) - raw = StringIO("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + raw = tutils.treader("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") out = StringIO() msg = socks.Message.from_file(raw) assert raw.read(2) == "\xBE\xEF" @@ -62,7 +62,7 @@ def test_message_ipv6(): # Test ATYP=0x04 (IPV6) ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344" - raw = StringIO("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF") + raw = tutils.treader("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF") out = StringIO() msg = socks.Message.from_file(raw) assert raw.read(2) == "\xBE\xEF" @@ -73,12 +73,12 @@ def test_message_ipv6(): def test_message_invalid_rsv(): - raw = StringIO("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + raw = tutils.treader("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") tutils.raises(socks.SocksError, socks.Message.from_file, raw) def test_message_unknown_atyp(): - raw = StringIO("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + raw = tutils.treader("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") tutils.raises(socks.SocksError, socks.Message.from_file, raw) m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) @@ -93,4 +93,4 @@ def test_read(): cs = mock.Mock() cs.read = mock.Mock(side_effect=socket.error) - tutils.raises(socks.SocksError, socks._read, cs, 4) \ No newline at end of file + tutils.raises(socks.SocksError, socks._read, cs, 4) diff --git a/test/test_websockets.py b/test/test_websockets.py index 428f7c61..7bd5d74e 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,4 +1,3 @@ -import cStringIO import os from nose.tools import raises @@ -170,7 +169,7 @@ class TestFrameHeader: def round(*args, **kwargs): f = websockets.FrameHeader(*args, **kwargs) bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) assert f == f2 round() round(fin=1) @@ -197,7 +196,7 @@ class TestFrameHeader: def test_funky(self): f = websockets.FrameHeader(masking_key="test", mask=False) bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) assert not f2.mask def test_violations(self): @@ -221,7 +220,7 @@ class TestFrame: def round(*args, **kwargs): f = websockets.Frame(*args, **kwargs) bytes = f.to_bytes() - f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes)) + f2 = websockets.Frame.from_file(tutils.treader(bytes)) assert f == f2 round("test") round("test", fin=1) diff --git a/test/tutils.py b/test/tutils.py index ea30f59c..141979f8 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -1,7 +1,20 @@ -import tempfile, os, shutil +import cStringIO +import tempfile +import os +import shutil from contextlib import contextmanager from libpathod import utils +from netlib import tcp + + +def treader(bytes): + """ + Construct a tcp.Read object from bytes. + """ + fp = cStringIO.StringIO(bytes) + return tcp.Reader(fp) + @contextmanager def tmpdir(*args, **kwargs): -- cgit v1.2.3 From dabb356c15bf0e51ae37b3c5fb3c04fd5b944afd Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 5 May 2015 10:52:50 +1200 Subject: Zap a left-over test --- test/test_socks.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/test/test_socks.py b/test/test_socks.py index 6e522826..a596dedf 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -1,6 +1,5 @@ from cStringIO import StringIO import socket -import mock from nose.plugins.skip import SkipTest from netlib import socks, tcp import tutils @@ -83,14 +82,3 @@ def test_message_unknown_atyp(): m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) tutils.raises(socks.SocksError, m.to_file, StringIO()) - -def test_read(): - cs = StringIO("1234") - assert socks._read(cs, 3) == "123" - - cs = StringIO("123") - tutils.raises(socks.SocksError, socks._read, cs, 4) - - cs = mock.Mock() - cs.read = mock.Mock(side_effect=socket.error) - tutils.raises(socks.SocksError, socks._read, cs, 4) -- cgit v1.2.3 From ace4454523a81303b6432714f8ff73dab02a7e33 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 16 May 2015 11:32:18 +1200 Subject: Zap outdated comment --- netlib/websockets.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index 6d08e101..a2d55c19 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -327,9 +327,7 @@ class Frame(object): def to_bytes(self): """ - Serialize the frame back into the wire format, returns a bytestring - If you haven't checked is_valid_frame() then there's no guarentees - that the serialized bytes will be correct. see safe_to_bytes() + Serialize the frame to wire format. Returns a string. """ b = self.header.to_bytes() if self.header.masking_key: -- cgit v1.2.3 From f40bf865b1e767d4f15e0e829b9ca3132c33d11d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 18 May 2015 10:46:00 +1200 Subject: release prep: bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index 826c66fe..502dce3a 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 11, 2) +IVERSION = (0, 12, 0) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 46fadfc82386265c26b77ea0d8c3801585c84fbc Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 18 May 2015 17:16:42 +0200 Subject: improve displaying tcp addresses --- netlib/tcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index dbe114a1..a5f43ea3 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -245,7 +245,10 @@ class Address(object): self.family = socket.AF_INET6 if b else socket.AF_INET def __repr__(self): - return repr(self.address) + return "{}:{}".format(self.host, self.port) + + def __str__(self): + return str(self.address) def __eq__(self, other): other = Address.wrap(other) -- cgit v1.2.3 From ae749975e537990f3db767b4d0d4c6ec2321a088 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 26 May 2015 10:43:28 +1200 Subject: Post release version bump. --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index 502dce3a..3eb0ffc9 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 12, 0) +IVERSION = (0, 12, 1) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 4ce6f43616db9c23a29484610045aecd88ed2cfc Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 25 May 2015 12:10:21 +0200 Subject: implement basic HTTP/2 frame classes --- netlib/h2/__init__.py | 1 + netlib/h2/frame.py | 375 +++++++++++++++++++++++++++++++++++++++++++++++++ netlib/h2/h2.py | 25 ++++ test/h2/__init__.py | 0 test/h2/test_frames.py | 341 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 742 insertions(+) create mode 100644 netlib/h2/__init__.py create mode 100644 netlib/h2/frame.py create mode 100644 netlib/h2/h2.py create mode 100644 test/h2/__init__.py create mode 100644 test/h2/test_frames.py diff --git a/netlib/h2/__init__.py b/netlib/h2/__init__.py new file mode 100644 index 00000000..9b4faa33 --- /dev/null +++ b/netlib/h2/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py new file mode 100644 index 00000000..52cc2992 --- /dev/null +++ b/netlib/h2/frame.py @@ -0,0 +1,375 @@ +import base64 +import hashlib +import os +import struct +import io + +from .. import utils, odict, tcp + +class Frame(object): + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__(self, length, flags, stream_id): + valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + self.length = length + self.flags = flags + self.stream_id = stream_id + + @classmethod + def from_bytes(self, data): + fields = struct.unpack("!HBBBL", data[:9]) + length = (fields[0] << 8) + fields[1] + # type is already deducted from class + flags = fields[3] + stream_id = fields[4] + return FRAMES[fields[2]].from_bytes(length, flags, stream_id, data[9:]) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b'', pad_length=0): + super(DataFrame, self).__init__(length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): + super(HeadersFrame, self).__init__(length, flags, stream_id) + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & self.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack('!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, exclusive=False, stream_dependency=0x0, weight=0): + super(PriorityFrame, self).__init__(length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('PRIORITY frames MUST be associated with a stream.') + + if self.stream_dependency == 0x0: + raise ValueError('stream dependency is invalid.') + + return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, error_code=0x0): + super(RstStreamFrame, self).__init__(length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = utils.BiDi( + SETTINGS_HEADER_TABLE_SIZE = 0x1, + SETTINGS_ENABLE_PUSH = 0x2, + SETTINGS_MAX_CONCURRENT_STREAMS = 0x3, + SETTINGS_INITIAL_WINDOW_SIZE = 0x4, + SETTINGS_MAX_FRAME_SIZE = 0x5, + SETTINGS_MAX_HEADER_LIST_SIZE = 0x6, + ) + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings={}): + super(SettingsFrame, self).__init__(length, flags, stream_id) + self.settings = settings + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + for i in xrange(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i+6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError('SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, promised_stream=0x0, header_block_fragment=b'', pad_length=0): + super(PushPromiseFrame, self).__init__(length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b''): + super(PingFrame, self).__init__(length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError('PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, last_stream=0x0, error_code=0x0, data=b''): + super(GoAwayFrame, self).__init__(length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError('GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2**31: + raise ValueError('Window Szie Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b''): + super(ContinuationFrame, self).__init__(length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(self, length, flags, stream_id, payload): + f = self(length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py new file mode 100644 index 00000000..5d74c1c8 --- /dev/null +++ b/netlib/h2/h2.py @@ -0,0 +1,25 @@ +import base64 +import hashlib +import os +import struct +import io + +# "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" +CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + +ERROR_CODES = utils.BiDi( + NO_ERROR = 0x0, + PROTOCOL_ERROR = 0x1, + INTERNAL_ERROR = 0x2, + FLOW_CONTROL_ERROR = 0x3, + SETTINGS_TIMEOUT = 0x4, + STREAM_CLOSED = 0x5, + FRAME_SIZE_ERROR = 0x6, + REFUSED_STREAM = 0x7, + CANCEL = 0x8, + COMPRESSION_ERROR = 0x9, + CONNECT_ERROR = 0xa, + ENHANCE_YOUR_CALM = 0xb, + INADEQUATE_SECURITY = 0xc, + HTTP_1_1_REQUIRED = 0xd + ) diff --git a/test/h2/__init__.py b/test/h2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/h2/test_frames.py b/test/h2/test_frames.py new file mode 100644 index 00000000..d04c7c8b --- /dev/null +++ b/test/h2/test_frames.py @@ -0,0 +1,341 @@ +from netlib.h2.frame import * +import tutils + +from nose.tools import assert_equal + + + +# TODO test stream association if valid or not + +def test_invalid_flags(): + tutils.raises(ValueError, DataFrame, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') + +def test_frame_equality(): + a = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') + b = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') + assert_equal(a, b) + +def test_too_large_frames(): + DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567) + +def test_data_frame_to_bytes(): + f = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') + assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') + + f = DataFrame(11, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED, 0x1234567, 'foobar', pad_length=3) + assert_equal(f.to_bytes().encode('hex'), '00000a00090123456703666f6f626172000000') + + f = DataFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') + tutils.raises(ValueError, f.to_bytes) + +def test_data_frame_from_bytes(): + f = Frame.from_bytes('000006000101234567666f6f626172'.decode('hex')) + assert isinstance(f, DataFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, DataFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_END_STREAM) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.payload, 'foobar') + + f = Frame.from_bytes('00000a00090123456703666f6f626172000000'.decode('hex')) + assert isinstance(f, DataFrame) + assert_equal(f.length, 10) + assert_equal(f.TYPE, DataFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.payload, 'foobar') + +def test_headers_frame_to_bytes(): + f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x1234567, 'foobar') + assert_equal(f.to_bytes().encode('hex'), '000006010001234567666f6f626172') + + f = HeadersFrame(10, HeadersFrame.FLAG_PADDED, 0x1234567, 'foobar', pad_length=3) + assert_equal(f.to_bytes().encode('hex'), '00000a01080123456703666f6f626172000000') + + f = HeadersFrame(10, HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', exclusive=True, stream_dependency=0x7654321, weight=42) + assert_equal(f.to_bytes().encode('hex'), '00000b012001234567876543212a666f6f626172') + + f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', pad_length=3, exclusive=True, stream_dependency=0x7654321, weight=42) + assert_equal(f.to_bytes().encode('hex'), '00000f01280123456703876543212a666f6f626172000000') + + f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', pad_length=3, exclusive=False, stream_dependency=0x7654321, weight=42) + assert_equal(f.to_bytes().encode('hex'), '00000f01280123456703076543212a666f6f626172000000') + + f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') + tutils.raises(ValueError, f.to_bytes) + +def test_headers_frame_from_bytes(): + f = Frame.from_bytes('000006010001234567666f6f626172'.decode('hex')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + f = Frame.from_bytes('00000a01080123456703666f6f626172000000'.decode('hex')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 10) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + f = Frame.from_bytes('00000b012001234567876543212a666f6f626172'.decode('hex')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 11) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_bytes('00000f01280123456703876543212a666f6f626172000000'.decode('hex')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 15) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_bytes('00000f01280123456703076543212a666f6f626172000000'.decode('hex')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 15) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + assert_equal(f.exclusive, False) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + +def test_priority_frame_to_bytes(): + f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, exclusive=True, stream_dependency=0x7654321, weight=42) + assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') + + f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, exclusive=False, stream_dependency=0x7654321, weight=21) + assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') + + f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x0, stream_dependency=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, stream_dependency=0x0) + tutils.raises(ValueError, f.to_bytes) + +def test_priority_frame_from_bytes(): + f = Frame.from_bytes('000005020001234567876543212a'.decode('hex')) + assert isinstance(f, PriorityFrame) + assert_equal(f.length, 5) + assert_equal(f.TYPE, PriorityFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_bytes('0000050200012345670765432115'.decode('hex')) + assert isinstance(f, PriorityFrame) + assert_equal(f.length, 5) + assert_equal(f.TYPE, PriorityFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.exclusive, False) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 21) + +def test_rst_stream_frame_to_bytes(): + f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, error_code=0x7654321) + assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') + + f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x0) + tutils.raises(ValueError, f.to_bytes) + +def test_rst_stream_frame_from_bytes(): + f = Frame.from_bytes('00000403000123456707654321'.decode('hex')) + assert isinstance(f, RstStreamFrame) + assert_equal(f.length, 4) + assert_equal(f.TYPE, RstStreamFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.error_code, 0x07654321) + +def test_settings_frame_to_bytes(): + f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x0) + assert_equal(f.to_bytes().encode('hex'), '000000040000000000') + + f = SettingsFrame(0, SettingsFrame.FLAG_ACK, 0x0) + assert_equal(f.to_bytes().encode('hex'), '000000040100000000') + + f = SettingsFrame(6, SettingsFrame.FLAG_ACK, 0x0, settings={SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) + assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') + + f = SettingsFrame(12, Frame.FLAG_NO_FLAGS, 0x0, settings={SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + assert_equal(f.to_bytes().encode('hex'), '00000c040000000000000200000001000312345678') + + f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x1234567) + tutils.raises(ValueError, f.to_bytes) + +def test_settings_frame_from_bytes(): + f = Frame.from_bytes('000000040000000000'.decode('hex')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 0) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + + f = Frame.from_bytes('000000040100000000'.decode('hex')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 0) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, SettingsFrame.FLAG_ACK) + assert_equal(f.stream_id, 0x0) + + f = Frame.from_bytes('000006040100000000000200000001'.decode('hex')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, SettingsFrame.FLAG_ACK, 0x0) + assert_equal(f.stream_id, 0x0) + assert_equal(len(f.settings), 1) + assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) + + f = Frame.from_bytes('00000c040000000000000200000001000312345678'.decode('hex')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 12) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(len(f.settings), 2) + assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) + assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], 0x12345678) + +def test_push_promise_frame_to_bytes(): + f = PushPromiseFrame(10, Frame.FLAG_NO_FLAGS, 0x1234567, 0x7654321, 'foobar') + assert_equal(f.to_bytes().encode('hex'), '00000a05000123456707654321666f6f626172') + + f = PushPromiseFrame(14, HeadersFrame.FLAG_PADDED, 0x1234567, 0x7654321, 'foobar', pad_length=3) + assert_equal(f.to_bytes().encode('hex'), '00000e0508012345670307654321666f6f626172000000') + + f = PushPromiseFrame(4, Frame.FLAG_NO_FLAGS, 0x0, 0x1234567) + tutils.raises(ValueError, f.to_bytes) + + f = PushPromiseFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, 0x0) + tutils.raises(ValueError, f.to_bytes) + +def test_push_promise_frame_from_bytes(): + f = Frame.from_bytes('00000a05000123456707654321666f6f626172'.decode('hex')) + assert isinstance(f, PushPromiseFrame) + assert_equal(f.length, 10) + assert_equal(f.TYPE, PushPromiseFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + f = Frame.from_bytes('00000e0508012345670307654321666f6f626172000000'.decode('hex')) + assert isinstance(f, PushPromiseFrame) + assert_equal(f.length, 14) + assert_equal(f.TYPE, PushPromiseFrame.TYPE) + assert_equal(f.flags, PushPromiseFrame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + +def test_ping_frame_to_bytes(): + f = PingFrame(8, PingFrame.FLAG_ACK, 0x0, payload=b'foobar') + assert_equal(f.to_bytes().encode('hex'), '000008060100000000666f6f6261720000') + + f = PingFrame(8, Frame.FLAG_NO_FLAGS, 0x0, payload=b'foobardeadbeef') + assert_equal(f.to_bytes().encode('hex'), '000008060000000000666f6f6261726465') + + f = PingFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567) + tutils.raises(ValueError, f.to_bytes) + +def test_ping_frame_from_bytes(): + f = Frame.from_bytes('000008060100000000666f6f6261720000'.decode('hex')) + assert isinstance(f, PingFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, PingFrame.TYPE) + assert_equal(f.flags, PingFrame.FLAG_ACK) + assert_equal(f.stream_id, 0x0) + assert_equal(f.payload, b'foobar\0\0') + + f = Frame.from_bytes('000008060000000000666f6f6261726465'.decode('hex')) + assert isinstance(f, PingFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, PingFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.payload, b'foobarde') + +def test_goaway_frame_to_bytes(): + f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x0, last_stream=0x1234567, error_code=0x87654321, data=b'') + assert_equal(f.to_bytes().encode('hex'), '0000080700000000000123456787654321') + + f = GoAwayFrame(14, Frame.FLAG_NO_FLAGS, 0x0, last_stream=0x1234567, error_code=0x87654321, data=b'foobar') + assert_equal(f.to_bytes().encode('hex'), '00000e0700000000000123456787654321666f6f626172') + + f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567, last_stream=0x1234567, error_code=0x87654321) + tutils.raises(ValueError, f.to_bytes) + +def test_goaway_frame_from_bytes(): + f = Frame.from_bytes('0000080700000000000123456787654321'.decode('hex')) + assert isinstance(f, GoAwayFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, GoAwayFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.last_stream, 0x1234567) + assert_equal(f.error_code, 0x87654321) + assert_equal(f.data, b'') + + f = Frame.from_bytes('00000e0700000000000123456787654321666f6f626172'.decode('hex')) + assert isinstance(f, GoAwayFrame) + assert_equal(f.length, 14) + assert_equal(f.TYPE, GoAwayFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.last_stream, 0x1234567) + assert_equal(f.error_code, 0x87654321) + assert_equal(f.data, b'foobar') + +def test_window_update_frame_to_bytes(): + f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0x1234567) + assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') + + f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, window_size_increment=0x7654321) + assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') + + f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0xdeadbeef) + tutils.raises(ValueError, f.to_bytes) + + f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) + tutils.raises(ValueError, f.to_bytes) + +def test_window_update_frame_from_bytes(): + f = Frame.from_bytes('00000408000000000001234567'.decode('hex')) + assert isinstance(f, WindowUpdateFrame) + assert_equal(f.length, 4) + assert_equal(f.TYPE, WindowUpdateFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.window_size_increment, 0x1234567) + +def test_continuation_frame_to_bytes(): + f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') + assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') + + f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x0, 'foobar') + tutils.raises(ValueError, f.to_bytes) + +def test_continuation_frame_from_bytes(): + f = Frame.from_bytes('000006090401234567666f6f626172'.decode('hex')) + assert isinstance(f, ContinuationFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, ContinuationFrame.TYPE) + assert_equal(f.flags, ContinuationFrame.FLAG_END_HEADERS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') -- cgit v1.2.3 From 1967a49cd997bf188bd63066e688e979d73759f9 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 10:21:28 +0200 Subject: bump pyOpenSSL and cryptography dependencies --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 8e3d51b8..86a55c4c 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,8 @@ setup( install_requires=[ "pyasn1>=0.1.7", - "pyOpenSSL>=0.14", + "pyOpenSSL>=0.15.1", + "cryptography>=0.9", "passlib>=1.6.2" ], extras_require={ @@ -52,4 +53,4 @@ setup( "pathod>=%s, <%s" % (version.MINORVERSION, version.NEXT_MINORVERSION) ] } -) \ No newline at end of file +) -- cgit v1.2.3 From d6a68e1394ac57854ac1fa09fd19b88d015789e1 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 10:21:50 +0200 Subject: remove outdated workarounds --- netlib/tcp.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index a5f43ea3..399203bb 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -307,10 +307,10 @@ class _Connection(object): def get_current_cipher(self): if not self.ssl_established: return None - c = SSL._lib.SSL_get_current_cipher(self.connection._ssl) - name = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_name(c))) - bits = SSL._lib.SSL_CIPHER_get_bits(c, SSL._ffi.NULL) - version = SSL._native(SSL._ffi.string(SSL._lib.SSL_CIPHER_get_version(c))) + + name = self.connection.get_cipher_name() + bits = self.connection.get_cipher_bits() + version = self.connection.get_cipher_version() return name, bits, version def finish(self): @@ -333,10 +333,6 @@ class _Connection(object): self.connection.shutdown() except SSL.Error: pass - except KeyError as e: # pragma: no cover - # Workaround for https://github.com/pyca/pyopenssl/pull/183 - if OpenSSL.__version__ != "0.14": - raise e """ Creates an SSL Context. -- cgit v1.2.3 From 041ca5c499369ffbf115e4451b85aee77e3095c0 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 10:53:23 +0200 Subject: update TLS defaults: signature hash and DH params * SHA1 is deprecated (use SHA256) * increase RSA key to 2048 bits * increase DH params to 4096 bits (LogJam attack) --- netlib/certutils.py | 32 +++++++++++++++++++++----------- test/data/dhparam.pem | 14 +++++++++++--- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index f5375c03..507241b2 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -8,15 +8,25 @@ import OpenSSL DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 # Generated with "openssl dhparam". It's too slow to generate this on startup. -DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS----- -MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 -zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK -1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC ------END DH PARAMETERS-----""" +DEFAULT_DHPARAM = """ +-----BEGIN DH PARAMETERS----- +MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 +O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv +j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ +Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB +chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC +ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq +o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX +IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv +A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8 +6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I +rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= +-----END DH PARAMETERS----- +""" def create_ca(o, cn, exp): key = OpenSSL.crypto.PKey() - key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) + key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048) cert = OpenSSL.crypto.X509() cert.set_serial_number(int(time.time()*10000)) cert.set_version(2) @@ -39,7 +49,7 @@ def create_ca(o, cn, exp): OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", subject=cert), ]) - cert.sign(key, "sha1") + cert.sign(key, "sha256") return key, cert @@ -69,7 +79,7 @@ def dummy_cert(privkey, cacert, commonname, sans): cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) cert.set_pubkey(cacert.get_pubkey()) - cert.sign(privkey, "sha1") + cert.sign(privkey, "sha256") return SSLCert(cert) @@ -124,7 +134,7 @@ class CertStore(object): """ Implements an in-memory certificate store. """ - def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None): + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams): self.default_privatekey = default_privatekey self.default_ca = default_ca self.default_chain_file = default_chain_file @@ -148,7 +158,7 @@ class CertStore(object): ) dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) return dh - + @classmethod def from_store(cls, path, basename): ca_path = os.path.join(path, basename + "-ca.pem") @@ -296,7 +306,7 @@ class SSLCert(object): self.x509 = cert def __eq__(self, other): - return self.digest("sha1") == other.digest("sha1") + return self.digest("sha256") == other.digest("sha256") def __ne__(self, other): return not self.__eq__(other) diff --git a/test/data/dhparam.pem b/test/data/dhparam.pem index 6f2526e1..afb41672 100644 --- a/test/data/dhparam.pem +++ b/test/data/dhparam.pem @@ -1,5 +1,13 @@ -----BEGIN DH PARAMETERS----- -MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 -zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK -1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC +MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 +O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv +j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ +Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB +chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC +ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq +o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX +IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv +A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8 +6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I +rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= -----END DH PARAMETERS----- -- cgit v1.2.3 From e3d390e036430b9d7cc4b93679229fe118eb583a Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 11:18:54 +0200 Subject: cleanup code with autopep8 run the following command: $ autopep8 -i -r -a -a . --- netlib/certutils.py | 56 +++++++------ netlib/h2/frame.py | 34 +++++--- netlib/h2/h2.py | 30 +++---- netlib/http.py | 13 +-- netlib/http_auth.py | 22 +++-- netlib/http_cookies.py | 10 +-- netlib/http_status.py | 84 ++++++++++---------- netlib/odict.py | 10 ++- netlib/socks.py | 43 +++++----- netlib/tcp.py | 62 ++++++++++----- netlib/test.py | 24 +++--- netlib/utils.py | 10 ++- netlib/websockets.py | 87 ++++++++++---------- netlib/wsgi.py | 52 ++++++------ setup.cfg | 7 ++ test/h2/test_frames.py | 32 +++++++- test/test_certutils.py | 5 +- test/test_http.py | 4 +- test/test_http_auth.py | 20 +++-- test/test_http_uastrings.py | 1 - test/test_odict.py | 7 +- test/test_tcp.py | 190 +++++++++++++++++++++++++++----------------- test/test_utils.py | 6 +- test/test_websockets.py | 18 +++-- test/test_wsgi.py | 5 +- test/tutils.py | 8 +- tools/getcertnames | 2 +- 27 files changed, 510 insertions(+), 332 deletions(-) create mode 100644 setup.cfg diff --git a/netlib/certutils.py b/netlib/certutils.py index f5375c03..da0e3355 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,12 +1,15 @@ from __future__ import (absolute_import, print_function, division) -import os, ssl, time, datetime +import os +import ssl +import time +import datetime import itertools from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 +DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS----- MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 @@ -14,31 +17,32 @@ zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK 1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC -----END DH PARAMETERS-----""" + def create_ca(o, cn, exp): key = OpenSSL.crypto.PKey() key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) cert = OpenSSL.crypto.X509() - cert.set_serial_number(int(time.time()*10000)) + cert.set_serial_number(int(time.time() * 10000)) cert.set_version(2) cert.get_subject().CN = cn cert.get_subject().O = o - cert.gmtime_adj_notBefore(-3600*48) + cert.gmtime_adj_notBefore(-3600 * 48) cert.gmtime_adj_notAfter(exp) cert.set_issuer(cert.get_subject()) cert.set_pubkey(key) cert.add_extensions([ - OpenSSL.crypto.X509Extension("basicConstraints", True, - "CA:TRUE"), - OpenSSL.crypto.X509Extension("nsCertType", False, - "sslCA"), - OpenSSL.crypto.X509Extension("extendedKeyUsage", False, - "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" - ), - OpenSSL.crypto.X509Extension("keyUsage", True, - "keyCertSign, cRLSign"), - OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", - subject=cert), - ]) + OpenSSL.crypto.X509Extension("basicConstraints", True, + "CA:TRUE"), + OpenSSL.crypto.X509Extension("nsCertType", False, + "sslCA"), + OpenSSL.crypto.X509Extension("extendedKeyUsage", False, + "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" + ), + OpenSSL.crypto.X509Extension("keyUsage", True, + "keyCertSign, cRLSign"), + OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", + subject=cert), + ]) cert.sign(key, "sha1") return key, cert @@ -56,15 +60,15 @@ def dummy_cert(privkey, cacert, commonname, sans): """ ss = [] for i in sans: - ss.append("DNS: %s"%i) + ss.append("DNS: %s" % i) ss = ", ".join(ss) cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore(-3600*48) + cert.gmtime_adj_notBefore(-3600 * 48) cert.gmtime_adj_notAfter(DEFAULT_EXP) cert.set_issuer(cacert.get_subject()) cert.get_subject().CN = commonname - cert.set_serial_number(int(time.time()*10000)) + cert.set_serial_number(int(time.time() * 10000)) if ss: cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) @@ -114,6 +118,7 @@ def dummy_cert(privkey, cacert, commonname, sans): class CertStoreEntry(object): + def __init__(self, cert, privatekey, chain_file): self.cert = cert self.privatekey = privatekey @@ -121,9 +126,11 @@ class CertStoreEntry(object): class CertStore(object): + """ Implements an in-memory certificate store. """ + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None): self.default_privatekey = default_privatekey self.default_ca = default_ca @@ -144,11 +151,11 @@ class CertStore(object): if bio != OpenSSL.SSL._ffi.NULL: bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( - bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL - ) + bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL + ) dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) return dh - + @classmethod def from_store(cls, path, basename): ca_path = os.path.join(path, basename + "-ca.pem") @@ -277,8 +284,8 @@ class _GeneralName(univ.Choice): # other types. componentType = namedtype.NamedTypes( namedtype.NamedType('dNSName', char.IA5String().subtype( - implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) - ) + implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) + ) ), ) @@ -289,6 +296,7 @@ class _GeneralNames(univ.SequenceOf): class SSLCert(object): + def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 52cc2992..d846b3b9 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -5,8 +5,11 @@ import struct import io from .. import utils, odict, tcp +from functools import reduce + class Frame(object): + """ Baseclass Frame contains header @@ -53,6 +56,7 @@ class Frame(object): def __eq__(self, other): return self.to_bytes() == other.to_bytes() + class DataFrame(Frame): TYPE = 0x0 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] @@ -89,11 +93,13 @@ class DataFrame(Frame): return b + class HeadersFrame(Frame): TYPE = 0x1 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', + pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): super(HeadersFrame, self).__init__(length, flags, stream_id) self.header_block_fragment = header_block_fragment self.pad_length = pad_length @@ -137,6 +143,7 @@ class HeadersFrame(Frame): return b + class PriorityFrame(Frame): TYPE = 0x2 VALID_FLAGS = [] @@ -166,6 +173,7 @@ class PriorityFrame(Frame): return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + class RstStreamFrame(Frame): TYPE = 0x3 VALID_FLAGS = [] @@ -186,18 +194,19 @@ class RstStreamFrame(Frame): return struct.pack('!L', self.error_code) + class SettingsFrame(Frame): TYPE = 0x4 VALID_FLAGS = [Frame.FLAG_ACK] SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE = 0x1, - SETTINGS_ENABLE_PUSH = 0x2, - SETTINGS_MAX_CONCURRENT_STREAMS = 0x3, - SETTINGS_INITIAL_WINDOW_SIZE = 0x4, - SETTINGS_MAX_FRAME_SIZE = 0x5, - SETTINGS_MAX_HEADER_LIST_SIZE = 0x6, - ) + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings={}): super(SettingsFrame, self).__init__(length, flags, stream_id) @@ -208,7 +217,7 @@ class SettingsFrame(Frame): f = self(length=length, flags=flags, stream_id=stream_id) for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i+6]) + identifier, value = struct.unpack("!HL", payload[i:i + 6]) f.settings[identifier] = value return f @@ -223,6 +232,7 @@ class SettingsFrame(Frame): return b + class PushPromiseFrame(Frame): TYPE = 0x5 VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] @@ -267,6 +277,7 @@ class PushPromiseFrame(Frame): return b + class PingFrame(Frame): TYPE = 0x6 VALID_FLAGS = [Frame.FLAG_ACK] @@ -289,6 +300,7 @@ class PingFrame(Frame): b += b'\0' * (8 - len(b)) return b + class GoAwayFrame(Frame): TYPE = 0x7 VALID_FLAGS = [] @@ -317,6 +329,7 @@ class GoAwayFrame(Frame): b += bytes(self.data) return b + class WindowUpdateFrame(Frame): TYPE = 0x8 VALID_FLAGS = [] @@ -335,11 +348,12 @@ class WindowUpdateFrame(Frame): return f def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2**31: + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: raise ValueError('Window Szie Increment MUST be greater than 0 and less than 2^31.') return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + class ContinuationFrame(Frame): TYPE = 0x9 VALID_FLAGS = [Frame.FLAG_END_HEADERS] diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index 5d74c1c8..1a39a635 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -8,18 +8,18 @@ import io CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' ERROR_CODES = utils.BiDi( - NO_ERROR = 0x0, - PROTOCOL_ERROR = 0x1, - INTERNAL_ERROR = 0x2, - FLOW_CONTROL_ERROR = 0x3, - SETTINGS_TIMEOUT = 0x4, - STREAM_CLOSED = 0x5, - FRAME_SIZE_ERROR = 0x6, - REFUSED_STREAM = 0x7, - CANCEL = 0x8, - COMPRESSION_ERROR = 0x9, - CONNECT_ERROR = 0xa, - ENHANCE_YOUR_CALM = 0xb, - INADEQUATE_SECURITY = 0xc, - HTTP_1_1_REQUIRED = 0xd - ) + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd +) diff --git a/netlib/http.py b/netlib/http.py index 43155486..47658097 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -8,6 +8,7 @@ from . import odict, utils, tcp, http_status class HttpError(Exception): + def __init__(self, code, message): super(HttpError, self).__init__(message) self.code = code @@ -95,7 +96,7 @@ def read_headers(fp): """ ret = [] name = '' - while 1: + while True: line = fp.readline() if not line or line == '\r\n' or line == '\n': break @@ -337,7 +338,7 @@ def read_http_body_chunked( otherwise """ if max_chunk_size is None: - max_chunk_size = limit or sys.maxint + max_chunk_size = limit or sys.maxsize expected_size = expected_http_body_size( headers, is_request, request_method, response_code @@ -399,10 +400,10 @@ def expected_http_body_size(headers, is_request, request_method, response_code): request_method = request_method.upper() if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): return 0 if has_chunked_encoding(headers): return None diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 296e094c..261b6654 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -4,9 +4,11 @@ from . import http class NullProxyAuth(object): + """ No proxy auth at all (returns empty challange headers) """ + def __init__(self, password_manager): self.password_manager = password_manager @@ -48,7 +50,7 @@ class BasicProxyAuth(NullProxyAuth): if not parts: return False scheme, username, password = parts - if scheme.lower()!='basic': + if scheme.lower() != 'basic': return False if not self.password_manager.test(username, password): return False @@ -56,18 +58,21 @@ class BasicProxyAuth(NullProxyAuth): return True def auth_challenge_headers(self): - return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm} + return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} class PassMan(object): + def test(self, username, password_token): return False class PassManNonAnon(PassMan): + """ Ensure the user specifies a username, accept any password. """ + def test(self, username, password_token): if username: return True @@ -75,9 +80,11 @@ class PassManNonAnon(PassMan): class PassManHtpasswd(PassMan): + """ Read usernames and passwords from an htpasswd file """ + def __init__(self, path): """ Raises ValueError if htpasswd file is invalid. @@ -90,14 +97,16 @@ class PassManHtpasswd(PassMan): class PassManSingleUser(PassMan): + def __init__(self, username, password): self.username, self.password = username, password def test(self, username, password_token): - return self.username==username and self.password==password_token + return self.username == username and self.password == password_token class AuthAction(Action): + """ Helper class to allow seamless integration int argparse. Example usage: parser.add_argument( @@ -106,16 +115,18 @@ class AuthAction(Action): help="Allow access to any user long as a credentials are specified." ) """ + def __call__(self, parser, namespace, values, option_string=None): passman = self.getPasswordManager(values) authenticator = BasicProxyAuth(passman, "mitmproxy") setattr(namespace, self.dest, authenticator) - def getPasswordManager(self, s): # pragma: nocover + def getPasswordManager(self, s): # pragma: nocover raise NotImplementedError() class SingleuserAuthAction(AuthAction): + def getPasswordManager(self, s): if len(s.split(':')) != 2: raise ArgumentTypeError( @@ -126,11 +137,12 @@ class SingleuserAuthAction(AuthAction): class NonanonymousAuthAction(AuthAction): + def getPasswordManager(self, s): return PassManNonAnon() class HtpasswdAuthAction(AuthAction): + def getPasswordManager(self, s): return PassManHtpasswd(s) - diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 8e245891..73e3f589 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -96,7 +96,7 @@ def _read_pairs(s, off=0, specials=()): specials: a lower-cased list of keys that may contain commas """ vals = [] - while 1: + while True: lhs, off = _read_token(s, off) lhs = lhs.lstrip() if lhs: @@ -135,15 +135,15 @@ def _format_pairs(lst, specials=(), sep="; "): else: if k.lower() not in specials and _has_special(v): v = ESCAPE.sub(r"\\\1", v) - v = '"%s"'%v - vals.append("%s=%s"%(k, v)) + v = '"%s"' % v + vals.append("%s=%s" % (k, v)) return sep.join(vals) def _format_set_cookie_pairs(lst): return _format_pairs( lst, - specials = ("expires", "path") + specials=("expires", "path") ) @@ -154,7 +154,7 @@ def _parse_set_cookie_pairs(s): """ pairs, off = _read_pairs( s, - specials = ("expires", "path") + specials=("expires", "path") ) return pairs diff --git a/netlib/http_status.py b/netlib/http_status.py index 7dba2d56..dc09f465 100644 --- a/netlib/http_status.py +++ b/netlib/http_status.py @@ -1,51 +1,51 @@ from __future__ import (absolute_import, print_function, division) -CONTINUE = 100 -SWITCHING = 101 -OK = 200 -CREATED = 201 -ACCEPTED = 202 -NON_AUTHORITATIVE_INFORMATION = 203 -NO_CONTENT = 204 -RESET_CONTENT = 205 -PARTIAL_CONTENT = 206 -MULTI_STATUS = 207 +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 -MULTIPLE_CHOICE = 300 -MOVED_PERMANENTLY = 301 -FOUND = 302 -SEE_OTHER = 303 -NOT_MODIFIED = 304 -USE_PROXY = 305 -TEMPORARY_REDIRECT = 307 +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 -BAD_REQUEST = 400 -UNAUTHORIZED = 401 -PAYMENT_REQUIRED = 402 -FORBIDDEN = 403 -NOT_FOUND = 404 -NOT_ALLOWED = 405 -NOT_ACCEPTABLE = 406 -PROXY_AUTH_REQUIRED = 407 -REQUEST_TIMEOUT = 408 -CONFLICT = 409 -GONE = 410 -LENGTH_REQUIRED = 411 -PRECONDITION_FAILED = 412 -REQUEST_ENTITY_TOO_LARGE = 413 -REQUEST_URI_TOO_LONG = 414 -UNSUPPORTED_MEDIA_TYPE = 415 +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 REQUESTED_RANGE_NOT_SATISFIABLE = 416 -EXPECTATION_FAILED = 417 +EXPECTATION_FAILED = 417 -INTERNAL_SERVER_ERROR = 500 -NOT_IMPLEMENTED = 501 -BAD_GATEWAY = 502 -SERVICE_UNAVAILABLE = 503 -GATEWAY_TIMEOUT = 504 -HTTP_VERSION_NOT_SUPPORTED = 505 -INSUFFICIENT_STORAGE_SPACE = 507 -NOT_EXTENDED = 510 +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510 RESPONSES = { # 100 diff --git a/netlib/odict.py b/netlib/odict.py index dd738c55..f52acd50 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,5 +1,6 @@ from __future__ import (absolute_import, print_function, division) -import re, copy +import re +import copy def safe_subn(pattern, repl, target, *args, **kwargs): @@ -12,10 +13,12 @@ def safe_subn(pattern, repl, target, *args, **kwargs): class ODict(object): + """ A dictionary-like object for managing ordered (key, value) data. Think about it as a convenient interface to a list of (key, value) tuples. """ + def __init__(self, lst=None): self.lst = lst or [] @@ -157,7 +160,7 @@ class ODict(object): "key: value" """ for k, v in self.lst: - s = "%s: %s"%(k, v) + s = "%s: %s" % (k, v) if re.search(expr, s): return True return False @@ -192,11 +195,12 @@ class ODict(object): return klass([list(i) for i in state]) - class ODictCaseless(ODict): + """ A variant of ODict with "caseless" keys. This version _preserves_ key case, but does not consider case when setting or getting items. """ + def _kconv(self, s): return s.lower() diff --git a/netlib/socks.py b/netlib/socks.py index 6f9f57bd..5a73c61a 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -6,49 +6,50 @@ from . import tcp, utils class SocksError(Exception): + def __init__(self, code, message): super(SocksError, self).__init__(message) self.code = code VERSION = utils.BiDi( - SOCKS4 = 0x04, - SOCKS5 = 0x05 + SOCKS4=0x04, + SOCKS5=0x05 ) CMD = utils.BiDi( - CONNECT = 0x01, - BIND = 0x02, - UDP_ASSOCIATE = 0x03 + CONNECT=0x01, + BIND=0x02, + UDP_ASSOCIATE=0x03 ) ATYP = utils.BiDi( - IPV4_ADDRESS = 0x01, - DOMAINNAME = 0x03, - IPV6_ADDRESS = 0x04 + IPV4_ADDRESS=0x01, + DOMAINNAME=0x03, + IPV6_ADDRESS=0x04 ) REP = utils.BiDi( - SUCCEEDED = 0x00, - GENERAL_SOCKS_SERVER_FAILURE = 0x01, - CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02, - NETWORK_UNREACHABLE = 0x03, - HOST_UNREACHABLE = 0x04, - CONNECTION_REFUSED = 0x05, - TTL_EXPIRED = 0x06, - COMMAND_NOT_SUPPORTED = 0x07, - ADDRESS_TYPE_NOT_SUPPORTED = 0x08, + SUCCEEDED=0x00, + GENERAL_SOCKS_SERVER_FAILURE=0x01, + CONNECTION_NOT_ALLOWED_BY_RULESET=0x02, + NETWORK_UNREACHABLE=0x03, + HOST_UNREACHABLE=0x04, + CONNECTION_REFUSED=0x05, + TTL_EXPIRED=0x06, + COMMAND_NOT_SUPPORTED=0x07, + ADDRESS_TYPE_NOT_SUPPORTED=0x08, ) METHOD = utils.BiDi( - NO_AUTHENTICATION_REQUIRED = 0x00, - GSSAPI = 0x01, - USERNAME_PASSWORD = 0x02, - NO_ACCEPTABLE_METHODS = 0xFF + NO_AUTHENTICATION_REQUIRED=0x00, + GSSAPI=0x01, + USERNAME_PASSWORD=0x02, + NO_ACCEPTABLE_METHODS=0xFF ) diff --git a/netlib/tcp.py b/netlib/tcp.py index 399203bb..7c115554 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -22,14 +22,28 @@ OP_NO_SSLv2 = SSL.OP_NO_SSLv2 OP_NO_SSLv3 = SSL.OP_NO_SSLv3 -class NetLibError(Exception): pass -class NetLibDisconnect(NetLibError): pass -class NetLibIncomplete(NetLibError): pass -class NetLibTimeout(NetLibError): pass -class NetLibSSLError(NetLibError): pass +class NetLibError(Exception): + pass + + +class NetLibDisconnect(NetLibError): + pass + + +class NetLibIncomplete(NetLibError): + pass + + +class NetLibTimeout(NetLibError): + pass + + +class NetLibSSLError(NetLibError): + pass class SSLKeyLogger(object): + def __init__(self, filename): self.filename = filename self.f = None @@ -67,6 +81,7 @@ log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or class _FileLike(object): BLOCKSIZE = 1024 * 32 + def __init__(self, o): self.o = o self._log = None @@ -112,6 +127,7 @@ class _FileLike(object): class Writer(_FileLike): + def flush(self): """ May raise NetLibDisconnect @@ -119,7 +135,7 @@ class Writer(_FileLike): if hasattr(self.o, "flush"): try: self.o.flush() - except (socket.error, IOError), v: + except (socket.error, IOError) as v: raise NetLibDisconnect(str(v)) def write(self, v): @@ -135,11 +151,12 @@ class Writer(_FileLike): r = self.o.write(v) self.add_log(v[:r]) return r - except (SSL.Error, socket.error) as e: + except (SSL.Error, socket.error) as e: raise NetLibDisconnect(str(e)) class Reader(_FileLike): + def read(self, length): """ If length is -1, we read until connection closes. @@ -180,7 +197,7 @@ class Reader(_FileLike): self.add_log(result) return result - def readline(self, size = None): + def readline(self, size=None): result = '' bytes_read = 0 while True: @@ -204,16 +221,18 @@ class Reader(_FileLike): result = self.read(length) if length != -1 and len(result) != length: raise NetLibIncomplete( - "Expected %s bytes, got %s"%(length, len(result)) + "Expected %s bytes, got %s" % (length, len(result)) ) return result class Address(object): + """ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ + def __init__(self, address, use_ipv6=False): self.address = tuple(address) self.use_ipv6 = use_ipv6 @@ -304,6 +323,7 @@ def close_socket(sock): class _Connection(object): + def get_current_cipher(self): if not self.ssl_established: return None @@ -319,7 +339,7 @@ class _Connection(object): # (We call _FileLike.set_descriptor(conn)) # Closing the socket is not our task, therefore we don't call close # then. - if type(self.connection) != SSL.Connection: + if not isinstance(self.connection, SSL.Connection): if not getattr(self.wfile, "closed", False): try: self.wfile.flush() @@ -337,6 +357,7 @@ class _Connection(object): """ Creates an SSL Context. """ + def _create_ssl_context(self, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), @@ -362,8 +383,8 @@ class _Connection(object): if cipher_list: try: context.set_cipher_list(cipher_list) - except SSL.Error, v: - raise NetLibError("SSL cipher specification error: %s"%str(v)) + except SSL.Error as v: + raise NetLibError("SSL cipher specification error: %s" % str(v)) # SSLKEYLOGFILE if log_ssl_key: @@ -380,7 +401,7 @@ class TCPClient(_Connection): # Make sure to close the real socket, not the SSL proxy. # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, # it tries to renegotiate... - if type(self.connection) == SSL.Connection: + if isinstance(self.connection, SSL.Connection): close_socket(self.connection._socket) else: close_socket(self.connection) @@ -400,8 +421,8 @@ class TCPClient(_Connection): try: context.use_privatekey_file(cert) context.use_certificate_file(cert) - except SSL.Error, v: - raise NetLibError("SSL client certificate error: %s"%str(v)) + except SSL.Error as v: + raise NetLibError("SSL client certificate error: %s" % str(v)) return context def convert_to_ssl(self, sni=None, **sslctx_kwargs): @@ -418,8 +439,8 @@ class TCPClient(_Connection): self.connection.set_connect_state() try: self.connection.do_handshake() - except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%repr(v)) + except SSL.Error as v: + raise NetLibError("SSL handshake error: %s" % repr(v)) self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) @@ -435,7 +456,7 @@ class TCPClient(_Connection): self.source_address = Address(connection.getsockname()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) - except (socket.error, IOError), err: + except (socket.error, IOError) as err: raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection @@ -447,6 +468,7 @@ class TCPClient(_Connection): class BaseHandler(_Connection): + """ The instantiator is expected to call the handle() and finish() methods. @@ -531,8 +553,8 @@ class BaseHandler(_Connection): self.connection.set_accept_state() try: self.connection.do_handshake() - except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%repr(v)) + except SSL.Error as v: + raise NetLibError("SSL handshake error: %s" % repr(v)) self.ssl_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) diff --git a/netlib/test.py b/netlib/test.py index db30c0e6..b6f94273 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,9 +1,13 @@ from __future__ import (absolute_import, print_function, division) -import threading, Queue, cStringIO +import threading +import Queue +import cStringIO import OpenSSL from . import tcp, certutils + class ServerThread(threading.Thread): + def __init__(self, server): self.server = server threading.Thread.__init__(self) @@ -19,6 +23,7 @@ class ServerTestBase(object): ssl = None handler = None addr = ("localhost", 0) + @classmethod def setupAll(cls): cls.q = Queue.Queue() @@ -41,10 +46,11 @@ class ServerTestBase(object): class TServer(tcp.TCPServer): + def __init__(self, ssl, q, handler_klass, addr): """ ssl: A dictionary of SSL parameters: - + cert, key, request_client_cert, cipher_list, dhparams, v3_only """ @@ -70,13 +76,13 @@ class TServer(tcp.TCPServer): options = None h.convert_to_ssl( cert, key, - method = method, - options = options, - handle_sni = getattr(h, "handle_sni", None), - request_client_cert = self.ssl["request_client_cert"], - cipher_list = self.ssl.get("cipher_list", None), - dhparams = self.ssl.get("dhparams", None), - chain_file = self.ssl.get("chain_file", None) + method=method, + options=options, + handle_sni=getattr(h, "handle_sni", None), + request_client_cert=self.ssl["request_client_cert"], + cipher_list=self.ssl.get("cipher_list", None), + dhparams=self.ssl.get("dhparams", None), + chain_file=self.ssl.get("chain_file", None) ) h.handle() h.finish() diff --git a/netlib/utils.py b/netlib/utils.py index 7e539977..9c5404e6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -68,6 +68,7 @@ def getbit(byte, offset): class BiDi: + """ A wee utility class for keeping bi-directional mappings, like field constants in protocols. Names are attributes on the object, dict-like @@ -77,6 +78,7 @@ class BiDi: assert CONST.a == 1 assert CONST.get_name(1) == "a" """ + def __init__(self, **kwargs): self.names = kwargs self.values = {} @@ -96,15 +98,15 @@ class BiDi: def pretty_size(size): suffixes = [ - ("B", 2**10), - ("kB", 2**20), - ("MB", 2**30), + ("B", 2 ** 10), + ("kB", 2 ** 20), + ("MB", 2 ** 30), ] for suf, lim in suffixes: if size >= lim: continue else: - x = round(size/float(lim/2**10), 2) + x = round(size / float(lim / 2 ** 10), 2) if x == int(x): x = int(x) return str(x) + suf diff --git a/netlib/websockets.py b/netlib/websockets.py index a2d55c19..63dc03f1 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -26,16 +26,17 @@ MAX_64_BIT_INT = (1 << 64) OPCODE = utils.BiDi( - CONTINUE = 0x00, - TEXT = 0x01, - BINARY = 0x02, - CLOSE = 0x08, - PING = 0x09, - PONG = 0x0a + CONTINUE=0x00, + TEXT=0x01, + BINARY=0x02, + CLOSE=0x08, + PING=0x09, + PONG=0x0a ) class Masker: + """ Data sent from the server must be masked to prevent malicious clients from sending data over the wire in predictable patterns @@ -43,6 +44,7 @@ class Masker: Servers do not have to mask data they send to the client. https://tools.ietf.org/html/rfc6455#section-5.3 """ + def __init__(self, key): self.key = key self.masks = [utils.bytes_to_int(byte) for byte in key] @@ -128,17 +130,18 @@ DEFAULT = object() class FrameHeader: + def __init__( self, - opcode = OPCODE.TEXT, - payload_length = 0, - fin = False, - rsv1 = False, - rsv2 = False, - rsv3 = False, - masking_key = DEFAULT, - mask = DEFAULT, - length_code = DEFAULT + opcode=OPCODE.TEXT, + payload_length=0, + fin=False, + rsv1=False, + rsv2=False, + rsv3=False, + masking_key=DEFAULT, + mask=DEFAULT, + length_code=DEFAULT ): if not 0 <= opcode < 2 ** 4: raise ValueError("opcode must be 0-16") @@ -182,9 +185,9 @@ class FrameHeader: if flags: vals.extend([":", "|".join(flags)]) if self.masking_key: - vals.append(":key=%s"%repr(self.masking_key)) + vals.append(":key=%s" % repr(self.masking_key)) if self.payload_length: - vals.append(" %s"%utils.pretty_size(self.payload_length)) + vals.append(" %s" % utils.pretty_size(self.payload_length)) return "".join(vals) def to_bytes(self): @@ -246,15 +249,15 @@ class FrameHeader: masking_key = None return klass( - fin = fin, - rsv1 = rsv1, - rsv2 = rsv2, - rsv3 = rsv3, - opcode = opcode, - mask = mask_bit, - length_code = length_code, - payload_length = payload_length, - masking_key = masking_key, + fin=fin, + rsv1=rsv1, + rsv2=rsv2, + rsv3=rsv3, + opcode=opcode, + mask=mask_bit, + length_code=length_code, + payload_length=payload_length, + masking_key=masking_key, ) def __eq__(self, other): @@ -262,6 +265,7 @@ class FrameHeader: class Frame(object): + """ Represents one websockets frame. Constructor takes human readable forms of the frame components @@ -287,13 +291,14 @@ class Frame(object): | Payload Data continued ... | +---------------------------------------------------------------+ """ - def __init__(self, payload = "", **kwargs): + + def __init__(self, payload="", **kwargs): self.payload = payload kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) self.header = FrameHeader(**kwargs) @classmethod - def default(cls, message, from_client = False): + def default(cls, message, from_client=False): """ Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. @@ -307,10 +312,10 @@ class Frame(object): return cls( message, - fin = 1, # final frame - opcode = OPCODE.TEXT, # text - mask = mask_bit, - masking_key = masking_key, + fin=1, # final frame + opcode=OPCODE.TEXT, # text + mask=mask_bit, + masking_key=masking_key, ) @classmethod @@ -356,15 +361,15 @@ class Frame(object): return cls( payload, - fin = header.fin, - opcode = header.opcode, - mask = header.mask, - payload_length = header.payload_length, - masking_key = header.masking_key, - rsv1 = header.rsv1, - rsv2 = header.rsv2, - rsv3 = header.rsv3, - length_code = header.length_code + fin=header.fin, + opcode=header.opcode, + mask=header.mask, + payload_length=header.payload_length, + masking_key=header.masking_key, + rsv1=header.rsv1, + rsv2=header.rsv2, + rsv3=header.rsv3, + length_code=header.length_code ) def __eq__(self, other): diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 1b979608..f393039a 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -7,17 +7,20 @@ from . import odict, tcp class ClientConn(object): + def __init__(self, address): self.address = tcp.Address.wrap(address) class Flow(object): + def __init__(self, address, request): self.client_conn = ClientConn(address) self.request = request class Request(object): + def __init__(self, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content @@ -42,6 +45,7 @@ def date_time_string(): class WSGIAdaptor(object): + def __init__(self, app, domain, port, sversion): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion @@ -52,24 +56,24 @@ class WSGIAdaptor(object): path_info = flow.request.path query = '' environ = { - 'wsgi.version': (1, 0), - 'wsgi.url_scheme': flow.request.scheme, - 'wsgi.input': cStringIO.StringIO(flow.request.content), - 'wsgi.errors': errsoc, - 'wsgi.multithread': True, - 'wsgi.multiprocess': False, - 'wsgi.run_once': False, - 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': flow.request.method, - 'SCRIPT_NAME': '', - 'PATH_INFO': urllib.unquote(path_info), - 'QUERY_STRING': query, - 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], - 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], - 'SERVER_NAME': self.domain, - 'SERVER_PORT': str(self.port), + 'wsgi.version': (1, 0), + 'wsgi.url_scheme': flow.request.scheme, + 'wsgi.input': cStringIO.StringIO(flow.request.content), + 'wsgi.errors': errsoc, + 'wsgi.multithread': True, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'SERVER_SOFTWARE': self.sversion, + 'REQUEST_METHOD': flow.request.method, + 'SCRIPT_NAME': '', + 'PATH_INFO': urllib.unquote(path_info), + 'QUERY_STRING': query, + 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], + 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], + 'SERVER_NAME': self.domain, + 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. - 'SERVER_PROTOCOL': "HTTP/1.1", + 'SERVER_PROTOCOL': "HTTP/1.1", } environ.update(extra) if flow.client_conn.address: @@ -91,25 +95,25 @@ class WSGIAdaptor(object):

Internal Server Error

%s"
- """%s + """ % s if not headers_sent: soc.write("HTTP/1.1 500 Internal Server Error\r\n") soc.write("Content-Type: text/html\r\n") - soc.write("Content-Length: %s\r\n"%len(c)) + soc.write("Content-Length: %s\r\n" % len(c)) soc.write("\r\n") soc.write(c) def serve(self, request, soc, **env): state = dict( - response_started = False, - headers_sent = False, - status = None, - headers = None + response_started=False, + headers_sent=False, + status=None, + headers=None ) def write(data): if not state["headers_sent"]: - soc.write("HTTP/1.1 %s\r\n"%state["status"]) + soc.write("HTTP/1.1 %s\r\n" % state["status"]) h = state["headers"] if 'server' not in h: h["Server"] = [self.sversion] diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..1ba84a24 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,7 @@ +[flake8] +max-line-length = 160 +max-complexity = 15 + +[pep8] +max-line-length = 160 +max-complexity = 15 diff --git a/test/h2/test_frames.py b/test/h2/test_frames.py index d04c7c8b..90162984 100644 --- a/test/h2/test_frames.py +++ b/test/h2/test_frames.py @@ -4,20 +4,22 @@ import tutils from nose.tools import assert_equal - # TODO test stream association if valid or not def test_invalid_flags(): tutils.raises(ValueError, DataFrame, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') + def test_frame_equality(): a = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') b = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') assert_equal(a, b) + def test_too_large_frames(): DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567) + def test_data_frame_to_bytes(): f = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') @@ -28,6 +30,7 @@ def test_data_frame_to_bytes(): f = DataFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') tutils.raises(ValueError, f.to_bytes) + def test_data_frame_from_bytes(): f = Frame.from_bytes('000006000101234567666f6f626172'.decode('hex')) assert isinstance(f, DataFrame) @@ -45,6 +48,7 @@ def test_data_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.payload, 'foobar') + def test_headers_frame_to_bytes(): f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x1234567, 'foobar') assert_equal(f.to_bytes().encode('hex'), '000006010001234567666f6f626172') @@ -55,15 +59,18 @@ def test_headers_frame_to_bytes(): f = HeadersFrame(10, HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', exclusive=True, stream_dependency=0x7654321, weight=42) assert_equal(f.to_bytes().encode('hex'), '00000b012001234567876543212a666f6f626172') - f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', pad_length=3, exclusive=True, stream_dependency=0x7654321, weight=42) + f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, + 'foobar', pad_length=3, exclusive=True, stream_dependency=0x7654321, weight=42) assert_equal(f.to_bytes().encode('hex'), '00000f01280123456703876543212a666f6f626172000000') - f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', pad_length=3, exclusive=False, stream_dependency=0x7654321, weight=42) + f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', + pad_length=3, exclusive=False, stream_dependency=0x7654321, weight=42) assert_equal(f.to_bytes().encode('hex'), '00000f01280123456703076543212a666f6f626172000000') f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') tutils.raises(ValueError, f.to_bytes) + def test_headers_frame_from_bytes(): f = Frame.from_bytes('000006010001234567666f6f626172'.decode('hex')) assert isinstance(f, HeadersFrame) @@ -114,6 +121,7 @@ def test_headers_frame_from_bytes(): assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) + def test_priority_frame_to_bytes(): f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, exclusive=True, stream_dependency=0x7654321, weight=42) assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') @@ -127,6 +135,7 @@ def test_priority_frame_to_bytes(): f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, stream_dependency=0x0) tutils.raises(ValueError, f.to_bytes) + def test_priority_frame_from_bytes(): f = Frame.from_bytes('000005020001234567876543212a'.decode('hex')) assert isinstance(f, PriorityFrame) @@ -148,6 +157,7 @@ def test_priority_frame_from_bytes(): assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 21) + def test_rst_stream_frame_to_bytes(): f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, error_code=0x7654321) assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') @@ -155,6 +165,7 @@ def test_rst_stream_frame_to_bytes(): f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x0) tutils.raises(ValueError, f.to_bytes) + def test_rst_stream_frame_from_bytes(): f = Frame.from_bytes('00000403000123456707654321'.decode('hex')) assert isinstance(f, RstStreamFrame) @@ -164,6 +175,7 @@ def test_rst_stream_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.error_code, 0x07654321) + def test_settings_frame_to_bytes(): f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x0) assert_equal(f.to_bytes().encode('hex'), '000000040000000000') @@ -174,12 +186,14 @@ def test_settings_frame_to_bytes(): f = SettingsFrame(6, SettingsFrame.FLAG_ACK, 0x0, settings={SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') - f = SettingsFrame(12, Frame.FLAG_NO_FLAGS, 0x0, settings={SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + f = SettingsFrame(12, Frame.FLAG_NO_FLAGS, 0x0, settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) assert_equal(f.to_bytes().encode('hex'), '00000c040000000000000200000001000312345678') f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x1234567) tutils.raises(ValueError, f.to_bytes) + def test_settings_frame_from_bytes(): f = Frame.from_bytes('000000040000000000'.decode('hex')) assert isinstance(f, SettingsFrame) @@ -214,6 +228,7 @@ def test_settings_frame_from_bytes(): assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], 0x12345678) + def test_push_promise_frame_to_bytes(): f = PushPromiseFrame(10, Frame.FLAG_NO_FLAGS, 0x1234567, 0x7654321, 'foobar') assert_equal(f.to_bytes().encode('hex'), '00000a05000123456707654321666f6f626172') @@ -227,6 +242,7 @@ def test_push_promise_frame_to_bytes(): f = PushPromiseFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, 0x0) tutils.raises(ValueError, f.to_bytes) + def test_push_promise_frame_from_bytes(): f = Frame.from_bytes('00000a05000123456707654321666f6f626172'.decode('hex')) assert isinstance(f, PushPromiseFrame) @@ -244,6 +260,7 @@ def test_push_promise_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.header_block_fragment, 'foobar') + def test_ping_frame_to_bytes(): f = PingFrame(8, PingFrame.FLAG_ACK, 0x0, payload=b'foobar') assert_equal(f.to_bytes().encode('hex'), '000008060100000000666f6f6261720000') @@ -254,6 +271,7 @@ def test_ping_frame_to_bytes(): f = PingFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567) tutils.raises(ValueError, f.to_bytes) + def test_ping_frame_from_bytes(): f = Frame.from_bytes('000008060100000000666f6f6261720000'.decode('hex')) assert isinstance(f, PingFrame) @@ -271,6 +289,7 @@ def test_ping_frame_from_bytes(): assert_equal(f.stream_id, 0x0) assert_equal(f.payload, b'foobarde') + def test_goaway_frame_to_bytes(): f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x0, last_stream=0x1234567, error_code=0x87654321, data=b'') assert_equal(f.to_bytes().encode('hex'), '0000080700000000000123456787654321') @@ -281,6 +300,7 @@ def test_goaway_frame_to_bytes(): f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567, last_stream=0x1234567, error_code=0x87654321) tutils.raises(ValueError, f.to_bytes) + def test_goaway_frame_from_bytes(): f = Frame.from_bytes('0000080700000000000123456787654321'.decode('hex')) assert isinstance(f, GoAwayFrame) @@ -302,6 +322,7 @@ def test_goaway_frame_from_bytes(): assert_equal(f.error_code, 0x87654321) assert_equal(f.data, b'foobar') + def test_window_update_frame_to_bytes(): f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0x1234567) assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') @@ -315,6 +336,7 @@ def test_window_update_frame_to_bytes(): f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) tutils.raises(ValueError, f.to_bytes) + def test_window_update_frame_from_bytes(): f = Frame.from_bytes('00000408000000000001234567'.decode('hex')) assert isinstance(f, WindowUpdateFrame) @@ -324,6 +346,7 @@ def test_window_update_frame_from_bytes(): assert_equal(f.stream_id, 0x0) assert_equal(f.window_size_increment, 0x1234567) + def test_continuation_frame_to_bytes(): f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') @@ -331,6 +354,7 @@ def test_continuation_frame_to_bytes(): f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x0, 'foobar') tutils.raises(ValueError, f.to_bytes) + def test_continuation_frame_from_bytes(): f = Frame.from_bytes('000006090401234567666f6f626172'.decode('hex')) assert isinstance(f, ContinuationFrame) diff --git a/test/test_certutils.py b/test/test_certutils.py index c96c5087..4af0197f 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -34,6 +34,7 @@ import tutils class TestCertStore: + def test_create_explicit(self): with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") @@ -102,6 +103,7 @@ class TestCertStore: class TestDummyCert: + def test_with_ca(self): with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") @@ -115,6 +117,7 @@ class TestDummyCert: class TestSSLCert: + def test_simple(self): with open(tutils.test_data.path("data/text_cert"), "rb") as f: d = f.read() @@ -152,5 +155,3 @@ class TestSSLCert: d = f.read() s = certutils.SSLCert.from_der(d) assert s.cn - - diff --git a/test/test_http.py b/test/test_http.py index 63b39f08..0a9e276f 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -230,6 +230,7 @@ def test_parse_init_http(): class TestReadHeaders: + def _read(self, data, verbatim=False): if not verbatim: data = textwrap.dedent(data) @@ -277,6 +278,7 @@ class TestReadHeaders: class NoContentLengthHTTPHandler(tcp.BaseHandler): + def handle(self): self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") self.wfile.flush() @@ -297,7 +299,7 @@ def test_read_response(): data = textwrap.dedent(data) r = cStringIO.StringIO(data) return http.read_response( - r, method, limit, include_body = include_body + r, method, limit, include_body=include_body ) tutils.raises("server disconnect", tst, "", "GET", None) diff --git a/test/test_http_auth.py b/test/test_http_auth.py index 176aa3ff..25df5410 100644 --- a/test/test_http_auth.py +++ b/test/test_http_auth.py @@ -1,9 +1,12 @@ -import binascii, cStringIO +import binascii +import cStringIO from netlib import odict, http_auth, http import mock import tutils + class TestPassManNonAnon: + def test_simple(self): p = http_auth.PassManNonAnon() assert not p.test("", "") @@ -11,6 +14,7 @@ class TestPassManNonAnon: class TestPassManHtpasswd: + def test_file_errors(self): tutils.raises("malformed htpasswd file", http_auth.PassManHtpasswd, tutils.test_data.path("data/server.crt")) @@ -27,6 +31,7 @@ class TestPassManHtpasswd: class TestPassManSingleUser: + def test_simple(self): pm = http_auth.PassManSingleUser("test", "test") assert pm.test("test", "test") @@ -35,6 +40,7 @@ class TestPassManSingleUser: class TestNullProxyAuth: + def test_simple(self): na = http_auth.NullProxyAuth(http_auth.PassManNonAnon()) assert not na.auth_challenge_headers() @@ -43,6 +49,7 @@ class TestNullProxyAuth: class TestBasicProxyAuth: + def test_simple(self): ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test") h = odict.ODictCaseless() @@ -60,7 +67,6 @@ class TestBasicProxyAuth: ba.clean(hdrs) assert not ba.AUTH_HEADER in hdrs - hdrs[ba.AUTH_HEADER] = [""] assert not ba.authenticate(hdrs) @@ -77,25 +83,27 @@ class TestBasicProxyAuth: assert not ba.authenticate(hdrs) -class Bunch: pass +class Bunch: + pass class TestAuthAction: + def test_nonanonymous(self): m = Bunch() aa = http_auth.NonanonymousAuthAction(None, "authenticator") aa(None, m, None, None) - assert m.authenticator + assert m.authenticator def test_singleuser(self): m = Bunch() aa = http_auth.SingleuserAuthAction(None, "authenticator") aa(None, m, "foo:bar", None) - assert m.authenticator + assert m.authenticator tutils.raises("invalid", aa, None, m, "foo", None) def test_httppasswd(self): m = Bunch() aa = http_auth.HtpasswdAuthAction(None, "authenticator") aa(None, m, tutils.test_data.path("data/htpasswd"), None) - assert m.authenticator + assert m.authenticator diff --git a/test/test_http_uastrings.py b/test/test_http_uastrings.py index c70b7048..3fa4f359 100644 --- a/test/test_http_uastrings.py +++ b/test/test_http_uastrings.py @@ -4,4 +4,3 @@ from netlib import http_uastrings def test_get_shortcut(): assert http_uastrings.get_by_shortcut("c")[0] == "chrome" assert not http_uastrings.get_by_shortcut("_") - diff --git a/test/test_odict.py b/test/test_odict.py index c01c4dbe..d66ae59b 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -3,6 +3,7 @@ import tutils class TestODict: + def setUp(self): self.od = odict.ODict() @@ -106,13 +107,13 @@ class TestODict: def test_get(self): self.od.add("one", "two") assert self.od.get("one") == ["two"] - assert self.od.get("two") == None + assert self.od.get("two") is None def test_get_first(self): self.od.add("one", "two") self.od.add("one", "three") assert self.od.get_first("one") == "two" - assert self.od.get_first("two") == None + assert self.od.get_first("two") is None def test_extend(self): a = odict.ODict([["a", "b"], ["c", "d"]]) @@ -121,7 +122,9 @@ class TestODict: assert len(a) == 4 assert a["a"] == ["b", "b"] + class TestODictCaseless: + def setUp(self): self.od = odict.ODictCaseless() diff --git a/test/test_tcp.py b/test/test_tcp.py index 4dbdd780..ef00e029 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,4 +1,8 @@ -import cStringIO, Queue, time, socket, random +import cStringIO +import Queue +import time +import socket +import random import os from netlib import tcp, certutils, test, certffi import threading @@ -6,8 +10,10 @@ import mock import tutils from OpenSSL import SSL + class EchoHandler(tcp.BaseHandler): sni = None + def handle_sni(self, connection): self.sni = connection.get_servername() @@ -19,19 +25,22 @@ class EchoHandler(tcp.BaseHandler): class ClientCipherListHandler(tcp.BaseHandler): sni = None + def handle(self): - self.wfile.write("%s"%self.connection.get_cipher_list()) + self.wfile.write("%s" % self.connection.get_cipher_list()) self.wfile.flush() class HangHandler(tcp.BaseHandler): + def handle(self): - while 1: + while True: time.sleep(1) class TestServer(test.ServerTestBase): handler = EchoHandler + def test_echo(self): testval = "echo!\n" c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -51,7 +60,9 @@ class TestServer(test.ServerTestBase): class TestServerBind(test.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): self.wfile.write(str(self.connection.getpeername())) self.wfile.flush() @@ -65,7 +76,7 @@ class TestServerBind(test.ServerTestBase): c.connect() assert c.rfile.readline() == str(("127.0.0.1", random_port)) return - except tcp.NetLibError: # port probably already in use + except tcp.NetLibError: # port probably already in use pass @@ -84,6 +95,7 @@ class TestServerIPv6(test.ServerTestBase): class TestEcho(test.ServerTestBase): handler = EchoHandler + def test_echo(self): testval = "echo!\n" c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -94,16 +106,19 @@ class TestEcho(test.ServerTestBase): class HardDisconnectHandler(tcp.BaseHandler): + def handle(self): self.connection.close() class TestFinishFail(test.ServerTestBase): + """ This tests a difficult-to-trigger exception in the .finish() method of the handler. """ handler = EchoHandler + def test_disconnect_in_finish(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -115,13 +130,14 @@ class TestFinishFail(test.ServerTestBase): class TestServerSSL(test.ServerTestBase): handler = EchoHandler ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = False, - cipher_list = "AES256-SHA", - chain_file=tutils.test_data.path("data/server.crt") - ) + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False, + cipher_list="AES256-SHA", + chain_file=tutils.test_data.path("data/server.crt") + ) + def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -144,11 +160,12 @@ class TestServerSSL(test.ServerTestBase): class TestSSLv3Only(test.ServerTestBase): handler = EchoHandler ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = True + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=True ) + def test_failure(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -156,20 +173,23 @@ class TestSSLv3Only(test.ServerTestBase): class TestSSLClientCert(test.ServerTestBase): + class handler(tcp.BaseHandler): sni = None + def handle_sni(self, connection): self.sni = connection.get_servername() def handle(self): - self.wfile.write("%s\n"%self.clientcert.serial) + self.wfile.write("%s\n" % self.clientcert.serial) self.wfile.flush() ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = True, - v3_only = False + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=True, + v3_only=False ) + def test_clientcert(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -187,8 +207,10 @@ class TestSSLClientCert(test.ServerTestBase): class TestSNI(test.ServerTestBase): + class handler(tcp.BaseHandler): sni = None + def handle_sni(self, connection): self.sni = connection.get_servername() @@ -197,11 +219,12 @@ class TestSNI(test.ServerTestBase): self.wfile.flush() ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = False + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False ) + def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -213,12 +236,13 @@ class TestSNI(test.ServerTestBase): class TestServerCipherList(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = False, - cipher_list = 'RC4-SHA' + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False, + cipher_list='RC4-SHA' ) + def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -227,18 +251,21 @@ class TestServerCipherList(test.ServerTestBase): class TestServerCurrentCipher(test.ServerTestBase): + class handler(tcp.BaseHandler): sni = None + def handle(self): - self.wfile.write("%s"%str(self.get_current_cipher())) + self.wfile.write("%s" % str(self.get_current_cipher())) self.wfile.flush() ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = False, - cipher_list = 'RC4-SHA' + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False, + cipher_list='RC4-SHA' ) + def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -249,12 +276,13 @@ class TestServerCurrentCipher(test.ServerTestBase): class TestServerCipherListError(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = False, - cipher_list = 'bogus' + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False, + cipher_list='bogus' ) + def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -264,12 +292,13 @@ class TestServerCipherListError(test.ServerTestBase): class TestClientCipherListError(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = False, - cipher_list = 'RC4-SHA' + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False, + cipher_list='RC4-SHA' ) + def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -277,15 +306,18 @@ class TestClientCipherListError(test.ServerTestBase): class TestSSLDisconnect(test.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): self.finish() ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = False + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False ) + def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -300,11 +332,12 @@ class TestSSLDisconnect(test.ServerTestBase): class TestSSLHardDisconnect(test.ServerTestBase): handler = HardDisconnectHandler ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = False + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False ) + def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -316,6 +349,7 @@ class TestSSLHardDisconnect(test.ServerTestBase): class TestDisconnect(test.ServerTestBase): + def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -326,7 +360,9 @@ class TestDisconnect(test.ServerTestBase): class TestServerTimeOut(test.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): self.timeout = False self.settimeout(0.01) @@ -344,6 +380,7 @@ class TestServerTimeOut(test.ServerTestBase): class TestTimeOut(test.ServerTestBase): handler = HangHandler + def test_timeout(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -355,11 +392,12 @@ class TestTimeOut(test.ServerTestBase): class TestSSLTimeOut(test.ServerTestBase): handler = HangHandler ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = False + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False ) + def test_timeout_client(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -371,15 +409,16 @@ class TestSSLTimeOut(test.ServerTestBase): class TestDHParams(test.ServerTestBase): handler = HangHandler ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = False, - dhparams = certutils.CertStore.load_dhparam( + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False, + dhparams=certutils.CertStore.load_dhparam( tutils.test_data.path("data/dhparam.pem"), ), - cipher_list = "DHE-RSA-AES256-SHA" + cipher_list="DHE-RSA-AES256-SHA" ) + def test_dhparams(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() @@ -395,7 +434,9 @@ class TestDHParams(test.ServerTestBase): class TestPrivkeyGen(test.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): with tutils.tmpdir() as d: ca1 = certutils.CertStore.from_store(d, "test2") @@ -411,7 +452,9 @@ class TestPrivkeyGen(test.ServerTestBase): class TestPrivkeyGenNoFlags(test.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): with tutils.tmpdir() as d: ca1 = certutils.CertStore.from_store(d, "test2") @@ -426,14 +469,15 @@ class TestPrivkeyGenNoFlags(test.ServerTestBase): tutils.raises("sslv3 alert handshake failure", c.convert_to_ssl) - class TestTCPClient: + def test_conerr(self): c = tcp.TCPClient(("127.0.0.1", 0)) tutils.raises(tcp.NetLibError, c.connect) class TestFileLike: + def test_blocksize(self): s = cStringIO.StringIO("1234567890abcdefghijklmnopqrstuvwxyz") s = tcp.Reader(s) @@ -460,7 +504,7 @@ class TestFileLike: assert s.readline(3) == "foo" def test_limitless(self): - s = cStringIO.StringIO("f"*(50*1024)) + s = cStringIO.StringIO("f" * (50 * 1024)) s = tcp.Reader(s) ret = s.read(-1) assert len(ret) == 50 * 1024 @@ -551,7 +595,9 @@ class TestFileLike: s = tcp.Reader(o) tutils.raises(tcp.NetLibDisconnect, s.readline, 10) + class TestAddress: + def test_simple(self): a = tcp.Address("localhost", True) assert a.use_ipv6 @@ -566,12 +612,12 @@ class TestAddress: class TestSSLKeyLogger(test.ServerTestBase): handler = EchoHandler ssl = dict( - cert = tutils.test_data.path("data/server.crt"), - key = tutils.test_data.path("data/server.key"), - request_client_cert = False, - v3_only = False, - cipher_list = "AES256-SHA" - ) + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False, + cipher_list="AES256-SHA" + ) def test_log(self): testval = "echo!\n" @@ -597,4 +643,4 @@ class TestSSLKeyLogger(test.ServerTestBase): def test_create_logfun(self): assert isinstance(tcp.SSLKeyLogger.create_logfun("test"), tcp.SSLKeyLogger) - assert not tcp.SSLKeyLogger.create_logfun(False) \ No newline at end of file + assert not tcp.SSLKeyLogger.create_logfun(False) diff --git a/test/test_utils.py b/test/test_utils.py index 942136fd..8e66bce4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -12,7 +12,7 @@ def test_bidi(): def test_hexdump(): - assert utils.hexdump("one\0"*10) + assert utils.hexdump("one\0" * 10) def test_cleanBin(): @@ -25,5 +25,5 @@ def test_cleanBin(): def test_pretty_size(): assert utils.pretty_size(100) == "100B" assert utils.pretty_size(1024) == "1kB" - assert utils.pretty_size(1024 + (1024/2.0)) == "1.5kB" - assert utils.pretty_size(1024*1024) == "1MB" + assert utils.pretty_size(1024 + (1024 / 2.0)) == "1.5kB" + assert utils.pretty_size(1024 * 1024) == "1MB" diff --git a/test/test_websockets.py b/test/test_websockets.py index 7bd5d74e..38947295 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -7,6 +7,7 @@ import tutils class WebSocketsEchoHandler(tcp.BaseHandler): + def __init__(self, connection, address, server): super(WebSocketsEchoHandler, self).__init__( connection, address, server @@ -25,7 +26,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.on_message(frame.payload) def send_message(self, message): - frame = websockets.Frame.default(message, from_client = False) + frame = websockets.Frame.default(message, from_client=False) frame.to_file(self.wfile) def handshake(self): @@ -44,6 +45,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): class WebSocketsClient(tcp.TCPClient): + def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) self.client_nonce = None @@ -68,14 +70,14 @@ class WebSocketsClient(tcp.TCPClient): return websockets.Frame.from_file(self.rfile).payload def send_message(self, message): - frame = websockets.Frame.default(message, from_client = True) + frame = websockets.Frame.default(message, from_client=True) frame.to_file(self.wfile) class TestWebSockets(test.ServerTestBase): handler = WebSocketsEchoHandler - def random_bytes(self, n = 100): + def random_bytes(self, n=100): return os.urandom(n) def echo(self, msg): @@ -105,8 +107,8 @@ class TestWebSockets(test.ServerTestBase): default builder should always generate valid frames """ msg = self.random_bytes() - client_frame = websockets.Frame.default(msg, from_client = True) - server_frame = websockets.Frame.default(msg, from_client = False) + client_frame = websockets.Frame.default(msg, from_client=True) + server_frame = websockets.Frame.default(msg, from_client=False) def test_serialization_bijection(self): """ @@ -140,6 +142,7 @@ class TestWebSockets(test.ServerTestBase): class BadHandshakeHandler(WebSocketsEchoHandler): + def handshake(self): client_hs = http.read_request(self.rfile) websockets.check_client_handshake(client_hs.headers) @@ -152,6 +155,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler): class TestBadHandshake(test.ServerTestBase): + """ Ensure that the client disconnects if the server handshake is malformed """ @@ -165,6 +169,7 @@ class TestBadHandshake(test.ServerTestBase): class TestFrameHeader: + def test_roundtrip(self): def round(*args, **kwargs): f = websockets.FrameHeader(*args, **kwargs) @@ -216,6 +221,7 @@ class TestFrameHeader: class TestFrame: + def test_roundtrip(self): def round(*args, **kwargs): f = websockets.Frame(*args, **kwargs) @@ -240,7 +246,7 @@ def test_masker(): ["fourf"], ["fourfive"], ["a", "aasdfasdfa", "asdf"], - ["a"*50, "aasdfasdfa", "asdf"], + ["a" * 50, "aasdfasdfa", "asdf"], ] for i in tests: m = websockets.Masker("abcd") diff --git a/test/test_wsgi.py b/test/test_wsgi.py index 1c8c5263..68a47769 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -1,4 +1,5 @@ -import cStringIO, sys +import cStringIO +import sys from netlib import wsgi, odict @@ -10,6 +11,7 @@ def tflow(): class TestApp: + def __init__(self): self.called = False @@ -22,6 +24,7 @@ class TestApp: class TestWSGI: + def test_make_environ(self): w = wsgi.WSGIAdaptor(None, "foo", 80, "version") tf = tflow() diff --git a/test/tutils.py b/test/tutils.py index 141979f8..95c8b80a 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -44,14 +44,14 @@ def raises(exc, obj, *args, **kwargs): :kwargs Arguments to be passed to the callable. """ try: - ret = apply(obj, args, kwargs) - except Exception, v: + ret = obj(*args, **kwargs) + except Exception as v: if isinstance(exc, basestring): if exc.lower() in str(v).lower(): return else: raise AssertionError( - "Expected %s, but caught %s"%( + "Expected %s, but caught %s" % ( repr(str(exc)), v ) ) @@ -60,7 +60,7 @@ def raises(exc, obj, *args, **kwargs): return else: raise AssertionError( - "Expected %s, but caught %s %s"%( + "Expected %s, but caught %s %s" % ( exc.__name__, v.__class__.__name__, str(v) ) ) diff --git a/tools/getcertnames b/tools/getcertnames index d22f4980..e33619f7 100755 --- a/tools/getcertnames +++ b/tools/getcertnames @@ -22,6 +22,6 @@ else: cert = get_remote_cert(sys.argv[1], port, sni) print "CN:", cert.cn if cert.altnames: - print "SANs:", + print "SANs:", for i in cert.altnames: print "\t", i -- cgit v1.2.3 From 161bc2cfaa8b70b4c2cab5562784df34013452e1 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 11:25:33 +0200 Subject: cleanup code with autoflake run the following command: $ autoflake -r -i --remove-all-unused-imports --remove-unused-variables . --- netlib/h2/frame.py | 6 +----- netlib/h2/h2.py | 5 ----- netlib/http_auth.py | 1 - netlib/http_cookies.py | 1 - netlib/tcp.py | 2 -- test/test_certutils.py | 7 +++---- test/test_http_auth.py | 5 +---- test/test_http_cookies.py | 3 +-- test/test_imports.py | 2 -- test/test_wsgi.py | 4 ++-- 10 files changed, 8 insertions(+), 28 deletions(-) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index d846b3b9..a7e81f48 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -1,10 +1,6 @@ -import base64 -import hashlib -import os import struct -import io -from .. import utils, odict, tcp +from .. import utils from functools import reduce diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index 1a39a635..7a85226f 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -1,8 +1,3 @@ -import base64 -import hashlib -import os -import struct -import io # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 261b6654..0143760c 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -16,7 +16,6 @@ class NullProxyAuth(object): """ Clean up authentication headers, so they're not passed upstream. """ - pass def authenticate(self, headers): """ diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 73e3f589..5cb39e5c 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -67,7 +67,6 @@ def _read_quoted_string(s, start): break elif s[i] == "\\": escaping = True - pass else: ret.append(s[i]) return "".join(ret), i + 1 diff --git a/netlib/tcp.py b/netlib/tcp.py index 7c115554..49f92e4a 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -7,7 +7,6 @@ import threading import time import traceback from OpenSSL import SSL -import OpenSSL from . import certutils @@ -650,4 +649,3 @@ class TCPServer(object): """ Called after server shutdown. """ - pass diff --git a/test/test_certutils.py b/test/test_certutils.py index 4af0197f..115cac4d 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -1,6 +1,5 @@ import os from netlib import certutils, certffi -import OpenSSL import tutils # class TestDNTree: @@ -57,13 +56,13 @@ class TestCertStore: def test_add_cert(self): with tutils.tmpdir() as d: - ca = certutils.CertStore.from_store(d, "test") + certutils.CertStore.from_store(d, "test") def test_sans(self): with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") c1 = ca.get_cert("foo.com", ["*.bar.com"]) - c2 = ca.get_cert("foo.bar.com", []) + ca.get_cert("foo.bar.com", []) # assert c1 == c2 c3 = ca.get_cert("bar.com", []) assert not c1 == c3 @@ -71,7 +70,7 @@ class TestCertStore: def test_sans_change(self): with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") - _ = ca.get_cert("foo.com", ["*.bar.com"]) + ca.get_cert("foo.com", ["*.bar.com"]) cert, key, chain_file = ca.get_cert("foo.bar.com", ["*.baz.com"]) assert "*.baz.com" in cert.altnames diff --git a/test/test_http_auth.py b/test/test_http_auth.py index 25df5410..045fb13e 100644 --- a/test/test_http_auth.py +++ b/test/test_http_auth.py @@ -1,7 +1,4 @@ -import binascii -import cStringIO from netlib import odict, http_auth, http -import mock import tutils @@ -22,7 +19,7 @@ class TestPassManHtpasswd: pm = http_auth.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) vals = ("basic", "test", "test") - p = http.assemble_http_basic_auth(*vals) + http.assemble_http_basic_auth(*vals) assert pm.test("test", "test") assert not pm.test("test", "foo") assert not pm.test("foo", "test") diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py index 7438af7c..070849cf 100644 --- a/test/test_http_cookies.py +++ b/test/test_http_cookies.py @@ -1,7 +1,6 @@ -import pprint import nose.tools -from netlib import http_cookies, odict +from netlib import http_cookies def test_read_token(): diff --git a/test/test_imports.py b/test/test_imports.py index 7b8a643b..b88ef26d 100644 --- a/test/test_imports.py +++ b/test/test_imports.py @@ -1,3 +1 @@ # These are actually tests! -import netlib.http_status -import netlib.version diff --git a/test/test_wsgi.py b/test/test_wsgi.py index 68a47769..41572d49 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -56,7 +56,7 @@ class TestWSGI: f.request.host = "foo" f.request.port = 80 wfile = cStringIO.StringIO() - err = w.serve(f, wfile) + w.serve(f, wfile) return wfile.getvalue() def test_serve_empty_body(self): @@ -72,7 +72,7 @@ class TestWSGI: try: raise ValueError("foo") except: - ei = sys.exc_info() + sys.exc_info() status = '200 OK' response_headers = [('Content-type', 'text/plain')] start_response(status, response_headers) -- cgit v1.2.3 From 80378306960379f12aca72309dc47437cd1a825c Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 11:34:14 +0200 Subject: add pep8 autoformat checks to travis --- .gitignore | 3 ++- .travis.yml | 8 +++++--- setup.py | 2 ++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index ef830f75..70059e0f 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ MANIFEST .coverage .idea/ __pycache__ -netlib.egg-info/ \ No newline at end of file +netlib.egg-info/ +pathod/ diff --git a/.travis.yml b/.travis.yml index aac6b272..77ba0bcf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,9 +7,11 @@ python: install: - "pip install --src . -r requirements.txt" # command to run tests, e.g. python setup.py test -script: +script: - "nosetests --with-cov --cov-report term-missing" -after_success: + - "autopep8 -i -r -a -a . && test -z \"$(git status -s)\"" + - "autoflake -r -i --remove-all-unused-imports --remove-unused-variables . && test -z \"$(git status -s)\"" +after_success: - coveralls notifications: irc: @@ -22,4 +24,4 @@ cache: - /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages - /home/travis/virtualenv/python2.7.9/bin - /home/travis/virtualenv/pypy-2.5.0/site-packages - - /home/travis/virtualenv/pypy-2.5.0/bin \ No newline at end of file + - /home/travis/virtualenv/pypy-2.5.0/bin diff --git a/setup.py b/setup.py index 86a55c4c..3680889b 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,8 @@ setup( "nose>=1.3.0", "nose-cov>=1.6", "coveralls>=0.4.1", + "autopep8>=1.0.3", + "autoflake>=0.6.6", "pathod>=%s, <%s" % (version.MINORVERSION, version.NEXT_MINORVERSION) ] } -- cgit v1.2.3 From 1dda164d0381161d3d0ad4e65199f6382aa2bf0d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 28 May 2015 12:18:56 +1200 Subject: Satisfy autobots. --- netlib/certutils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 05408a0c..abf1a28b 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -42,7 +42,7 @@ def create_ca(o, cn, exp): cert.set_pubkey(key) cert.add_extensions([ OpenSSL.crypto.X509Extension( - "basicConstraints", + "basicConstraints", True, "CA:TRUE" ), @@ -155,6 +155,7 @@ class CertStore(object): """ Implements an in-memory certificate store. """ + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams): self.default_privatekey = default_privatekey self.default_ca = default_ca -- cgit v1.2.3 From e805f2d06609a297391e4486f9a8e5394bac5435 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 29 May 2015 11:22:31 +0200 Subject: improve travis coding style checks --- .gitignore | 1 + .travis.yml | 3 +-- check_coding_style.sh | 20 ++++++++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) create mode 100755 check_coding_style.sh diff --git a/.gitignore b/.gitignore index 70059e0f..68d71ab6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,5 +9,6 @@ MANIFEST .coverage .idea/ __pycache__ +_cffi__* netlib.egg-info/ pathod/ diff --git a/.travis.yml b/.travis.yml index 77ba0bcf..a1eafcea 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,8 +9,7 @@ install: # command to run tests, e.g. python setup.py test script: - "nosetests --with-cov --cov-report term-missing" - - "autopep8 -i -r -a -a . && test -z \"$(git status -s)\"" - - "autoflake -r -i --remove-all-unused-imports --remove-unused-variables . && test -z \"$(git status -s)\"" + - "./check_coding_style.sh" after_success: - coveralls notifications: diff --git a/check_coding_style.sh b/check_coding_style.sh new file mode 100755 index 00000000..5b38e003 --- /dev/null +++ b/check_coding_style.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +autopep8 -i -r -a -a . +if [[ -n "$(git status -s)" ]]; then + echo "autopep8 yielded the following changes:" + git status -s + git --no-pager diff + exit 1 +fi + +autoflake -i -r --remove-all-unused-imports --remove-unused-variables . +if [[ -n "$(git status -s)" ]]; then + echo "autoflake yielded the following changes:" + git status -s + git --no-pager diff + exit 1 +fi + +echo "Coding style seems to be ok." +exit 0 -- cgit v1.2.3 From bdb62101bbbd4babc3099dd71424f85676866161 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 28 May 2015 17:45:54 +0200 Subject: test Address __str__ --- test/test_tcp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_tcp.py b/test/test_tcp.py index ef00e029..2bf492fa 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -603,6 +603,7 @@ class TestAddress: assert a.use_ipv6 b = tcp.Address("foo.com", True) assert not a == b + assert str(b) == str(tuple("foo.com")) c = tcp.Address("localhost", True) assert a == c assert not a != c -- cgit v1.2.3 From 5288aa36403bc4b350700a0bf97adc4413f2a398 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 12:58:55 +0200 Subject: add human_readable() to each frame for debugging --- netlib/h2/frame.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++- test/h2/test_frames.py | 48 ++++++++++++++++++++++---------- 2 files changed, 108 insertions(+), 15 deletions(-) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index a7e81f48..51de7d4d 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -25,6 +25,7 @@ class Frame(object): raise ValueError('invalid flags detected.') self.length = length + self.type = self.TYPE self.flags = flags self.stream_id = stream_id @@ -49,10 +50,27 @@ class Frame(object): return b + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self): + return "\n".join([ + "============================================================", + "length: %d bytes" % self.length, + "type: %s (%#x)" % (self.__class__.__name__, self.TYPE), + "flags: %#x" % self.flags, + "stream_id: %#x" % self.stream_id, + "------------------------------------------------------------", + self.payload_human_readable(), + "============================================================", + ]) + def __eq__(self, other): return self.to_bytes() == other.to_bytes() - class DataFrame(Frame): TYPE = 0x0 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] @@ -89,6 +107,8 @@ class DataFrame(Frame): return b + def payload_human_readable(self): + return "payload: %s" % str(self.payload) class HeadersFrame(Frame): TYPE = 0x1 @@ -139,6 +159,19 @@ class HeadersFrame(Frame): return b + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("header_block_fragment: %s" % str(self.header_block_fragment)) + return "\n".join(s) class PriorityFrame(Frame): TYPE = 0x2 @@ -169,6 +202,12 @@ class PriorityFrame(Frame): return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) class RstStreamFrame(Frame): TYPE = 0x3 @@ -190,6 +229,8 @@ class RstStreamFrame(Frame): return struct.pack('!L', self.error_code) + def payload_human_readable(self): + return "error code: %#x" % self.error_code class SettingsFrame(Frame): TYPE = 0x4 @@ -228,6 +269,16 @@ class SettingsFrame(Frame): return b + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) class PushPromiseFrame(Frame): TYPE = 0x5 @@ -273,6 +324,15 @@ class PushPromiseFrame(Frame): return b + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append("header_block_fragment: %s" % str(self.header_block_fragment)) + return "\n".join(s) class PingFrame(Frame): TYPE = 0x6 @@ -296,6 +356,8 @@ class PingFrame(Frame): b += b'\0' * (8 - len(b)) return b + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) class GoAwayFrame(Frame): TYPE = 0x7 @@ -325,6 +387,12 @@ class GoAwayFrame(Frame): b += bytes(self.data) return b + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) class WindowUpdateFrame(Frame): TYPE = 0x8 @@ -349,6 +417,8 @@ class WindowUpdateFrame(Frame): return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment class ContinuationFrame(Frame): TYPE = 0x9 @@ -370,6 +440,9 @@ class ContinuationFrame(Frame): return self.header_block_fragment + def payload_human_readable(self): + return "header_block_fragment: %s" % str(self.header_block_fragment) + _FRAME_CLASSES = [ DataFrame, HeadersFrame, diff --git a/test/h2/test_frames.py b/test/h2/test_frames.py index 90162984..eb6e2a60 100644 --- a/test/h2/test_frames.py +++ b/test/h2/test_frames.py @@ -3,23 +3,19 @@ import tutils from nose.tools import assert_equal - # TODO test stream association if valid or not def test_invalid_flags(): tutils.raises(ValueError, DataFrame, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') - def test_frame_equality(): a = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') b = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') assert_equal(a, b) - def test_too_large_frames(): DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567) - def test_data_frame_to_bytes(): f = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') @@ -30,7 +26,6 @@ def test_data_frame_to_bytes(): f = DataFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') tutils.raises(ValueError, f.to_bytes) - def test_data_frame_from_bytes(): f = Frame.from_bytes('000006000101234567666f6f626172'.decode('hex')) assert isinstance(f, DataFrame) @@ -48,6 +43,9 @@ def test_data_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.payload, 'foobar') +def test_data_frame_human_readable(): + f = DataFrame(11, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED, 0x1234567, 'foobar', pad_length=3) + assert f.human_readable() def test_headers_frame_to_bytes(): f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x1234567, 'foobar') @@ -70,7 +68,6 @@ def test_headers_frame_to_bytes(): f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') tutils.raises(ValueError, f.to_bytes) - def test_headers_frame_from_bytes(): f = Frame.from_bytes('000006010001234567666f6f626172'.decode('hex')) assert isinstance(f, HeadersFrame) @@ -121,6 +118,9 @@ def test_headers_frame_from_bytes(): assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) +def test_headers_frame_human_readable(): + f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', pad_length=3, exclusive=False, stream_dependency=0x7654321, weight=42) + assert f.human_readable() def test_priority_frame_to_bytes(): f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, exclusive=True, stream_dependency=0x7654321, weight=42) @@ -135,7 +135,6 @@ def test_priority_frame_to_bytes(): f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, stream_dependency=0x0) tutils.raises(ValueError, f.to_bytes) - def test_priority_frame_from_bytes(): f = Frame.from_bytes('000005020001234567876543212a'.decode('hex')) assert isinstance(f, PriorityFrame) @@ -157,6 +156,9 @@ def test_priority_frame_from_bytes(): assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 21) +def test_priority_frame_human_readable(): + f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, exclusive=False, stream_dependency=0x7654321, weight=21) + assert f.human_readable() def test_rst_stream_frame_to_bytes(): f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, error_code=0x7654321) @@ -165,7 +167,6 @@ def test_rst_stream_frame_to_bytes(): f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x0) tutils.raises(ValueError, f.to_bytes) - def test_rst_stream_frame_from_bytes(): f = Frame.from_bytes('00000403000123456707654321'.decode('hex')) assert isinstance(f, RstStreamFrame) @@ -175,6 +176,9 @@ def test_rst_stream_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.error_code, 0x07654321) +def test_rst_stream_frame_human_readable(): + f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, error_code=0x7654321) + assert f.human_readable() def test_settings_frame_to_bytes(): f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x0) @@ -193,7 +197,6 @@ def test_settings_frame_to_bytes(): f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x1234567) tutils.raises(ValueError, f.to_bytes) - def test_settings_frame_from_bytes(): f = Frame.from_bytes('000000040000000000'.decode('hex')) assert isinstance(f, SettingsFrame) @@ -228,6 +231,12 @@ def test_settings_frame_from_bytes(): assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], 0x12345678) +def test_settings_frame_human_readable(): + f = SettingsFrame(12, Frame.FLAG_NO_FLAGS, 0x0, settings={}) + assert f.human_readable() + + f = SettingsFrame(12, Frame.FLAG_NO_FLAGS, 0x0, settings={SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + assert f.human_readable() def test_push_promise_frame_to_bytes(): f = PushPromiseFrame(10, Frame.FLAG_NO_FLAGS, 0x1234567, 0x7654321, 'foobar') @@ -242,7 +251,6 @@ def test_push_promise_frame_to_bytes(): f = PushPromiseFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, 0x0) tutils.raises(ValueError, f.to_bytes) - def test_push_promise_frame_from_bytes(): f = Frame.from_bytes('00000a05000123456707654321666f6f626172'.decode('hex')) assert isinstance(f, PushPromiseFrame) @@ -260,6 +268,9 @@ def test_push_promise_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.header_block_fragment, 'foobar') +def test_push_promise_frame_human_readable(): + f = PushPromiseFrame(14, HeadersFrame.FLAG_PADDED, 0x1234567, 0x7654321, 'foobar', pad_length=3) + assert f.human_readable() def test_ping_frame_to_bytes(): f = PingFrame(8, PingFrame.FLAG_ACK, 0x0, payload=b'foobar') @@ -271,7 +282,6 @@ def test_ping_frame_to_bytes(): f = PingFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567) tutils.raises(ValueError, f.to_bytes) - def test_ping_frame_from_bytes(): f = Frame.from_bytes('000008060100000000666f6f6261720000'.decode('hex')) assert isinstance(f, PingFrame) @@ -289,6 +299,9 @@ def test_ping_frame_from_bytes(): assert_equal(f.stream_id, 0x0) assert_equal(f.payload, b'foobarde') +def test_ping_frame_human_readable(): + f = PingFrame(8, PingFrame.FLAG_ACK, 0x0, payload=b'foobar') + assert f.human_readable() def test_goaway_frame_to_bytes(): f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x0, last_stream=0x1234567, error_code=0x87654321, data=b'') @@ -300,7 +313,6 @@ def test_goaway_frame_to_bytes(): f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567, last_stream=0x1234567, error_code=0x87654321) tutils.raises(ValueError, f.to_bytes) - def test_goaway_frame_from_bytes(): f = Frame.from_bytes('0000080700000000000123456787654321'.decode('hex')) assert isinstance(f, GoAwayFrame) @@ -322,6 +334,9 @@ def test_goaway_frame_from_bytes(): assert_equal(f.error_code, 0x87654321) assert_equal(f.data, b'foobar') +def test_go_away_frame_human_readable(): + f = GoAwayFrame(14, Frame.FLAG_NO_FLAGS, 0x0, last_stream=0x1234567, error_code=0x87654321, data=b'foobar') + assert f.human_readable() def test_window_update_frame_to_bytes(): f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0x1234567) @@ -336,7 +351,6 @@ def test_window_update_frame_to_bytes(): f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) tutils.raises(ValueError, f.to_bytes) - def test_window_update_frame_from_bytes(): f = Frame.from_bytes('00000408000000000001234567'.decode('hex')) assert isinstance(f, WindowUpdateFrame) @@ -346,6 +360,9 @@ def test_window_update_frame_from_bytes(): assert_equal(f.stream_id, 0x0) assert_equal(f.window_size_increment, 0x1234567) +def test_window_update_frame_human_readable(): + f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, window_size_increment=0x7654321) + assert f.human_readable() def test_continuation_frame_to_bytes(): f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') @@ -354,7 +371,6 @@ def test_continuation_frame_to_bytes(): f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x0, 'foobar') tutils.raises(ValueError, f.to_bytes) - def test_continuation_frame_from_bytes(): f = Frame.from_bytes('000006090401234567666f6f626172'.decode('hex')) assert isinstance(f, ContinuationFrame) @@ -363,3 +379,7 @@ def test_continuation_frame_from_bytes(): assert_equal(f.flags, ContinuationFrame.FLAG_END_HEADERS) assert_equal(f.stream_id, 0x1234567) assert_equal(f.header_block_fragment, 'foobar') + +def test_continuation_frame_human_readable(): + f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') + assert f.human_readable() -- cgit v1.2.3 From 754f929187e3954eb05971e38bcd3358d3a5e3be Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 17:31:18 +0200 Subject: fix default argument Python evaluates default args during method definition. So you get the same dict each time you call this method. Therefore the dict is the SAME actual object each time. --- netlib/h2/frame.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 51de7d4d..ed6af200 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -245,8 +245,12 @@ class SettingsFrame(Frame): SETTINGS_MAX_HEADER_LIST_SIZE=0x6, ) - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings={}): + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings=None): super(SettingsFrame, self).__init__(length, flags, stream_id) + + if settings is None: + settings = {} + self.settings = settings @classmethod -- cgit v1.2.3 From 4c469fdee1b5b01a7e847a75fbbd902dc3bfbd70 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 17:53:06 +0200 Subject: add hpack to encode and decode headers --- netlib/h2/frame.py | 41 +++++++++--- setup.py | 3 +- test/h2/test_frames.py | 169 ++++++++++++++++++++++++++++++++++++++----------- 3 files changed, 167 insertions(+), 46 deletions(-) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index ed6af200..179634b0 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -1,4 +1,6 @@ import struct +import io +from hpack.hpack import Encoder, Decoder from .. import utils from functools import reduce @@ -71,6 +73,7 @@ class Frame(object): def __eq__(self, other): return self.to_bytes() == other.to_bytes() + class DataFrame(Frame): TYPE = 0x0 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] @@ -110,14 +113,18 @@ class DataFrame(Frame): def payload_human_readable(self): return "payload: %s" % str(self.payload) + class HeadersFrame(Frame): TYPE = 0x1 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', - pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): + def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, headers=None, pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): super(HeadersFrame, self).__init__(length, flags, stream_id) - self.header_block_fragment = header_block_fragment + + if headers is None: + headers = [] + + self.headers = headers self.pad_length = pad_length self.exclusive = exclusive self.stream_dependency = stream_dependency @@ -129,15 +136,18 @@ class HeadersFrame(Frame): if f.flags & self.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] + header_block_fragment = payload[1:-f.pad_length] else: - f.header_block_fragment = payload[0:] + header_block_fragment = payload[0:] if f.flags & self.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack('!LB', f.header_block_fragment[:5]) + f.stream_dependency, f.weight = struct.unpack('!LB', header_block_fragment[:5]) f.exclusive = bool(f.stream_dependency >> 31) f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] + header_block_fragment = header_block_fragment[5:] + + for header, value in Decoder().decode(header_block_fragment): + f.headers.append((header, value)) return f @@ -152,7 +162,7 @@ class HeadersFrame(Frame): if self.flags & self.FLAG_PRIORITY: b += struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) - b += bytes(self.header_block_fragment) + b += Encoder().encode(self.headers) if self.flags & self.FLAG_PADDED: b += b'\0' * self.pad_length @@ -170,9 +180,15 @@ class HeadersFrame(Frame): if self.flags & self.FLAG_PADDED: s.append("padding: %d" % self.pad_length) - s.append("header_block_fragment: %s" % str(self.header_block_fragment)) + if not self.headers: + s.append("headers: None") + else: + for header, value in self.headers: + s.append("%s: %s" % (header, value)) + return "\n".join(s) + class PriorityFrame(Frame): TYPE = 0x2 VALID_FLAGS = [] @@ -209,6 +225,7 @@ class PriorityFrame(Frame): s.append("weight: %d" % self.weight) return "\n".join(s) + class RstStreamFrame(Frame): TYPE = 0x3 VALID_FLAGS = [] @@ -232,6 +249,7 @@ class RstStreamFrame(Frame): def payload_human_readable(self): return "error code: %#x" % self.error_code + class SettingsFrame(Frame): TYPE = 0x4 VALID_FLAGS = [Frame.FLAG_ACK] @@ -284,6 +302,7 @@ class SettingsFrame(Frame): else: return "\n".join(s) + class PushPromiseFrame(Frame): TYPE = 0x5 VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] @@ -338,6 +357,7 @@ class PushPromiseFrame(Frame): s.append("header_block_fragment: %s" % str(self.header_block_fragment)) return "\n".join(s) + class PingFrame(Frame): TYPE = 0x6 VALID_FLAGS = [Frame.FLAG_ACK] @@ -363,6 +383,7 @@ class PingFrame(Frame): def payload_human_readable(self): return "opaque data: %s" % str(self.payload) + class GoAwayFrame(Frame): TYPE = 0x7 VALID_FLAGS = [] @@ -398,6 +419,7 @@ class GoAwayFrame(Frame): s.append("debug data: %s" % str(self.data)) return "\n".join(s) + class WindowUpdateFrame(Frame): TYPE = 0x8 VALID_FLAGS = [] @@ -424,6 +446,7 @@ class WindowUpdateFrame(Frame): def payload_human_readable(self): return "window size increment: %#x" % self.window_size_increment + class ContinuationFrame(Frame): TYPE = 0x9 VALID_FLAGS = [Frame.FLAG_END_HEADERS] diff --git a/setup.py b/setup.py index 3680889b..450e9822 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,8 @@ setup( "pyasn1>=0.1.7", "pyOpenSSL>=0.15.1", "cryptography>=0.9", - "passlib>=1.6.2" + "passlib>=1.6.2", + "hpack>=1.0.1" ], extras_require={ 'dev': [ diff --git a/test/h2/test_frames.py b/test/h2/test_frames.py index eb6e2a60..eb470dd4 100644 --- a/test/h2/test_frames.py +++ b/test/h2/test_frames.py @@ -5,17 +5,21 @@ from nose.tools import assert_equal # TODO test stream association if valid or not + def test_invalid_flags(): tutils.raises(ValueError, DataFrame, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') + def test_frame_equality(): a = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') b = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') assert_equal(a, b) + def test_too_large_frames(): DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567) + def test_data_frame_to_bytes(): f = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') @@ -26,6 +30,7 @@ def test_data_frame_to_bytes(): f = DataFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') tutils.raises(ValueError, f.to_bytes) + def test_data_frame_from_bytes(): f = Frame.from_bytes('000006000101234567666f6f626172'.decode('hex')) assert isinstance(f, DataFrame) @@ -43,85 +48,139 @@ def test_data_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.payload, 'foobar') + def test_data_frame_human_readable(): f = DataFrame(11, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED, 0x1234567, 'foobar', pad_length=3) assert f.human_readable() -def test_headers_frame_to_bytes(): - f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x1234567, 'foobar') - assert_equal(f.to_bytes().encode('hex'), '000006010001234567666f6f626172') - f = HeadersFrame(10, HeadersFrame.FLAG_PADDED, 0x1234567, 'foobar', pad_length=3) - assert_equal(f.to_bytes().encode('hex'), '00000a01080123456703666f6f626172000000') - - f = HeadersFrame(10, HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', exclusive=True, stream_dependency=0x7654321, weight=42) - assert_equal(f.to_bytes().encode('hex'), '00000b012001234567876543212a666f6f626172') - - f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, - 'foobar', pad_length=3, exclusive=True, stream_dependency=0x7654321, weight=42) - assert_equal(f.to_bytes().encode('hex'), '00000f01280123456703876543212a666f6f626172000000') - - f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', - pad_length=3, exclusive=False, stream_dependency=0x7654321, weight=42) - assert_equal(f.to_bytes().encode('hex'), '00000f01280123456703076543212a666f6f626172000000') +def test_headers_frame_to_bytes(): + f = HeadersFrame( + 6, + Frame.FLAG_NO_FLAGS, + 0x1234567, + headers=[('host', 'foo.bar')]) + assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') + + f = HeadersFrame( + 10, + HeadersFrame.FLAG_PADDED, + 0x1234567, + headers=[('host', 'foo.bar')], + pad_length=3) + assert_equal(f.to_bytes().encode('hex'), '00000b01080123456703668594e75e31d9000000') + + f = HeadersFrame( + 10, + HeadersFrame.FLAG_PRIORITY, + 0x1234567, + headers=[('host', 'foo.bar')], + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal(f.to_bytes().encode('hex'), '00000c012001234567876543212a668594e75e31d9') + + f = HeadersFrame( + 14, + HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, + 0x1234567, + headers=[('host', 'foo.bar')], + pad_length=3, + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal(f.to_bytes().encode('hex'), '00001001280123456703876543212a668594e75e31d9000000') + + f = HeadersFrame( + 14, + HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, + 0x1234567, + headers=[('host', 'foo.bar')], + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert_equal(f.to_bytes().encode('hex'), '00001001280123456703076543212a668594e75e31d9000000') f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') tutils.raises(ValueError, f.to_bytes) + def test_headers_frame_from_bytes(): - f = Frame.from_bytes('000006010001234567666f6f626172'.decode('hex')) + f = Frame.from_bytes('000007010001234567668594e75e31d9'.decode('hex')) assert isinstance(f, HeadersFrame) - assert_equal(f.length, 6) + assert_equal(f.length, 7) assert_equal(f.TYPE, HeadersFrame.TYPE) assert_equal(f.flags, Frame.FLAG_NO_FLAGS) assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') + assert_equal(f.headers, [('host', 'foo.bar')]) - f = Frame.from_bytes('00000a01080123456703666f6f626172000000'.decode('hex')) + f = Frame.from_bytes('00000b01080123456703668594e75e31d9000000'.decode('hex')) assert isinstance(f, HeadersFrame) - assert_equal(f.length, 10) + assert_equal(f.length, 11) assert_equal(f.TYPE, HeadersFrame.TYPE) assert_equal(f.flags, HeadersFrame.FLAG_PADDED) assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') + assert_equal(f.headers, [('host', 'foo.bar')]) - f = Frame.from_bytes('00000b012001234567876543212a666f6f626172'.decode('hex')) + f = Frame.from_bytes('00000c012001234567876543212a668594e75e31d9'.decode('hex')) assert isinstance(f, HeadersFrame) - assert_equal(f.length, 11) + assert_equal(f.length, 12) assert_equal(f.TYPE, HeadersFrame.TYPE) assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') + assert_equal(f.headers, [('host', 'foo.bar')]) assert_equal(f.exclusive, True) assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) - f = Frame.from_bytes('00000f01280123456703876543212a666f6f626172000000'.decode('hex')) + f = Frame.from_bytes('00001001280123456703876543212a668594e75e31d9000000'.decode('hex')) assert isinstance(f, HeadersFrame) - assert_equal(f.length, 15) + assert_equal(f.length, 16) assert_equal(f.TYPE, HeadersFrame.TYPE) assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') + assert_equal(f.headers, [('host', 'foo.bar')]) assert_equal(f.exclusive, True) assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) - f = Frame.from_bytes('00000f01280123456703076543212a666f6f626172000000'.decode('hex')) + f = Frame.from_bytes('00001001280123456703076543212a668594e75e31d9000000'.decode('hex')) assert isinstance(f, HeadersFrame) - assert_equal(f.length, 15) + assert_equal(f.length, 16) assert_equal(f.TYPE, HeadersFrame.TYPE) assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') + assert_equal(f.headers, [('host', 'foo.bar')]) assert_equal(f.exclusive, False) assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) + def test_headers_frame_human_readable(): - f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', pad_length=3, exclusive=False, stream_dependency=0x7654321, weight=42) + f = HeadersFrame( + 7, + HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, + 0x1234567, + headers=[], + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert f.human_readable() + + f = HeadersFrame( + 14, + HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, + 0x1234567, + headers=[('host', 'foo.bar')], + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) assert f.human_readable() + def test_priority_frame_to_bytes(): f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, exclusive=True, stream_dependency=0x7654321, weight=42) assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') @@ -135,6 +194,7 @@ def test_priority_frame_to_bytes(): f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, stream_dependency=0x0) tutils.raises(ValueError, f.to_bytes) + def test_priority_frame_from_bytes(): f = Frame.from_bytes('000005020001234567876543212a'.decode('hex')) assert isinstance(f, PriorityFrame) @@ -156,10 +216,12 @@ def test_priority_frame_from_bytes(): assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 21) + def test_priority_frame_human_readable(): f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, exclusive=False, stream_dependency=0x7654321, weight=21) assert f.human_readable() + def test_rst_stream_frame_to_bytes(): f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, error_code=0x7654321) assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') @@ -167,6 +229,7 @@ def test_rst_stream_frame_to_bytes(): f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x0) tutils.raises(ValueError, f.to_bytes) + def test_rst_stream_frame_from_bytes(): f = Frame.from_bytes('00000403000123456707654321'.decode('hex')) assert isinstance(f, RstStreamFrame) @@ -176,10 +239,12 @@ def test_rst_stream_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.error_code, 0x07654321) + def test_rst_stream_frame_human_readable(): f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, error_code=0x7654321) assert f.human_readable() + def test_settings_frame_to_bytes(): f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x0) assert_equal(f.to_bytes().encode('hex'), '000000040000000000') @@ -187,16 +252,26 @@ def test_settings_frame_to_bytes(): f = SettingsFrame(0, SettingsFrame.FLAG_ACK, 0x0) assert_equal(f.to_bytes().encode('hex'), '000000040100000000') - f = SettingsFrame(6, SettingsFrame.FLAG_ACK, 0x0, settings={SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) + f = SettingsFrame( + 6, + SettingsFrame.FLAG_ACK, 0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') - f = SettingsFrame(12, Frame.FLAG_NO_FLAGS, 0x0, settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + f = SettingsFrame( + 12, + Frame.FLAG_NO_FLAGS, + 0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) assert_equal(f.to_bytes().encode('hex'), '00000c040000000000000200000001000312345678') f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x1234567) tutils.raises(ValueError, f.to_bytes) + def test_settings_frame_from_bytes(): f = Frame.from_bytes('000000040000000000'.decode('hex')) assert isinstance(f, SettingsFrame) @@ -231,13 +306,21 @@ def test_settings_frame_from_bytes(): assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], 0x12345678) + def test_settings_frame_human_readable(): f = SettingsFrame(12, Frame.FLAG_NO_FLAGS, 0x0, settings={}) assert f.human_readable() - f = SettingsFrame(12, Frame.FLAG_NO_FLAGS, 0x0, settings={SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + f = SettingsFrame( + 12, + Frame.FLAG_NO_FLAGS, + 0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) assert f.human_readable() + def test_push_promise_frame_to_bytes(): f = PushPromiseFrame(10, Frame.FLAG_NO_FLAGS, 0x1234567, 0x7654321, 'foobar') assert_equal(f.to_bytes().encode('hex'), '00000a05000123456707654321666f6f626172') @@ -251,6 +334,7 @@ def test_push_promise_frame_to_bytes(): f = PushPromiseFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, 0x0) tutils.raises(ValueError, f.to_bytes) + def test_push_promise_frame_from_bytes(): f = Frame.from_bytes('00000a05000123456707654321666f6f626172'.decode('hex')) assert isinstance(f, PushPromiseFrame) @@ -268,10 +352,12 @@ def test_push_promise_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.header_block_fragment, 'foobar') + def test_push_promise_frame_human_readable(): f = PushPromiseFrame(14, HeadersFrame.FLAG_PADDED, 0x1234567, 0x7654321, 'foobar', pad_length=3) assert f.human_readable() + def test_ping_frame_to_bytes(): f = PingFrame(8, PingFrame.FLAG_ACK, 0x0, payload=b'foobar') assert_equal(f.to_bytes().encode('hex'), '000008060100000000666f6f6261720000') @@ -282,6 +368,7 @@ def test_ping_frame_to_bytes(): f = PingFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567) tutils.raises(ValueError, f.to_bytes) + def test_ping_frame_from_bytes(): f = Frame.from_bytes('000008060100000000666f6f6261720000'.decode('hex')) assert isinstance(f, PingFrame) @@ -299,10 +386,12 @@ def test_ping_frame_from_bytes(): assert_equal(f.stream_id, 0x0) assert_equal(f.payload, b'foobarde') + def test_ping_frame_human_readable(): f = PingFrame(8, PingFrame.FLAG_ACK, 0x0, payload=b'foobar') assert f.human_readable() + def test_goaway_frame_to_bytes(): f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x0, last_stream=0x1234567, error_code=0x87654321, data=b'') assert_equal(f.to_bytes().encode('hex'), '0000080700000000000123456787654321') @@ -313,6 +402,7 @@ def test_goaway_frame_to_bytes(): f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567, last_stream=0x1234567, error_code=0x87654321) tutils.raises(ValueError, f.to_bytes) + def test_goaway_frame_from_bytes(): f = Frame.from_bytes('0000080700000000000123456787654321'.decode('hex')) assert isinstance(f, GoAwayFrame) @@ -334,10 +424,12 @@ def test_goaway_frame_from_bytes(): assert_equal(f.error_code, 0x87654321) assert_equal(f.data, b'foobar') + def test_go_away_frame_human_readable(): f = GoAwayFrame(14, Frame.FLAG_NO_FLAGS, 0x0, last_stream=0x1234567, error_code=0x87654321, data=b'foobar') assert f.human_readable() + def test_window_update_frame_to_bytes(): f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0x1234567) assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') @@ -351,6 +443,7 @@ def test_window_update_frame_to_bytes(): f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) tutils.raises(ValueError, f.to_bytes) + def test_window_update_frame_from_bytes(): f = Frame.from_bytes('00000408000000000001234567'.decode('hex')) assert isinstance(f, WindowUpdateFrame) @@ -360,10 +453,12 @@ def test_window_update_frame_from_bytes(): assert_equal(f.stream_id, 0x0) assert_equal(f.window_size_increment, 0x1234567) + def test_window_update_frame_human_readable(): f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, window_size_increment=0x7654321) assert f.human_readable() + def test_continuation_frame_to_bytes(): f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') @@ -371,6 +466,7 @@ def test_continuation_frame_to_bytes(): f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x0, 'foobar') tutils.raises(ValueError, f.to_bytes) + def test_continuation_frame_from_bytes(): f = Frame.from_bytes('000006090401234567666f6f626172'.decode('hex')) assert isinstance(f, ContinuationFrame) @@ -380,6 +476,7 @@ def test_continuation_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.header_block_fragment, 'foobar') + def test_continuation_frame_human_readable(): f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') assert f.human_readable() -- cgit v1.2.3 From d50b9be0d5dab1772f0edcbfa89542ef9425e7bf Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 27 May 2015 17:53:45 +0200 Subject: add generic frame parsing method --- netlib/h2/frame.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 179634b0..11687316 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -31,6 +31,23 @@ class Frame(object): self.flags = flags self.stream_id = stream_id + @classmethod + def from_file(self, fp): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes(length, flags, stream_id, payload) + @classmethod def from_bytes(self, data): fields = struct.unpack("!HBBBL", data[:9]) -- cgit v1.2.3 From 780836b182cd982b978f16218299f2b77a8ed204 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 28 May 2015 17:46:44 +0200 Subject: add ALPN support to TCP abstraction --- netlib/tcp.py | 35 +++++++++++++++++++++++++++-------- netlib/test.py | 3 ++- test/test_tcp.py | 18 ++++++++++++++++++ 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 49f92e4a..fc2c144e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -360,7 +360,9 @@ class _Connection(object): def _create_ssl_context(self, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), - cipher_list=None + cipher_list=None, + alpn_protos=None, + alpn_select=None, ): """ :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD @@ -389,6 +391,17 @@ class _Connection(object): if log_ssl_key: context.set_info_callback(log_ssl_key) + # advertise application layer protocols + if alpn_protos is not None: + context.set_alpn_protos(alpn_protos) + + # select application layer protocol + if alpn_select is not None: + def alpn_select_f(conn, options): + return bytes(alpn_select) + + context.set_alpn_select_callback(alpn_select_f) + return context @@ -413,8 +426,8 @@ class TCPClient(_Connection): self.ssl_established = False self.sni = None - def create_ssl_context(self, cert=None, **sslctx_kwargs): - context = self._create_ssl_context(**sslctx_kwargs) + def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): + context = self._create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs) # Client Certs if cert: try: @@ -424,13 +437,13 @@ class TCPClient(_Connection): raise NetLibError("SSL client certificate error: %s" % str(v)) return context - def convert_to_ssl(self, sni=None, **sslctx_kwargs): + def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): """ cert: Path to a file containing both client cert and private key. options: A bit field consisting of OpenSSL.SSL.OP_* values """ - context = self.create_ssl_context(**sslctx_kwargs) + context = self.create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni @@ -465,6 +478,9 @@ class TCPClient(_Connection): def gettimeout(self): return self.connection.gettimeout() + def get_alpn_proto_negotiated(self): + return self.connection.get_alpn_proto_negotiated() + class BaseHandler(_Connection): @@ -492,6 +508,7 @@ class BaseHandler(_Connection): request_client_cert=None, chain_file=None, dhparams=None, + alpn_select=None, **sslctx_kwargs): """ cert: A certutils.SSLCert object. @@ -517,7 +534,8 @@ class BaseHandler(_Connection): we may be able to make the proper behaviour the default again, but until then we're conservative. """ - context = self._create_ssl_context(**sslctx_kwargs) + + context = self._create_ssl_context(alpn_select=alpn_select, **sslctx_kwargs) context.use_privatekey(key) context.use_certificate(cert.x509) @@ -542,12 +560,13 @@ class BaseHandler(_Connection): return context - def convert_to_ssl(self, cert, key, **sslctx_kwargs): + def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) """ - context = self.create_ssl_context(cert, key, **sslctx_kwargs) + + context = self.create_ssl_context(cert, key, alpn_select=alpn_select, **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() try: diff --git a/netlib/test.py b/netlib/test.py index b6f94273..63b493a9 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -82,7 +82,8 @@ class TServer(tcp.TCPServer): request_client_cert=self.ssl["request_client_cert"], cipher_list=self.ssl.get("cipher_list", None), dhparams=self.ssl.get("dhparams", None), - chain_file=self.ssl.get("chain_file", None) + chain_file=self.ssl.get("chain_file", None), + alpn_select=self.ssl.get("alpn_select", None) ) h.handle() h.finish() diff --git a/test/test_tcp.py b/test/test_tcp.py index 2bf492fa..ab786d3f 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -389,6 +389,24 @@ class TestTimeOut(test.ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) +class TestALPN(test.ServerTestBase): + handler = HangHandler + ssl = dict( + cert=tutils.test_data.path("data/server.crt"), + key=tutils.test_data.path("data/server.key"), + request_client_cert=False, + v3_only=False, + alpn_select="h2" + ) + + def test_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=["h2"]) + print "ALPN: %s" % c.get_alpn_proto_negotiated() + assert c.get_alpn_proto_negotiated() == "h2" + + class TestSSLTimeOut(test.ServerTestBase): handler = HangHandler ssl = dict( -- cgit v1.2.3 From e2de49596d0e60e343c71c73e0847b17fb27ac3c Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 28 May 2015 17:46:30 +0200 Subject: add HTTP/2-capable client --- netlib/h2/h2.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ test/h2/example.py | 20 +++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 test/h2/example.py diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index 7a85226f..bfe5832b 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -1,3 +1,5 @@ +from .. import utils, odict, tcp +from frame import * # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' @@ -18,3 +20,66 @@ ERROR_CODES = utils.BiDi( INADEQUATE_SECURITY=0xc, HTTP_1_1_REQUIRED=0xd ) + + +class H2Client(tcp.TCPClient): + ALPN_PROTO_H2 = b'h2' + + DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ^ 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ^ 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, + } + + def __init__(self, address, source_address=None): + super(H2Client, self).__init__(address, source_address) + self.settings = self.DEFAULT_SETTINGS.copy() + + def connect(self, send_preface=True): + super(H2Client, self).connect() + self.convert_to_ssl(alpn_protos=[self.ALPN_PROTO_H2]) + + alp = self.get_alpn_proto_negotiated() + if alp != b'h2': + raise NotImplementedError("H2Client can not handle unknown protocol: %s" % alp) + print "-> Successfully negotiated 'h2' application layer protocol." + + if send_preface: + self.wfile.write(bytes(CLIENT_CONNECTION_PREFACE.decode('hex'))) + self.send_frame(SettingsFrame()) + + frame = Frame.from_file(self.rfile) + print frame.human_readable() + assert isinstance(frame, SettingsFrame) + self.apply_settings(frame.settings) + + print "-> Connection Preface completed." + + print "-> H2Client is ready..." + + def send_frame(self, frame): + self.wfile.write(frame.to_bytes()) + self.wfile.flush() + + def read_frame(self): + frame = Frame.from_file(self.rfile) + if isinstance(frame, SettingsFrame): + self.apply_settings(frame.settings) + + return frame + + def apply_settings(self, settings): + for setting, value in settings.items(): + old_value = self.settings[setting] + if not old_value: + old_value = '-' + + self.settings[setting] = value + print "-> Setting changed: %s to %d (was %s)" % + (SettingsFrame.SETTINGS.get_name(setting), value, str(old_value)) + + self.send_frame(SettingsFrame(flags=Frame.FLAG_ACK)) + print "-> New settings acknowledged." diff --git a/test/h2/example.py b/test/h2/example.py new file mode 100644 index 00000000..ca7c6c38 --- /dev/null +++ b/test/h2/example.py @@ -0,0 +1,20 @@ +from netlib import tcp +from netlib.h2.frame import * +from netlib.h2.h2 import * +from hpack.hpack import Encoder, Decoder + +c = H2Client(("127.0.0.1", 443)) +c.connect() + +c.send_frame(HeadersFrame( + flags=(Frame.FLAG_END_HEADERS | Frame.FLAG_END_STREAM), + stream_id=0x1, + headers=[ + (b':method', 'GET'), + (b':path', b'/index.html'), + (b':scheme', b'https'), + (b':authority', b'localhost'), + ])) + +while True: + print c.read_frame().human_readable() -- cgit v1.2.3 From c32d8189faa24cbe016bb3c859f64c816e0871fe Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 29 May 2015 16:59:50 +0200 Subject: cleanup imports --- netlib/h2/frame.py | 1 - test/h2/example.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 11687316..d4294052 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -1,5 +1,4 @@ import struct -import io from hpack.hpack import Encoder, Decoder from .. import utils diff --git a/test/h2/example.py b/test/h2/example.py index ca7c6c38..fc4f2f10 100644 --- a/test/h2/example.py +++ b/test/h2/example.py @@ -1,7 +1,5 @@ -from netlib import tcp from netlib.h2.frame import * from netlib.h2.h2 import * -from hpack.hpack import Encoder, Decoder c = H2Client(("127.0.0.1", 443)) c.connect() -- cgit v1.2.3 From 629fa8e5528783501e402a7e33ac6199bb38ece6 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 29 May 2015 17:04:12 +0200 Subject: make tests aware of ALPN & OpenSSL 1.0.2 dependency --- test/test_tcp.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/test/test_tcp.py b/test/test_tcp.py index ab786d3f..62617707 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -4,11 +4,14 @@ import time import socket import random import os -from netlib import tcp, certutils, test, certffi import threading import mock -import tutils + from OpenSSL import SSL +import OpenSSL + +from netlib import tcp, certutils, test, certffi +import tutils class EchoHandler(tcp.BaseHandler): @@ -399,12 +402,14 @@ class TestALPN(test.ServerTestBase): alpn_select="h2" ) - def test_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl(alpn_protos=["h2"]) - print "ALPN: %s" % c.get_alpn_proto_negotiated() - assert c.get_alpn_proto_negotiated() == "h2" + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=["h2"]) + print "ALPN: %s" % c.get_alpn_proto_negotiated() + assert c.get_alpn_proto_negotiated() == "h2" class TestSSLTimeOut(test.ServerTestBase): -- cgit v1.2.3 From f76bfabc5d4ce36c56b1d1fd571728ee06f37b78 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 30 May 2015 12:02:58 +1200 Subject: Adjust pep8 parameters, reformat --- netlib/certutils.py | 75 +++++++++++++---- netlib/h2/frame.py | 127 +++++++++++++++++++++++----- netlib/h2/h2.py | 8 +- netlib/http.py | 3 +- netlib/http_uastrings.py | 91 +++++++------------- netlib/tcp.py | 42 +++++++--- netlib/test.py | 4 +- netlib/wsgi.py | 3 +- setup.cfg | 6 +- setup.py | 12 +-- test/h2/test_frames.py | 213 ++++++++++++++++++++++++++++++++++++++--------- test/test_certutils.py | 18 +++- test/test_http_auth.py | 5 +- test/test_socks.py | 7 +- test/test_tcp.py | 17 +++- test/test_websockets.py | 3 +- 16 files changed, 455 insertions(+), 179 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index abf1a28b..ade61bb5 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -96,7 +96,8 @@ def dummy_cert(privkey, cacert, commonname, sans): cert.set_serial_number(int(time.time() * 10000)) if ss: cert.set_version(2) - cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) + cert.add_extensions( + [OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) cert.set_pubkey(cacert.get_pubkey()) cert.sign(privkey, "sha256") return SSLCert(cert) @@ -156,7 +157,12 @@ class CertStore(object): Implements an in-memory certificate store. """ - def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams): + def __init__( + self, + default_privatekey, + default_ca, + default_chain_file, + dhparams): self.default_privatekey = default_privatekey self.default_ca = default_ca self.default_chain_file = default_chain_file @@ -176,8 +182,10 @@ class CertStore(object): if bio != OpenSSL.SSL._ffi.NULL: bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( - bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL - ) + bio, + OpenSSL.SSL._ffi.NULL, + OpenSSL.SSL._ffi.NULL, + OpenSSL.SSL._ffi.NULL) dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) return dh @@ -189,8 +197,12 @@ class CertStore(object): else: with open(ca_path, "rb") as f: raw = f.read() - ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) - key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + ca = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, + raw) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + raw) dh_path = os.path.join(path, basename + "-dhparam.pem") dh = cls.load_dhparam(dh_path) return cls(key, ca, ca_path, dh) @@ -206,16 +218,28 @@ class CertStore(object): key, ca = create_ca(o=o, cn=cn, exp=expiry) # Dump the CA plus private key with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: - f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.write( + OpenSSL.crypto.dump_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + key)) + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) # Dump the certificate in PEM format with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) # Create a .cer file with the same contents for Android with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) + f.write( + OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + ca)) # Dump the certificate in PKCS12 format for Windows devices with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: @@ -232,9 +256,14 @@ class CertStore(object): def add_cert_file(self, spec, path): with open(path, "rb") as f: raw = f.read() - cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)) + cert = SSLCert( + OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, + raw)) try: - privatekey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + privatekey = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + raw) except Exception: privatekey = self.default_privatekey self.add_cert( @@ -284,15 +313,22 @@ class CertStore(object): potential_keys.extend(self.asterisk_forms(s)) potential_keys.append((commonname, tuple(sans))) - name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) + name = next( + itertools.ifilter( + lambda key: key in self.certs, + potential_keys), + None) if name: entry = self.certs[name] else: entry = CertStoreEntry( - cert=dummy_cert(self.default_privatekey, self.default_ca, commonname, sans), + cert=dummy_cert( + self.default_privatekey, + self.default_ca, + commonname, + sans), privatekey=self.default_privatekey, - chain_file=self.default_chain_file - ) + chain_file=self.default_chain_file) self.certs[(commonname, tuple(sans))] = entry return entry.cert, entry.privatekey, entry.chain_file @@ -317,7 +353,8 @@ class _GeneralName(univ.Choice): class _GeneralNames(univ.SequenceOf): componentType = _GeneralName() - sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) + sizeSpec = univ.SequenceOf.sizeSpec + \ + constraint.ValueSizeConstraint(1, 1024) class SSLCert(object): @@ -345,7 +382,9 @@ class SSLCert(object): return klass.from_pem(pem) def to_pem(self): - return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, self.x509) + return OpenSSL.crypto.dump_certificate( + OpenSSL.crypto.FILETYPE_PEM, + self.x509) def digest(self, name): return self.x509.digest(name) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index d4294052..36456c46 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -94,7 +94,13 @@ class DataFrame(Frame): TYPE = 0x0 VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b'', pad_length=0): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): super(DataFrame, self).__init__(length, flags, stream_id) self.payload = payload self.pad_length = pad_length @@ -132,9 +138,22 @@ class DataFrame(Frame): class HeadersFrame(Frame): TYPE = 0x1 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY] - - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, headers=None, pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + headers=None, + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): super(HeadersFrame, self).__init__(length, flags, stream_id) if headers is None: @@ -157,7 +176,9 @@ class HeadersFrame(Frame): header_block_fragment = payload[0:] if f.flags & self.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack('!LB', header_block_fragment[:5]) + f.stream_dependency, f.weight = struct.unpack( + '!LB', header_block_fragment[ + :5]) f.exclusive = bool(f.stream_dependency >> 31) f.stream_dependency &= 0x7FFFFFFF header_block_fragment = header_block_fragment[5:] @@ -176,7 +197,9 @@ class HeadersFrame(Frame): b += struct.pack('!B', self.pad_length) if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) b += Encoder().encode(self.headers) @@ -209,7 +232,14 @@ class PriorityFrame(Frame): TYPE = 0x2 VALID_FLAGS = [] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, exclusive=False, stream_dependency=0x0, weight=0): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): super(PriorityFrame, self).__init__(length, flags, stream_id) self.exclusive = exclusive self.stream_dependency = stream_dependency @@ -227,12 +257,17 @@ class PriorityFrame(Frame): def payload_bytes(self): if self.stream_id == 0x0: - raise ValueError('PRIORITY frames MUST be associated with a stream.') + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') if self.stream_dependency == 0x0: raise ValueError('stream dependency is invalid.') - return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) def payload_human_readable(self): s = [] @@ -246,7 +281,12 @@ class RstStreamFrame(Frame): TYPE = 0x3 VALID_FLAGS = [] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, error_code=0x0): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): super(RstStreamFrame, self).__init__(length, flags, stream_id) self.error_code = error_code @@ -258,7 +298,8 @@ class RstStreamFrame(Frame): def payload_bytes(self): if self.stream_id == 0x0: - raise ValueError('RST_STREAM frames MUST be associated with a stream.') + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') return struct.pack('!L', self.error_code) @@ -279,7 +320,12 @@ class SettingsFrame(Frame): SETTINGS_MAX_HEADER_LIST_SIZE=0x6, ) - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings=None): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): super(SettingsFrame, self).__init__(length, flags, stream_id) if settings is None: @@ -299,7 +345,8 @@ class SettingsFrame(Frame): def payload_bytes(self): if self.stream_id != 0x0: - raise ValueError('SETTINGS frames MUST NOT be associated with a stream.') + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') b = b'' for identifier, value in self.settings.items(): @@ -323,7 +370,14 @@ class PushPromiseFrame(Frame): TYPE = 0x5 VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, promised_stream=0x0, header_block_fragment=b'', pad_length=0): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): super(PushPromiseFrame, self).__init__(length, flags, stream_id) self.pad_length = pad_length self.promised_stream = promised_stream @@ -346,7 +400,8 @@ class PushPromiseFrame(Frame): def payload_bytes(self): if self.stream_id == 0x0: - raise ValueError('PUSH_PROMISE frames MUST be associated with a stream.') + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') if self.promised_stream == 0x0: raise ValueError('Promised stream id not valid.') @@ -378,7 +433,12 @@ class PingFrame(Frame): TYPE = 0x6 VALID_FLAGS = [Frame.FLAG_ACK] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b''): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): super(PingFrame, self).__init__(length, flags, stream_id) self.payload = payload @@ -390,7 +450,8 @@ class PingFrame(Frame): def payload_bytes(self): if self.stream_id != 0x0: - raise ValueError('PING frames MUST NOT be associated with a stream.') + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') b = self.payload[0:8] b += b'\0' * (8 - len(b)) @@ -404,7 +465,14 @@ class GoAwayFrame(Frame): TYPE = 0x7 VALID_FLAGS = [] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, last_stream=0x0, error_code=0x0, data=b''): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): super(GoAwayFrame, self).__init__(length, flags, stream_id) self.last_stream = last_stream self.error_code = error_code @@ -422,7 +490,8 @@ class GoAwayFrame(Frame): def payload_bytes(self): if self.stream_id != 0x0: - raise ValueError('GOAWAY frames MUST NOT be associated with a stream.') + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) b += bytes(self.data) @@ -440,7 +509,12 @@ class WindowUpdateFrame(Frame): TYPE = 0x8 VALID_FLAGS = [] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, window_size_increment=0x0): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): super(WindowUpdateFrame, self).__init__(length, flags, stream_id) self.window_size_increment = window_size_increment @@ -455,7 +529,8 @@ class WindowUpdateFrame(Frame): def payload_bytes(self): if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError('Window Szie Increment MUST be greater than 0 and less than 2^31.') + raise ValueError( + 'Window Szie Increment MUST be greater than 0 and less than 2^31.') return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) @@ -467,7 +542,12 @@ class ContinuationFrame(Frame): TYPE = 0x9 VALID_FLAGS = [Frame.FLAG_END_HEADERS] - def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b''): + def __init__( + self, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): super(ContinuationFrame, self).__init__(length, flags, stream_id) self.header_block_fragment = header_block_fragment @@ -479,7 +559,8 @@ class ContinuationFrame(Frame): def payload_bytes(self): if self.stream_id == 0x0: - raise ValueError('CONTINUATION frames MUST be associated with a stream.') + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') return self.header_block_fragment diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index bfe5832b..707b1465 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -44,7 +44,9 @@ class H2Client(tcp.TCPClient): alp = self.get_alpn_proto_negotiated() if alp != b'h2': - raise NotImplementedError("H2Client can not handle unknown protocol: %s" % alp) + raise NotImplementedError( + "H2Client can not handle unknown protocol: %s" % + alp) print "-> Successfully negotiated 'h2' application layer protocol." if send_preface: @@ -79,7 +81,9 @@ class H2Client(tcp.TCPClient): self.settings[setting] = value print "-> Setting changed: %s to %d (was %s)" % - (SettingsFrame.SETTINGS.get_name(setting), value, str(old_value)) + (SettingsFrame.SETTINGS.get_name(setting), + value, + str(old_value)) self.send_frame(SettingsFrame(flags=Frame.FLAG_ACK)) print "-> New settings acknowledged." diff --git a/netlib/http.py b/netlib/http.py index 47658097..a2af9e49 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -124,7 +124,8 @@ def read_chunked(fp, limit, is_request): May raise HttpError. """ # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2. + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + # 3.3 2. total = 0 code = 400 if is_request else 502 while True: diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py index d0d145da..d9869531 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http_uastrings.py @@ -8,66 +8,37 @@ from __future__ import (absolute_import, print_function, division) # A collection of (name, shortcut, string) tuples. UASTRINGS = [ - ( - "android", - "a", - "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02" - ), - - ( - "blackberry", - "l", - "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+" - ), - - ( - "bingbot", - "b", - "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)" - ), - - ( - "chrome", - "c", - "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1" - ), - - ( - "firefox", - "f", - "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1" - ), - - ( - "googlebot", - "g", - "Googlebot/2.1 (+http://www.googlebot.com/bot.html)" - ), - - ( - "ie9", - "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))" - ), - - ( - "ipad", - "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3" - ), - - ( - "iphone", - "h", - "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5", - ), - - ( - "safari", - "s", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10" - ) -] + ("android", + "a", + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), + ("blackberry", + "l", + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), + ("bingbot", + "b", + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), + ("chrome", + "c", + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), + ("firefox", + "f", + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), + ("googlebot", + "g", + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), + ("ie9", + "i", + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"), + ("ipad", + "p", + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"), + ("iphone", + "h", + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5", + ), + ("safari", + "s", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10")] def get_by_shortcut(s): diff --git a/netlib/tcp.py b/netlib/tcp.py index fc2c144e..a705c95b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,7 +48,8 @@ class SSLKeyLogger(object): self.f = None self.lock = threading.Lock() - __name__ = "SSLKeyLogger" # required for functools.wraps, which pyOpenSSL uses. + # required for functools.wraps, which pyOpenSSL uses. + __name__ = "SSLKeyLogger" def __call__(self, connection, where, ret): if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: @@ -61,7 +62,10 @@ class SSLKeyLogger(object): self.f.write("\r\n") client_random = connection.client_random().encode("hex") masterkey = connection.master_key().encode("hex") - self.f.write("CLIENT_RANDOM {} {}\r\n".format(client_random, masterkey)) + self.f.write( + "CLIENT_RANDOM {} {}\r\n".format( + client_random, + masterkey)) self.f.flush() def close(self): @@ -75,7 +79,8 @@ class SSLKeyLogger(object): return SSLKeyLogger(filename) return False -log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) +log_ssl_key = SSLKeyLogger.create_logfun( + os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) class _FileLike(object): @@ -378,7 +383,8 @@ class _Connection(object): # Workaround for # https://github.com/pyca/pyopenssl/issues/190 # https://github.com/mitmproxy/mitmproxy/issues/472 - context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Options already set before are not cleared. + # Options already set before are not cleared. + context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) # Cipher List if cipher_list: @@ -420,14 +426,17 @@ class TCPClient(_Connection): def __init__(self, address, source_address=None): self.address = Address.wrap(address) - self.source_address = Address.wrap(source_address) if source_address else None + self.source_address = Address.wrap( + source_address) if source_address else None self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False self.sni = None def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): - context = self._create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs) + context = self._create_ssl_context( + alpn_protos=alpn_protos, + **sslctx_kwargs) # Client Certs if cert: try: @@ -443,7 +452,9 @@ class TCPClient(_Connection): options: A bit field consisting of OpenSSL.SSL.OP_* values """ - context = self.create_ssl_context(alpn_protos=alpn_protos, **sslctx_kwargs) + context = self.create_ssl_context( + alpn_protos=alpn_protos, + **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni @@ -469,7 +480,9 @@ class TCPClient(_Connection): self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError) as err: - raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err)) + raise NetLibError( + 'Error connecting to "%s": %s' % + (self.address.host, err)) self.connection = connection def settimeout(self, n): @@ -535,7 +548,9 @@ class BaseHandler(_Connection): until then we're conservative. """ - context = self._create_ssl_context(alpn_select=alpn_select, **sslctx_kwargs) + context = self._create_ssl_context( + alpn_select=alpn_select, + **sslctx_kwargs) context.use_privatekey(key) context.use_certificate(cert.x509) @@ -566,7 +581,11 @@ class BaseHandler(_Connection): For a list of parameters, see BaseHandler._create_ssl_context(...) """ - context = self.create_ssl_context(cert, key, alpn_select=alpn_select, **sslctx_kwargs) + context = self.create_ssl_context( + cert, + key, + alpn_select=alpn_select, + **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() try: @@ -611,7 +630,8 @@ class TCPServer(object): try: while not self.__shutdown_request: try: - r, w, e = select.select([self.socket], [], [], poll_interval) + r, w, e = select.select( + [self.socket], [], [], poll_interval) except select.error as ex: # pragma: no cover if ex[0] == EINTR: continue diff --git a/netlib/test.py b/netlib/test.py index 63b493a9..14f50157 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -67,7 +67,9 @@ class TServer(tcp.TCPServer): file(self.ssl["cert"], "rb").read() ) raw = file(self.ssl["key"], "rb").read() - key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + raw) if self.ssl["v3_only"]: method = tcp.SSLv3_METHOD options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 diff --git a/netlib/wsgi.py b/netlib/wsgi.py index f393039a..827cf6f0 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -77,7 +77,8 @@ class WSGIAdaptor(object): } environ.update(extra) if flow.client_conn.address: - environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = flow.client_conn.address() + environ["REMOTE_ADDR"], environ[ + "REMOTE_PORT"] = flow.client_conn.address() for key, value in flow.request.headers.items(): key = 'HTTP_' + key.upper().replace('-', '_') diff --git a/setup.cfg b/setup.cfg index 1ba84a24..bc980d56 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,9 @@ [flake8] -max-line-length = 160 +max-line-length = 80 max-complexity = 15 [pep8] -max-line-length = 160 +max-line-length = 80 max-complexity = 15 +exclude = */contrib/* +ignore = E251,E309 diff --git a/setup.py b/setup.py index 450e9822..b5674d85 100644 --- a/setup.py +++ b/setup.py @@ -34,17 +34,14 @@ setup( "Topic :: Software Development :: Testing", "Topic :: Software Development :: Testing :: Traffic Generation", ], - packages=find_packages(), include_package_data=True, - install_requires=[ "pyasn1>=0.1.7", "pyOpenSSL>=0.15.1", "cryptography>=0.9", "passlib>=1.6.2", - "hpack>=1.0.1" - ], + "hpack>=1.0.1"], extras_require={ 'dev': [ "mock>=1.0.1", @@ -53,7 +50,6 @@ setup( "coveralls>=0.4.1", "autopep8>=1.0.3", "autoflake>=0.6.6", - "pathod>=%s, <%s" % (version.MINORVERSION, version.NEXT_MINORVERSION) - ] - } -) + "pathod>=%s, <%s" % + (version.MINORVERSION, + version.NEXT_MINORVERSION)]}) diff --git a/test/h2/test_frames.py b/test/h2/test_frames.py index eb470dd4..313ef405 100644 --- a/test/h2/test_frames.py +++ b/test/h2/test_frames.py @@ -7,7 +7,12 @@ from nose.tools import assert_equal def test_invalid_flags(): - tutils.raises(ValueError, DataFrame, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') + tutils.raises( + ValueError, + DataFrame, + ContinuationFrame.FLAG_END_HEADERS, + 0x1234567, + 'foobar') def test_frame_equality(): @@ -24,8 +29,15 @@ def test_data_frame_to_bytes(): f = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') - f = DataFrame(11, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED, 0x1234567, 'foobar', pad_length=3) - assert_equal(f.to_bytes().encode('hex'), '00000a00090123456703666f6f626172000000') + f = DataFrame( + 11, + Frame.FLAG_END_STREAM | Frame.FLAG_PADDED, + 0x1234567, + 'foobar', + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000a00090123456703666f6f626172000000') f = DataFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') tutils.raises(ValueError, f.to_bytes) @@ -50,7 +62,12 @@ def test_data_frame_from_bytes(): def test_data_frame_human_readable(): - f = DataFrame(11, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED, 0x1234567, 'foobar', pad_length=3) + f = DataFrame( + 11, + Frame.FLAG_END_STREAM | Frame.FLAG_PADDED, + 0x1234567, + 'foobar', + pad_length=3) assert f.human_readable() @@ -68,7 +85,9 @@ def test_headers_frame_to_bytes(): 0x1234567, headers=[('host', 'foo.bar')], pad_length=3) - assert_equal(f.to_bytes().encode('hex'), '00000b01080123456703668594e75e31d9000000') + assert_equal( + f.to_bytes().encode('hex'), + '00000b01080123456703668594e75e31d9000000') f = HeadersFrame( 10, @@ -78,7 +97,9 @@ def test_headers_frame_to_bytes(): exclusive=True, stream_dependency=0x7654321, weight=42) - assert_equal(f.to_bytes().encode('hex'), '00000c012001234567876543212a668594e75e31d9') + assert_equal( + f.to_bytes().encode('hex'), + '00000c012001234567876543212a668594e75e31d9') f = HeadersFrame( 14, @@ -89,7 +110,9 @@ def test_headers_frame_to_bytes(): exclusive=True, stream_dependency=0x7654321, weight=42) - assert_equal(f.to_bytes().encode('hex'), '00001001280123456703876543212a668594e75e31d9000000') + assert_equal( + f.to_bytes().encode('hex'), + '00001001280123456703876543212a668594e75e31d9000000') f = HeadersFrame( 14, @@ -100,7 +123,9 @@ def test_headers_frame_to_bytes(): exclusive=False, stream_dependency=0x7654321, weight=42) - assert_equal(f.to_bytes().encode('hex'), '00001001280123456703076543212a668594e75e31d9000000') + assert_equal( + f.to_bytes().encode('hex'), + '00001001280123456703076543212a668594e75e31d9000000') f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') tutils.raises(ValueError, f.to_bytes) @@ -115,7 +140,8 @@ def test_headers_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.headers, [('host', 'foo.bar')]) - f = Frame.from_bytes('00000b01080123456703668594e75e31d9000000'.decode('hex')) + f = Frame.from_bytes( + '00000b01080123456703668594e75e31d9000000'.decode('hex')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 11) assert_equal(f.TYPE, HeadersFrame.TYPE) @@ -123,7 +149,8 @@ def test_headers_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.headers, [('host', 'foo.bar')]) - f = Frame.from_bytes('00000c012001234567876543212a668594e75e31d9'.decode('hex')) + f = Frame.from_bytes( + '00000c012001234567876543212a668594e75e31d9'.decode('hex')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 12) assert_equal(f.TYPE, HeadersFrame.TYPE) @@ -134,7 +161,8 @@ def test_headers_frame_from_bytes(): assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) - f = Frame.from_bytes('00001001280123456703876543212a668594e75e31d9000000'.decode('hex')) + f = Frame.from_bytes( + '00001001280123456703876543212a668594e75e31d9000000'.decode('hex')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 16) assert_equal(f.TYPE, HeadersFrame.TYPE) @@ -145,7 +173,8 @@ def test_headers_frame_from_bytes(): assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) - f = Frame.from_bytes('00001001280123456703076543212a668594e75e31d9000000'.decode('hex')) + f = Frame.from_bytes( + '00001001280123456703076543212a668594e75e31d9000000'.decode('hex')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 16) assert_equal(f.TYPE, HeadersFrame.TYPE) @@ -182,10 +211,22 @@ def test_headers_frame_human_readable(): def test_priority_frame_to_bytes(): - f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, exclusive=True, stream_dependency=0x7654321, weight=42) + f = PriorityFrame( + 5, + Frame.FLAG_NO_FLAGS, + 0x1234567, + exclusive=True, + stream_dependency=0x7654321, + weight=42) assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') - f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, exclusive=False, stream_dependency=0x7654321, weight=21) + f = PriorityFrame( + 5, + Frame.FLAG_NO_FLAGS, + 0x1234567, + exclusive=False, + stream_dependency=0x7654321, + weight=21) assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x0, stream_dependency=0x1234567) @@ -218,7 +259,13 @@ def test_priority_frame_from_bytes(): def test_priority_frame_human_readable(): - f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, exclusive=False, stream_dependency=0x7654321, weight=21) + f = PriorityFrame( + 5, + Frame.FLAG_NO_FLAGS, + 0x1234567, + exclusive=False, + stream_dependency=0x7654321, + weight=21) assert f.human_readable() @@ -266,7 +313,9 @@ def test_settings_frame_to_bytes(): settings={ SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert_equal(f.to_bytes().encode('hex'), '00000c040000000000000200000001000312345678') + assert_equal( + f.to_bytes().encode('hex'), + '00000c040000000000000200000001000312345678') f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x1234567) tutils.raises(ValueError, f.to_bytes) @@ -296,7 +345,8 @@ def test_settings_frame_from_bytes(): assert_equal(len(f.settings), 1) assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - f = Frame.from_bytes('00000c040000000000000200000001000312345678'.decode('hex')) + f = Frame.from_bytes( + '00000c040000000000000200000001000312345678'.decode('hex')) assert isinstance(f, SettingsFrame) assert_equal(f.length, 12) assert_equal(f.TYPE, SettingsFrame.TYPE) @@ -304,7 +354,10 @@ def test_settings_frame_from_bytes(): assert_equal(f.stream_id, 0x0) assert_equal(len(f.settings), 2) assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], 0x12345678) + assert_equal( + f.settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], + 0x12345678) def test_settings_frame_human_readable(): @@ -322,11 +375,26 @@ def test_settings_frame_human_readable(): def test_push_promise_frame_to_bytes(): - f = PushPromiseFrame(10, Frame.FLAG_NO_FLAGS, 0x1234567, 0x7654321, 'foobar') - assert_equal(f.to_bytes().encode('hex'), '00000a05000123456707654321666f6f626172') + f = PushPromiseFrame( + 10, + Frame.FLAG_NO_FLAGS, + 0x1234567, + 0x7654321, + 'foobar') + assert_equal( + f.to_bytes().encode('hex'), + '00000a05000123456707654321666f6f626172') - f = PushPromiseFrame(14, HeadersFrame.FLAG_PADDED, 0x1234567, 0x7654321, 'foobar', pad_length=3) - assert_equal(f.to_bytes().encode('hex'), '00000e0508012345670307654321666f6f626172000000') + f = PushPromiseFrame( + 14, + HeadersFrame.FLAG_PADDED, + 0x1234567, + 0x7654321, + 'foobar', + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000e0508012345670307654321666f6f626172000000') f = PushPromiseFrame(4, Frame.FLAG_NO_FLAGS, 0x0, 0x1234567) tutils.raises(ValueError, f.to_bytes) @@ -344,7 +412,8 @@ def test_push_promise_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.header_block_fragment, 'foobar') - f = Frame.from_bytes('00000e0508012345670307654321666f6f626172000000'.decode('hex')) + f = Frame.from_bytes( + '00000e0508012345670307654321666f6f626172000000'.decode('hex')) assert isinstance(f, PushPromiseFrame) assert_equal(f.length, 14) assert_equal(f.TYPE, PushPromiseFrame.TYPE) @@ -354,16 +423,26 @@ def test_push_promise_frame_from_bytes(): def test_push_promise_frame_human_readable(): - f = PushPromiseFrame(14, HeadersFrame.FLAG_PADDED, 0x1234567, 0x7654321, 'foobar', pad_length=3) + f = PushPromiseFrame( + 14, + HeadersFrame.FLAG_PADDED, + 0x1234567, + 0x7654321, + 'foobar', + pad_length=3) assert f.human_readable() def test_ping_frame_to_bytes(): f = PingFrame(8, PingFrame.FLAG_ACK, 0x0, payload=b'foobar') - assert_equal(f.to_bytes().encode('hex'), '000008060100000000666f6f6261720000') + assert_equal( + f.to_bytes().encode('hex'), + '000008060100000000666f6f6261720000') f = PingFrame(8, Frame.FLAG_NO_FLAGS, 0x0, payload=b'foobardeadbeef') - assert_equal(f.to_bytes().encode('hex'), '000008060000000000666f6f6261726465') + assert_equal( + f.to_bytes().encode('hex'), + '000008060000000000666f6f6261726465') f = PingFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567) tutils.raises(ValueError, f.to_bytes) @@ -393,13 +472,34 @@ def test_ping_frame_human_readable(): def test_goaway_frame_to_bytes(): - f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x0, last_stream=0x1234567, error_code=0x87654321, data=b'') - assert_equal(f.to_bytes().encode('hex'), '0000080700000000000123456787654321') - - f = GoAwayFrame(14, Frame.FLAG_NO_FLAGS, 0x0, last_stream=0x1234567, error_code=0x87654321, data=b'foobar') - assert_equal(f.to_bytes().encode('hex'), '00000e0700000000000123456787654321666f6f626172') - - f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567, last_stream=0x1234567, error_code=0x87654321) + f = GoAwayFrame( + 8, + Frame.FLAG_NO_FLAGS, + 0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'') + assert_equal( + f.to_bytes().encode('hex'), + '0000080700000000000123456787654321') + + f = GoAwayFrame( + 14, + Frame.FLAG_NO_FLAGS, + 0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'foobar') + assert_equal( + f.to_bytes().encode('hex'), + '00000e0700000000000123456787654321666f6f626172') + + f = GoAwayFrame( + 8, + Frame.FLAG_NO_FLAGS, + 0x1234567, + last_stream=0x1234567, + error_code=0x87654321) tutils.raises(ValueError, f.to_bytes) @@ -414,7 +514,8 @@ def test_goaway_frame_from_bytes(): assert_equal(f.error_code, 0x87654321) assert_equal(f.data, b'') - f = Frame.from_bytes('00000e0700000000000123456787654321666f6f626172'.decode('hex')) + f = Frame.from_bytes( + '00000e0700000000000123456787654321666f6f626172'.decode('hex')) assert isinstance(f, GoAwayFrame) assert_equal(f.length, 14) assert_equal(f.TYPE, GoAwayFrame.TYPE) @@ -426,18 +527,36 @@ def test_goaway_frame_from_bytes(): def test_go_away_frame_human_readable(): - f = GoAwayFrame(14, Frame.FLAG_NO_FLAGS, 0x0, last_stream=0x1234567, error_code=0x87654321, data=b'foobar') + f = GoAwayFrame( + 14, + Frame.FLAG_NO_FLAGS, + 0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'foobar') assert f.human_readable() def test_window_update_frame_to_bytes(): - f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0x1234567) + f = WindowUpdateFrame( + 4, + Frame.FLAG_NO_FLAGS, + 0x0, + window_size_increment=0x1234567) assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') - f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, window_size_increment=0x7654321) + f = WindowUpdateFrame( + 4, + Frame.FLAG_NO_FLAGS, + 0x1234567, + window_size_increment=0x7654321) assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') - f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0xdeadbeef) + f = WindowUpdateFrame( + 4, + Frame.FLAG_NO_FLAGS, + 0x0, + window_size_increment=0xdeadbeef) tutils.raises(ValueError, f.to_bytes) f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) @@ -455,12 +574,20 @@ def test_window_update_frame_from_bytes(): def test_window_update_frame_human_readable(): - f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, window_size_increment=0x7654321) + f = WindowUpdateFrame( + 4, + Frame.FLAG_NO_FLAGS, + 0x1234567, + window_size_increment=0x7654321) assert f.human_readable() def test_continuation_frame_to_bytes(): - f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') + f = ContinuationFrame( + 6, + ContinuationFrame.FLAG_END_HEADERS, + 0x1234567, + 'foobar') assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x0, 'foobar') @@ -478,5 +605,9 @@ def test_continuation_frame_from_bytes(): def test_continuation_frame_human_readable(): - f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar') + f = ContinuationFrame( + 6, + ContinuationFrame.FLAG_END_HEADERS, + 0x1234567, + 'foobar') assert f.human_readable() diff --git a/test/test_certutils.py b/test/test_certutils.py index 115cac4d..e079ec40 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -42,7 +42,8 @@ class TestCertStore: ca2 = certutils.CertStore.from_store(d, "test") assert ca2.get_cert("foo", []) - assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() + assert ca.default_ca.get_serial_number( + ) == ca2.default_ca.get_serial_number() def test_create_tmp(self): with tutils.tmpdir() as d: @@ -78,7 +79,8 @@ class TestCertStore: with tutils.tmpdir() as d: ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test") - assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() + assert not ca1.default_ca.get_serial_number( + ) == ca2.default_ca.get_serial_number() dc = ca2.get_cert("foo.com", ["sans.example.com"]) dcp = os.path.join(d, "dc") @@ -93,8 +95,16 @@ class TestCertStore: def test_gen_pkey(self): try: with tutils.tmpdir() as d: - ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") - ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test") + ca1 = certutils.CertStore.from_store( + os.path.join( + d, + "ca1"), + "test") + ca2 = certutils.CertStore.from_store( + os.path.join( + d, + "ca2"), + "test") cert = ca1.get_cert("foo.com", []) assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1 finally: diff --git a/test/test_http_auth.py b/test/test_http_auth.py index 045fb13e..c842925b 100644 --- a/test/test_http_auth.py +++ b/test/test_http_auth.py @@ -13,7 +13,10 @@ class TestPassManNonAnon: class TestPassManHtpasswd: def test_file_errors(self): - tutils.raises("malformed htpasswd file", http_auth.PassManHtpasswd, tutils.test_data.path("data/server.crt")) + tutils.raises( + "malformed htpasswd file", + http_auth.PassManHtpasswd, + tutils.test_data.path("data/server.crt")) def test_simple(self): pm = http_auth.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) diff --git a/test/test_socks.py b/test/test_socks.py index a596dedf..a9db4706 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -61,7 +61,12 @@ def test_message_ipv6(): # Test ATYP=0x04 (IPV6) ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344" - raw = tutils.treader("\x05\x01\x00\x04" + socket.inet_pton(socket.AF_INET6, ipv6_addr) + "\xDE\xAD\xBE\xEF") + raw = tutils.treader( + "\x05\x01\x00\x04" + + socket.inet_pton( + socket.AF_INET6, + ipv6_addr) + + "\xDE\xAD\xBE\xEF") out = StringIO() msg = socks.Message.from_file(raw) assert raw.read(2) == "\xBE\xEF" diff --git a/test/test_tcp.py b/test/test_tcp.py index 62617707..14ba555d 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -75,7 +75,9 @@ class TestServerBind(test.ServerTestBase): for i in range(20): random_port = random.randrange(1024, 65535) try: - c = tcp.TCPClient(("127.0.0.1", self.port), source_address=("127.0.0.1", random_port)) + c = tcp.TCPClient( + ("127.0.0.1", self.port), source_address=( + "127.0.0.1", random_port)) c.connect() assert c.rfile.readline() == str(("127.0.0.1", random_port)) return @@ -196,7 +198,8 @@ class TestSSLClientCert(test.ServerTestBase): def test_clientcert(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(cert=tutils.test_data.path("data/clientcert/client.pem")) + c.convert_to_ssl( + cert=tutils.test_data.path("data/clientcert/client.pem")) assert c.rfile.readline().strip() == "1" def test_clientcert_err(self): @@ -305,7 +308,11 @@ class TestClientCipherListError(test.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - tutils.raises("cipher specification", c.convert_to_ssl, sni="foo.com", cipher_list="bogus") + tutils.raises( + "cipher specification", + c.convert_to_ssl, + sni="foo.com", + cipher_list="bogus") class TestSSLDisconnect(test.ServerTestBase): @@ -666,5 +673,7 @@ class TestSSLKeyLogger(test.ServerTestBase): tcp.log_ssl_key = _logfun def test_create_logfun(self): - assert isinstance(tcp.SSLKeyLogger.create_logfun("test"), tcp.SSLKeyLogger) + assert isinstance( + tcp.SSLKeyLogger.create_logfun("test"), + tcp.SSLKeyLogger) assert not tcp.SSLKeyLogger.create_logfun(False) diff --git a/test/test_websockets.py b/test/test_websockets.py index 38947295..8ed14708 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -63,7 +63,8 @@ class WebSocketsClient(tcp.TCPClient): resp = http.read_response(self.rfile, "get", None) server_nonce = websockets.check_server_handshake(resp.headers) - if not server_nonce == websockets.create_server_nonce(self.client_nonce): + if not server_nonce == websockets.create_server_nonce( + self.client_nonce): self.close() def read_next_message(self): -- cgit v1.2.3 From b395049a853aa378773aebae83468c1b889c2d4e Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 28 May 2015 10:56:00 +0200 Subject: distribute cffi correctly --- netlib/certffi.py | 4 ++-- setup.py | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/netlib/certffi.py b/netlib/certffi.py index 81dc72e8..451f4493 100644 --- a/netlib/certffi.py +++ b/netlib/certffi.py @@ -1,8 +1,8 @@ from __future__ import (absolute_import, print_function, division) -import cffi +from cffi import FFI import OpenSSL -xffi = cffi.FFI() +xffi = FFI() xffi.cdef(""" struct rsa_meth_st { int flags; diff --git a/setup.py b/setup.py index b5674d85..0051ea77 100644 --- a/setup.py +++ b/setup.py @@ -1,16 +1,39 @@ +from distutils.command.build import build +from setuptools.command.install import install from setuptools import setup, find_packages from codecs import open import os + from netlib import version # Based on https://github.com/pypa/sampleproject/blob/master/setup.py # and https://python-packaging-user-guide.readthedocs.org/ +# and https://caremad.io/2014/11/distributing-a-cffi-project/ here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, 'README.mkd'), encoding='utf-8') as f: long_description = f.read() + +def get_ext_modules(): + from netlib import certffi + return [certffi.xffi.verifier.get_extension()] + + +class CFFIBuild(build): + + def finalize_options(self): + self.distribution.ext_modules = get_ext_modules() + build.finalize_options(self) + + +class CFFIInstall(install): + + def finalize_options(self): + self.distribution.ext_modules = get_ext_modules() + install.finalize_options(self) + setup( name="netlib", version=version.VERSION, @@ -36,12 +59,18 @@ setup( ], packages=find_packages(), include_package_data=True, + zip_safe=False, install_requires=[ + "cffi", "pyasn1>=0.1.7", "pyOpenSSL>=0.15.1", "cryptography>=0.9", "passlib>=1.6.2", "hpack>=1.0.1"], + setup_requires=[ + "cffi", + "pyOpenSSL>=0.15.1", + ], extras_require={ 'dev': [ "mock>=1.0.1", @@ -52,4 +81,9 @@ setup( "autoflake>=0.6.6", "pathod>=%s, <%s" % (version.MINORVERSION, - version.NEXT_MINORVERSION)]}) + version.NEXT_MINORVERSION)]}, + cmdclass={ + "build": CFFIBuild, + "install": CFFIInstall, + }, +) -- cgit v1.2.3 From 4ec181c1403670702c2f163062b92de4dec3d2cc Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 31 May 2015 13:12:01 +1200 Subject: Move version check to netlib, unit test it. --- netlib/version_check.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++ test/test_version_check.py | 22 +++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 netlib/version_check.py create mode 100644 test/test_version_check.py diff --git a/netlib/version_check.py b/netlib/version_check.py new file mode 100644 index 00000000..09dc23ae --- /dev/null +++ b/netlib/version_check.py @@ -0,0 +1,49 @@ +from __future__ import print_function, absolute_import +import sys +import inspect +import os.path + +import OpenSSL +from . import version + +PYOPENSSL_MIN_VERSION = (0, 15) + + +def version_check( + mitmproxy_version, + pyopenssl_min_version=PYOPENSSL_MIN_VERSION, + fp=sys.stderr): + """ + Having installed a wrong version of pyOpenSSL or netlib is unfortunately a + very common source of error. Check before every start that both versions + are somewhat okay. + """ + # We don't introduce backward-incompatible changes in patch versions. Only + # consider major and minor version. + if version.IVERSION[:2] != mitmproxy_version[:2]: + print( + "You are using mitmproxy %s with netlib %s. " + "Most likely, that won't work - please upgrade!" % ( + mitmproxy_version, version.VERSION + ), + file=fp + ) + sys.exit(1) + v = tuple([int(x) for x in OpenSSL.__version__.split(".")][:2]) + if v < pyopenssl_min_version: + print( + "You are using an outdated version of pyOpenSSL:" + " mitmproxy requires pyOpenSSL %x or greater." % + pyopenssl_min_version, + file=fp + ) + # Some users apparently have multiple versions of pyOpenSSL installed. + # Report which one we got. + pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL)) + print( + "Your pyOpenSSL %s installation is located at %s" % ( + OpenSSL.__version__, pyopenssl_path + ), + file=fp + ) + sys.exit(1) diff --git a/test/test_version_check.py b/test/test_version_check.py new file mode 100644 index 00000000..bf6ad1f5 --- /dev/null +++ b/test/test_version_check.py @@ -0,0 +1,22 @@ +import cStringIO +import mock +from netlib import version_check, version + + +@mock.patch("sys.exit") +def test_version_check(sexit): + fp = cStringIO.StringIO() + version_check.version_check(version.IVERSION, fp=fp) + assert not sexit.called + + b = (version.IVERSION[0] - 1, version.IVERSION[1]) + version_check.version_check(b, fp=fp) + assert sexit.called + + sexit.reset_mock() + version_check.version_check( + version.IVERSION, + pyopenssl_min_version=(9999,), + fp=fp + ) + assert sexit.called -- cgit v1.2.3 From 73376e605a61fab239213da375a612ed7d3274b5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 31 May 2015 16:54:14 +1200 Subject: Save first byte timestamp for writers too. --- netlib/tcp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index a705c95b..c8545d4f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -147,6 +147,7 @@ class Writer(_FileLike): May raise NetLibDisconnect """ if v: + self.first_byte_timestamp = self.first_byte_timestamp or time.time() try: if hasattr(self.o, "sendall"): self.add_log(v) -- cgit v1.2.3 From f7bd690e3aba0be05c30a3b9a4d499de8dbd5e06 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 31 May 2015 17:18:55 +1200 Subject: When we see an incomplete read with 0 bytes, it's a disconnect Partially fixes mitmproxy/mitmproxy:#593 --- netlib/tcp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index c8545d4f..f6179faa 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -225,9 +225,12 @@ class Reader(_FileLike): """ result = self.read(length) if length != -1 and len(result) != length: - raise NetLibIncomplete( - "Expected %s bytes, got %s" % (length, len(result)) - ) + if not result: + raise NetLibDisconnect() + else: + raise NetLibIncomplete( + "Expected %s bytes, got %s" % (length, len(result)) + ) return result -- cgit v1.2.3 From 35856ead075829d5b086e60c60ac20fdfc8560f1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 31 May 2015 17:24:44 +1200 Subject: websockets: nicer human readable --- netlib/websockets.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index 63dc03f1..bf920897 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -175,7 +175,7 @@ class FrameHeader: def human_readable(self): vals = [ - "wf:", + "ws frame:", OPCODE.get_name(self.opcode, hex(self.opcode)).lower() ] flags = [] @@ -327,8 +327,10 @@ class Frame(object): return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) def human_readable(self): - hdr = self.header.human_readable() - return hdr + "\n" + repr(self.payload) + ret = self.header.human_readable() + if self.payload: + ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) + return ret def to_bytes(self): """ -- cgit v1.2.3 From 113c5c187f0c37ce0c13c399248f4bf91e3a3149 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 4 Jun 2015 11:14:47 +1200 Subject: Bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index 3eb0ffc9..bc9a1a57 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 12, 1) +IVERSION = (0, 12, 2) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 4ca62e0d9bd09aa286cde9bafceff7204304d00c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 5 Jun 2015 11:42:06 +1200 Subject: tcp: clear_log to clear socket logs --- netlib/tcp.py | 3 +++ test/test_tcp.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index f6179faa..2ebfae96 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -126,6 +126,9 @@ class _FileLike(object): if self.is_logging(): self._log.append(v) + def clear_log(self): + self._log = [] + def reset_timestamps(self): self.first_byte_timestamp = None diff --git a/test/test_tcp.py b/test/test_tcp.py index 14ba555d..362ba0f4 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -553,6 +553,8 @@ class TestFileLike: assert s.get_log() == "" s.read(1) assert s.get_log() == "o" + s.clear_log() + assert s.get_log() == "" s.stop_log() tutils.raises(ValueError, s.get_log) -- cgit v1.2.3 From 2d9b9be1f4fb67d6989b57b68858896d8512293e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 5 Jun 2015 11:50:29 +1200 Subject: Revert "tcp: clear_log to clear socket logs" start_log also clears the log, which is good enough. This reverts commit 4ca62e0d9bd09aa286cde9bafceff7204304d00c. --- netlib/tcp.py | 3 --- test/test_tcp.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 2ebfae96..f6179faa 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -126,9 +126,6 @@ class _FileLike(object): if self.is_logging(): self._log.append(v) - def clear_log(self): - self._log = [] - def reset_timestamps(self): self.first_byte_timestamp = None diff --git a/test/test_tcp.py b/test/test_tcp.py index 362ba0f4..14ba555d 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -553,8 +553,6 @@ class TestFileLike: assert s.get_log() == "" s.read(1) assert s.get_log() == "o" - s.clear_log() - assert s.get_log() == "" s.stop_log() tutils.raises(ValueError, s.get_log) -- cgit v1.2.3 From 0269d0fb8b8726f8a84ebe916a553ef435a3a50d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 5 Jun 2015 17:08:22 +1200 Subject: repr for websocket frames --- netlib/websockets.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/netlib/websockets.py b/netlib/websockets.py index bf920897..346adf1b 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -332,6 +332,9 @@ class Frame(object): ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) return ret + def __repr__(self): + return self.header.human_readable() + def to_bytes(self): """ Serialize the frame to wire format. Returns a string. -- cgit v1.2.3 From 9883509f894dde57c8a71340a69581ac46c44f51 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 5 Jun 2015 12:44:29 +0200 Subject: simplify default ssl params for test servers --- netlib/test.py | 30 +++++++++++++++------- test/test_tcp.py | 76 ++++++++------------------------------------------------ 2 files changed, 32 insertions(+), 74 deletions(-) diff --git a/netlib/test.py b/netlib/test.py index 14f50157..ee8c6685 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -4,6 +4,7 @@ import Queue import cStringIO import OpenSSL from . import tcp, certutils +import tutils class ServerThread(threading.Thread): @@ -55,22 +56,33 @@ class TServer(tcp.TCPServer): dhparams, v3_only """ tcp.TCPServer.__init__(self, addr) - self.ssl, self.q = ssl, q + + if ssl is True: + self.ssl = dict() + elif isinstance(ssl, dict): + self.ssl = ssl + else: + self.ssl = None + + self.q = q self.handler_klass = handler_klass self.last_handler = None def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h - if self.ssl: - cert = certutils.SSLCert.from_pem( - file(self.ssl["cert"], "rb").read() - ) - raw = file(self.ssl["key"], "rb").read() + if self.ssl is not None: + raw_cert = self.ssl.get( + "cert", + tutils.test_data.path("data/server.crt")) + cert = certutils.SSLCert.from_pem(file(raw_cert, "rb").read()) + raw_key = self.ssl.get( + "key", + tutils.test_data.path("data/server.key")) key = OpenSSL.crypto.load_privatekey( OpenSSL.crypto.FILETYPE_PEM, - raw) - if self.ssl["v3_only"]: + file(raw_key, "rb").read()) + if self.ssl.get("v3_only", False): method = tcp.SSLv3_METHOD options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 else: @@ -81,7 +93,7 @@ class TServer(tcp.TCPServer): method=method, options=options, handle_sni=getattr(h, "handle_sni", None), - request_client_cert=self.ssl["request_client_cert"], + request_client_cert=self.ssl.get("request_client_cert", None), cipher_list=self.ssl.get("cipher_list", None), dhparams=self.ssl.get("dhparams", None), chain_file=self.ssl.get("chain_file", None), diff --git a/test/test_tcp.py b/test/test_tcp.py index 14ba555d..cbe92f3c 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -135,10 +135,6 @@ class TestFinishFail(test.ServerTestBase): class TestServerSSL(test.ServerTestBase): handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list="AES256-SHA", chain_file=tutils.test_data.path("data/server.crt") ) @@ -165,8 +161,6 @@ class TestServerSSL(test.ServerTestBase): class TestSSLv3Only(test.ServerTestBase): handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), request_client_cert=False, v3_only=True ) @@ -188,9 +182,8 @@ class TestSSLClientCert(test.ServerTestBase): def handle(self): self.wfile.write("%s\n" % self.clientcert.serial) self.wfile.flush() + ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), request_client_cert=True, v3_only=False ) @@ -224,12 +217,7 @@ class TestSNI(test.ServerTestBase): self.wfile.write(self.sni) self.wfile.flush() - ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False - ) + ssl = True def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -242,10 +230,6 @@ class TestSNI(test.ServerTestBase): class TestServerCipherList(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list='RC4-SHA' ) @@ -264,11 +248,8 @@ class TestServerCurrentCipher(test.ServerTestBase): def handle(self): self.wfile.write("%s" % str(self.get_current_cipher())) self.wfile.flush() + ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list='RC4-SHA' ) @@ -282,10 +263,6 @@ class TestServerCurrentCipher(test.ServerTestBase): class TestServerCipherListError(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list='bogus' ) @@ -298,10 +275,6 @@ class TestServerCipherListError(test.ServerTestBase): class TestClientCipherListError(test.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list='RC4-SHA' ) @@ -321,12 +294,8 @@ class TestSSLDisconnect(test.ServerTestBase): def handle(self): self.finish() - ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False - ) + + ssl = True def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -341,12 +310,7 @@ class TestSSLDisconnect(test.ServerTestBase): class TestSSLHardDisconnect(test.ServerTestBase): handler = HardDisconnectHandler - ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False - ) + ssl = True def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -400,13 +364,9 @@ class TestTimeOut(test.ServerTestBase): class TestALPN(test.ServerTestBase): - handler = HangHandler + handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, - alpn_select="h2" + alpn_select="foobar" ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -414,19 +374,13 @@ class TestALPN(test.ServerTestBase): def test_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=["h2"]) - print "ALPN: %s" % c.get_alpn_proto_negotiated() - assert c.get_alpn_proto_negotiated() == "h2" + c.convert_to_ssl(alpn_protos=["foobar"]) + assert c.get_alpn_proto_negotiated() == "foobar" class TestSSLTimeOut(test.ServerTestBase): handler = HangHandler - ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False - ) + ssl = True def test_timeout_client(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -439,10 +393,6 @@ class TestSSLTimeOut(test.ServerTestBase): class TestDHParams(test.ServerTestBase): handler = HangHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, dhparams=certutils.CertStore.load_dhparam( tutils.test_data.path("data/dhparam.pem"), ), @@ -643,10 +593,6 @@ class TestAddress: class TestSSLKeyLogger(test.ServerTestBase): handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt"), - key=tutils.test_data.path("data/server.key"), - request_client_cert=False, - v3_only=False, cipher_list="AES256-SHA" ) -- cgit v1.2.3 From 436291764c4e557155d7e4e87482a4e378a2ccce Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 1 Jun 2015 15:14:31 +0200 Subject: http2: fix default settings --- netlib/h2/h2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index 707b1465..227139a3 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -29,8 +29,8 @@ class H2Client(tcp.TCPClient): SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ^ 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ^ 14, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, } -- cgit v1.2.3 From b84001e8f082a9198f56037aa6861a360d5d76cf Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 1 Jun 2015 15:16:00 +0200 Subject: http2: explicitly mention all arguments in tests --- test/h2/test_frames.py | 325 ++++++++++++++++++++++++++++++------------------- 1 file changed, 203 insertions(+), 122 deletions(-) diff --git a/test/h2/test_frames.py b/test/h2/test_frames.py index 313ef405..310336b0 100644 --- a/test/h2/test_frames.py +++ b/test/h2/test_frames.py @@ -3,43 +3,59 @@ import tutils from nose.tools import assert_equal -# TODO test stream association if valid or not - - def test_invalid_flags(): tutils.raises( ValueError, DataFrame, - ContinuationFrame.FLAG_END_HEADERS, - 0x1234567, - 'foobar') + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + payload='foobar') def test_frame_equality(): - a = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') - b = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') + a = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + b = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') assert_equal(a, b) def test_too_large_frames(): - DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567) + DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567) def test_data_frame_to_bytes(): - f = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar') + f = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') f = DataFrame( - 11, - Frame.FLAG_END_STREAM | Frame.FLAG_PADDED, - 0x1234567, - 'foobar', + length=11, + flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), + stream_id=0x1234567, + payload='foobar', pad_length=3) assert_equal( f.to_bytes().encode('hex'), '00000a00090123456703666f6f626172000000') - f = DataFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') + f = DataFrame( + length=6, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload='foobar') tutils.raises(ValueError, f.to_bytes) @@ -63,26 +79,26 @@ def test_data_frame_from_bytes(): def test_data_frame_human_readable(): f = DataFrame( - 11, - Frame.FLAG_END_STREAM | Frame.FLAG_PADDED, - 0x1234567, - 'foobar', + length=11, + flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), + stream_id=0x1234567, + payload='foobar', pad_length=3) assert f.human_readable() def test_headers_frame_to_bytes(): f = HeadersFrame( - 6, - Frame.FLAG_NO_FLAGS, - 0x1234567, + length=6, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, headers=[('host', 'foo.bar')]) assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') f = HeadersFrame( - 10, - HeadersFrame.FLAG_PADDED, - 0x1234567, + length=10, + flags=(HeadersFrame.FLAG_PADDED), + stream_id=0x1234567, headers=[('host', 'foo.bar')], pad_length=3) assert_equal( @@ -90,9 +106,9 @@ def test_headers_frame_to_bytes(): '00000b01080123456703668594e75e31d9000000') f = HeadersFrame( - 10, - HeadersFrame.FLAG_PRIORITY, - 0x1234567, + length=10, + flags=(HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, headers=[('host', 'foo.bar')], exclusive=True, stream_dependency=0x7654321, @@ -102,9 +118,9 @@ def test_headers_frame_to_bytes(): '00000c012001234567876543212a668594e75e31d9') f = HeadersFrame( - 14, - HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, - 0x1234567, + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, headers=[('host', 'foo.bar')], pad_length=3, exclusive=True, @@ -115,9 +131,9 @@ def test_headers_frame_to_bytes(): '00001001280123456703876543212a668594e75e31d9000000') f = HeadersFrame( - 14, - HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, - 0x1234567, + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, headers=[('host', 'foo.bar')], pad_length=3, exclusive=False, @@ -127,7 +143,11 @@ def test_headers_frame_to_bytes(): f.to_bytes().encode('hex'), '00001001280123456703076543212a668594e75e31d9000000') - f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar') + f = HeadersFrame( + length=6, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + headers=[('host', 'foo.bar')]) tutils.raises(ValueError, f.to_bytes) @@ -188,9 +208,9 @@ def test_headers_frame_from_bytes(): def test_headers_frame_human_readable(): f = HeadersFrame( - 7, - HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, - 0x1234567, + length=7, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, headers=[], pad_length=3, exclusive=False, @@ -199,9 +219,9 @@ def test_headers_frame_human_readable(): assert f.human_readable() f = HeadersFrame( - 14, - HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, - 0x1234567, + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, headers=[('host', 'foo.bar')], pad_length=3, exclusive=False, @@ -212,27 +232,35 @@ def test_headers_frame_human_readable(): def test_priority_frame_to_bytes(): f = PriorityFrame( - 5, - Frame.FLAG_NO_FLAGS, - 0x1234567, + length=5, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, exclusive=True, stream_dependency=0x7654321, weight=42) assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') f = PriorityFrame( - 5, - Frame.FLAG_NO_FLAGS, - 0x1234567, + length=5, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, exclusive=False, stream_dependency=0x7654321, weight=21) assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') - f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x0, stream_dependency=0x1234567) + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + stream_dependency=0x1234567) tutils.raises(ValueError, f.to_bytes) - f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, stream_dependency=0x0) + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + stream_dependency=0x0) tutils.raises(ValueError, f.to_bytes) @@ -260,9 +288,9 @@ def test_priority_frame_from_bytes(): def test_priority_frame_human_readable(): f = PriorityFrame( - 5, - Frame.FLAG_NO_FLAGS, - 0x1234567, + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, exclusive=False, stream_dependency=0x7654321, weight=21) @@ -270,10 +298,17 @@ def test_priority_frame_human_readable(): def test_rst_stream_frame_to_bytes(): - f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, error_code=0x7654321) + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + error_code=0x7654321) assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') - f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x0) + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0) tutils.raises(ValueError, f.to_bytes) @@ -288,28 +323,39 @@ def test_rst_stream_frame_from_bytes(): def test_rst_stream_frame_human_readable(): - f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, error_code=0x7654321) + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + error_code=0x7654321) assert f.human_readable() def test_settings_frame_to_bytes(): - f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x0) + f = SettingsFrame( + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0) assert_equal(f.to_bytes().encode('hex'), '000000040000000000') - f = SettingsFrame(0, SettingsFrame.FLAG_ACK, 0x0) + f = SettingsFrame( + length=0, + flags=SettingsFrame.FLAG_ACK, + stream_id=0x0) assert_equal(f.to_bytes().encode('hex'), '000000040100000000') f = SettingsFrame( - 6, - SettingsFrame.FLAG_ACK, 0x0, + length=6, + flags=SettingsFrame.FLAG_ACK, + stream_id=0x0, settings={ SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') f = SettingsFrame( - 12, - Frame.FLAG_NO_FLAGS, - 0x0, + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, settings={ SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) @@ -317,7 +363,10 @@ def test_settings_frame_to_bytes(): f.to_bytes().encode('hex'), '00000c040000000000000200000001000312345678') - f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x1234567) + f = SettingsFrame( + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567) tutils.raises(ValueError, f.to_bytes) @@ -361,13 +410,17 @@ def test_settings_frame_from_bytes(): def test_settings_frame_human_readable(): - f = SettingsFrame(12, Frame.FLAG_NO_FLAGS, 0x0, settings={}) + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={}) assert f.human_readable() f = SettingsFrame( - 12, - Frame.FLAG_NO_FLAGS, - 0x0, + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, settings={ SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) @@ -376,30 +429,38 @@ def test_settings_frame_human_readable(): def test_push_promise_frame_to_bytes(): f = PushPromiseFrame( - 10, - Frame.FLAG_NO_FLAGS, - 0x1234567, - 0x7654321, - 'foobar') + length=10, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar') assert_equal( f.to_bytes().encode('hex'), '00000a05000123456707654321666f6f626172') f = PushPromiseFrame( - 14, - HeadersFrame.FLAG_PADDED, - 0x1234567, - 0x7654321, - 'foobar', + length=14, + flags=HeadersFrame.FLAG_PADDED, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar', pad_length=3) assert_equal( f.to_bytes().encode('hex'), '00000e0508012345670307654321666f6f626172000000') - f = PushPromiseFrame(4, Frame.FLAG_NO_FLAGS, 0x0, 0x1234567) + f = PushPromiseFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x1234567) tutils.raises(ValueError, f.to_bytes) - f = PushPromiseFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, 0x0) + f = PushPromiseFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + promised_stream=0x0) tutils.raises(ValueError, f.to_bytes) @@ -424,27 +485,38 @@ def test_push_promise_frame_from_bytes(): def test_push_promise_frame_human_readable(): f = PushPromiseFrame( - 14, - HeadersFrame.FLAG_PADDED, - 0x1234567, - 0x7654321, - 'foobar', + length=14, + flags=HeadersFrame.FLAG_PADDED, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar', pad_length=3) assert f.human_readable() def test_ping_frame_to_bytes(): - f = PingFrame(8, PingFrame.FLAG_ACK, 0x0, payload=b'foobar') + f = PingFrame( + length=8, + flags=PingFrame.FLAG_ACK, + stream_id=0x0, + payload=b'foobar') assert_equal( f.to_bytes().encode('hex'), '000008060100000000666f6f6261720000') - f = PingFrame(8, Frame.FLAG_NO_FLAGS, 0x0, payload=b'foobardeadbeef') + f = PingFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'foobardeadbeef') assert_equal( f.to_bytes().encode('hex'), '000008060000000000666f6f6261726465') - f = PingFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567) + f = PingFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567) tutils.raises(ValueError, f.to_bytes) @@ -467,15 +539,19 @@ def test_ping_frame_from_bytes(): def test_ping_frame_human_readable(): - f = PingFrame(8, PingFrame.FLAG_ACK, 0x0, payload=b'foobar') + f = PingFrame( + length=8, + flags=PingFrame.FLAG_ACK, + stream_id=0x0, + payload=b'foobar') assert f.human_readable() def test_goaway_frame_to_bytes(): f = GoAwayFrame( - 8, - Frame.FLAG_NO_FLAGS, - 0x0, + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, last_stream=0x1234567, error_code=0x87654321, data=b'') @@ -484,9 +560,9 @@ def test_goaway_frame_to_bytes(): '0000080700000000000123456787654321') f = GoAwayFrame( - 14, - Frame.FLAG_NO_FLAGS, - 0x0, + length=14, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, last_stream=0x1234567, error_code=0x87654321, data=b'foobar') @@ -495,16 +571,17 @@ def test_goaway_frame_to_bytes(): '00000e0700000000000123456787654321666f6f626172') f = GoAwayFrame( - 8, - Frame.FLAG_NO_FLAGS, - 0x1234567, + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, last_stream=0x1234567, error_code=0x87654321) tutils.raises(ValueError, f.to_bytes) def test_goaway_frame_from_bytes(): - f = Frame.from_bytes('0000080700000000000123456787654321'.decode('hex')) + f = Frame.from_bytes( + '0000080700000000000123456787654321'.decode('hex')) assert isinstance(f, GoAwayFrame) assert_equal(f.length, 8) assert_equal(f.TYPE, GoAwayFrame.TYPE) @@ -528,9 +605,9 @@ def test_goaway_frame_from_bytes(): def test_go_away_frame_human_readable(): f = GoAwayFrame( - 14, - Frame.FLAG_NO_FLAGS, - 0x0, + length=14, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, last_stream=0x1234567, error_code=0x87654321, data=b'foobar') @@ -539,23 +616,23 @@ def test_go_away_frame_human_readable(): def test_window_update_frame_to_bytes(): f = WindowUpdateFrame( - 4, - Frame.FLAG_NO_FLAGS, - 0x0, + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, window_size_increment=0x1234567) assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') f = WindowUpdateFrame( - 4, - Frame.FLAG_NO_FLAGS, - 0x1234567, + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, window_size_increment=0x7654321) assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') f = WindowUpdateFrame( - 4, - Frame.FLAG_NO_FLAGS, - 0x0, + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, window_size_increment=0xdeadbeef) tutils.raises(ValueError, f.to_bytes) @@ -575,22 +652,26 @@ def test_window_update_frame_from_bytes(): def test_window_update_frame_human_readable(): f = WindowUpdateFrame( - 4, - Frame.FLAG_NO_FLAGS, - 0x1234567, + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, window_size_increment=0x7654321) assert f.human_readable() def test_continuation_frame_to_bytes(): f = ContinuationFrame( - 6, - ContinuationFrame.FLAG_END_HEADERS, - 0x1234567, - 'foobar') + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + header_block_fragment='foobar') assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') - f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x0, 'foobar') + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x0, + header_block_fragment='foobar') tutils.raises(ValueError, f.to_bytes) @@ -606,8 +687,8 @@ def test_continuation_frame_from_bytes(): def test_continuation_frame_human_readable(): f = ContinuationFrame( - 6, - ContinuationFrame.FLAG_END_HEADERS, - 0x1234567, - 'foobar') + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + header_block_fragment='foobar') assert f.human_readable() -- cgit v1.2.3 From e4c129026fbf4228c13ae64da19a9a85fc7ff2a5 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 1 Jun 2015 15:17:50 +0200 Subject: http2: introduce state for connection objects --- netlib/h2/frame.py | 102 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 63 insertions(+), 39 deletions(-) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 36456c46..174ceebd 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -20,18 +20,28 @@ class Frame(object): FLAG_PADDED = 0x8 FLAG_PRIORITY = 0x20 - def __init__(self, length, flags, stream_id): + def __init__(self, state=None, length=0, flags=FLAG_NO_FLAGS, stream_id=0x0): valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) if flags | valid_flags != valid_flags: raise ValueError('invalid flags detected.') + if state is None: + class State(object): + pass + + state = State() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + self.length = length self.type = self.TYPE self.flags = flags self.stream_id = stream_id @classmethod - def from_file(self, fp): + def from_file(self, fp, state=None): """ read a HTTP/2 frame sent by a server or client fp is a "file like" object that could be backed by a network @@ -45,16 +55,16 @@ class Frame(object): stream_id = fields[4] payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes(length, flags, stream_id, payload) + return FRAMES[fields[2]].from_bytes(state, length, flags, stream_id, payload) @classmethod - def from_bytes(self, data): + def from_bytes(self, data, state=None): fields = struct.unpack("!HBBBL", data[:9]) length = (fields[0] << 8) + fields[1] # type is already deducted from class flags = fields[3] stream_id = fields[4] - return FRAMES[fields[2]].from_bytes(length, flags, stream_id, data[9:]) + return FRAMES[fields[2]].from_bytes(state, length, flags, stream_id, data[9:]) def to_bytes(self): payload = self.payload_bytes() @@ -96,18 +106,19 @@ class DataFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b'', pad_length=0): - super(DataFrame, self).__init__(length, flags, stream_id) + super(DataFrame, self).__init__(state, length, flags, stream_id) self.payload = payload self.pad_length = pad_length @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) if f.flags & self.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] @@ -146,6 +157,7 @@ class HeadersFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, @@ -154,7 +166,7 @@ class HeadersFrame(Frame): exclusive=False, stream_dependency=0x0, weight=0): - super(HeadersFrame, self).__init__(length, flags, stream_id) + super(HeadersFrame, self).__init__(state, length, flags, stream_id) if headers is None: headers = [] @@ -166,8 +178,8 @@ class HeadersFrame(Frame): self.weight = weight @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) if f.flags & self.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] @@ -177,18 +189,22 @@ class HeadersFrame(Frame): if f.flags & self.FLAG_PRIORITY: f.stream_dependency, f.weight = struct.unpack( - '!LB', header_block_fragment[ - :5]) + '!LB', header_block_fragment[:5]) f.exclusive = bool(f.stream_dependency >> 31) f.stream_dependency &= 0x7FFFFFFF header_block_fragment = header_block_fragment[5:] - for header, value in Decoder().decode(header_block_fragment): + for header, value in f.state.decoder.decode(header_block_fragment): f.headers.append((header, value)) return f def payload_bytes(self): + """ + This encodes all headers with HPACK + Do NOT call this method twice - it will change the encoder state! + """ + if self.stream_id == 0x0: raise ValueError('HEADERS frames MUST be associated with a stream.') @@ -201,7 +217,7 @@ class HeadersFrame(Frame): (int(self.exclusive) << 31) | self.stream_dependency, self.weight) - b += Encoder().encode(self.headers) + b += self.state.encoder.encode(self.headers) if self.flags & self.FLAG_PADDED: b += b'\0' * self.pad_length @@ -234,20 +250,21 @@ class PriorityFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, exclusive=False, stream_dependency=0x0, weight=0): - super(PriorityFrame, self).__init__(length, flags, stream_id) + super(PriorityFrame, self).__init__(state, length, flags, stream_id) self.exclusive = exclusive self.stream_dependency = stream_dependency self.weight = weight @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.stream_dependency, f.weight = struct.unpack('!LB', payload) f.exclusive = bool(f.stream_dependency >> 31) @@ -283,16 +300,17 @@ class RstStreamFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, error_code=0x0): - super(RstStreamFrame, self).__init__(length, flags, stream_id) + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) self.error_code = error_code @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.error_code = struct.unpack('!L', payload)[0] return f @@ -322,11 +340,12 @@ class SettingsFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings=None): - super(SettingsFrame, self).__init__(length, flags, stream_id) + super(SettingsFrame, self).__init__(state, length, flags, stream_id) if settings is None: settings = {} @@ -334,8 +353,8 @@ class SettingsFrame(Frame): self.settings = settings @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) for i in xrange(0, len(payload), 6): identifier, value = struct.unpack("!HL", payload[i:i + 6]) @@ -372,20 +391,21 @@ class PushPromiseFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, promised_stream=0x0, header_block_fragment=b'', pad_length=0): - super(PushPromiseFrame, self).__init__(length, flags, stream_id) + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) self.pad_length = pad_length self.promised_stream = promised_stream self.header_block_fragment = header_block_fragment @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) if f.flags & self.FLAG_PADDED: f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) @@ -435,16 +455,17 @@ class PingFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b''): - super(PingFrame, self).__init__(length, flags, stream_id) + super(PingFrame, self).__init__(state, length, flags, stream_id) self.payload = payload @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.payload = payload return f @@ -467,20 +488,21 @@ class GoAwayFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, last_stream=0x0, error_code=0x0, data=b''): - super(GoAwayFrame, self).__init__(length, flags, stream_id) + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) self.last_stream = last_stream self.error_code = error_code self.data = data @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) f.last_stream &= 0x7FFFFFFF @@ -511,16 +533,17 @@ class WindowUpdateFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(length, flags, stream_id) + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) self.window_size_increment = window_size_increment @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.window_size_increment = struct.unpack("!L", payload)[0] f.window_size_increment &= 0x7FFFFFFF @@ -544,16 +567,17 @@ class ContinuationFrame(Frame): def __init__( self, + state=None, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b''): - super(ContinuationFrame, self).__init__(length, flags, stream_id) + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) self.header_block_fragment = header_block_fragment @classmethod - def from_bytes(self, length, flags, stream_id, payload): - f = self(length=length, flags=flags, stream_id=stream_id) + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) f.header_block_fragment = payload return f -- cgit v1.2.3 From 5cecbdc1687346bb2bf139c904ffda2b37dc8276 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 1 Jun 2015 12:34:50 +0200 Subject: http2: add basic protocol handling --- netlib/h2/__init__.py | 169 +++++++++++++++++++++++++++++++++++++++++++++++++ netlib/h2/frame.py | 50 ++++++++++++--- netlib/h2/h2.py | 89 -------------------------- test/h2/example.py | 18 ------ test/h2/test_frames.py | 1 + 5 files changed, 212 insertions(+), 115 deletions(-) delete mode 100644 netlib/h2/h2.py delete mode 100644 test/h2/example.py diff --git a/netlib/h2/__init__.py b/netlib/h2/__init__.py index 9b4faa33..054ba91c 100644 --- a/netlib/h2/__init__.py +++ b/netlib/h2/__init__.py @@ -1 +1,170 @@ from __future__ import (absolute_import, print_function, division) +import itertools + +from .. import utils +from .frame import * + + +class HTTP2Protocol(object): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + + ALPN_PROTO_H2 = b'h2' + + HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, + } + + def __init__(self): + self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.encoder = Encoder() + self.decoder = Decoder() + + def check_alpn(self): + alp = self.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "H2Client can not handle unknown ALP: %s" % alp) + print("-> Successfully negotiated 'h2' application layer protocol.") + + def send_connection_preface(self): + self.wfile.write(bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) + self.send_frame(SettingsFrame(state=self)) + + frame = Frame.from_file(self.rfile, self) + assert isinstance(frame, SettingsFrame) + self._apply_settings(frame.settings) + self.read_frame() # read setting ACK frame + + print("-> Connection Preface completed.") + + def next_stream_id(self): + if self.current_stream_id is None: + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def send_frame(self, frame): + raw_bytes = frame.to_bytes() + self.wfile.write(raw_bytes) + self.wfile.flush() + + def read_frame(self): + frame = Frame.from_file(self.rfile, self) + if isinstance(frame, SettingsFrame): + self._apply_settings(frame.settings) + + return frame + + def _apply_settings(self, settings): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + + self.http2_settings[setting] = value + print("-> Setting changed: %s to %d (was %s)" % ( + SettingsFrame.SETTINGS.get_name(setting), + value, + str(old_value))) + + self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) + print("-> New settings acknowledged.") + + def _create_headers(self, headers, stream_id, end_stream=True): + # TODO: implement max frame size checks and sending in chunks + + flags = Frame.FLAG_END_HEADERS + if end_stream: + flags |= Frame.FLAG_END_STREAM + + bytes = HeadersFrame( + state=self, + flags=flags, + stream_id=stream_id, + headers=headers).to_bytes() + return [bytes] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + # TODO: implement max frame size checks and sending in chunks + # TODO: implement flow-control window + + bytes = DataFrame( + state=self, + flags=Frame.FLAG_END_STREAM, + stream_id=stream_id, + payload=body).to_bytes() + return [bytes] + + def create_request(self, method, path, headers=None, body=None): + if headers is None: + headers = [] + + headers = [ + (b':method', bytes(method)), + (b':path', bytes(path)), + (b':scheme', b'https')] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) + + def read_response(self): + header_block_fragment = b'' + body = b'' + + while True: + frame = self.read_frame() + if isinstance(frame, HeadersFrame): + header_block_fragment += frame.header_block_fragment + if frame.flags | Frame.FLAG_END_HEADERS: + break + else: + print("Unexpected frame received:") + print(frame.human_readable()) + + while True: + frame = self.read_frame() + if isinstance(frame, DataFrame): + body += frame.payload + if frame.flags | Frame.FLAG_END_STREAM: + break + else: + print("Unexpected frame received:") + print(frame.human_readable()) + + headers = {} + for header, value in self.decoder.decode(header_block_fragment): + headers[header] = value + + return headers[':status'], headers, body diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 174ceebd..137cbb3d 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -20,16 +20,24 @@ class Frame(object): FLAG_PADDED = 0x8 FLAG_PRIORITY = 0x20 - def __init__(self, state=None, length=0, flags=FLAG_NO_FLAGS, stream_id=0x0): + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) if flags | valid_flags != valid_flags: raise ValueError('invalid flags detected.') if state is None: + from . import HTTP2Protocol + class State(object): pass state = State() + state.http2_settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS.copy() state.encoder = Encoder() state.decoder = Decoder() @@ -40,6 +48,14 @@ class Frame(object): self.flags = flags self.stream_id = stream_id + def _check_frame_size(self, length): + max_length = self.state.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + if length > max_length: + raise NotImplementedError( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_length)) + @classmethod def from_file(self, fp, state=None): """ @@ -54,8 +70,15 @@ class Frame(object): flags = fields[3] stream_id = fields[4] + # TODO: check frame size if <= current SETTINGS_MAX_FRAME_SIZE + payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes(state, length, flags, stream_id, payload) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) @classmethod def from_bytes(self, data, state=None): @@ -64,12 +87,20 @@ class Frame(object): # type is already deducted from class flags = fields[3] stream_id = fields[4] - return FRAMES[fields[2]].from_bytes(state, length, flags, stream_id, data[9:]) + + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + data[9:]) def to_bytes(self): payload = self.payload_bytes() self.length = len(payload) + self._check_frame_size(self.length) + b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) b += struct.pack('!B', self.TYPE) b += struct.pack('!B', self.flags) @@ -183,19 +214,20 @@ class HeadersFrame(Frame): if f.flags & self.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] - header_block_fragment = payload[1:-f.pad_length] + f.header_block_fragment = payload[1:-f.pad_length] else: - header_block_fragment = payload[0:] + f.header_block_fragment = payload[0:] if f.flags & self.FLAG_PRIORITY: f.stream_dependency, f.weight = struct.unpack( '!LB', header_block_fragment[:5]) f.exclusive = bool(f.stream_dependency >> 31) f.stream_dependency &= 0x7FFFFFFF - header_block_fragment = header_block_fragment[5:] + f.header_block_fragment = f.header_block_fragment[5:] - for header, value in f.state.decoder.decode(header_block_fragment): - f.headers.append((header, value)) + # TODO only do this if END_HEADERS or something... + # for header, value in f.state.decoder.decode(f.header_block_fragment): + # f.headers.append((header, value)) return f @@ -217,6 +249,8 @@ class HeadersFrame(Frame): (int(self.exclusive) << 31) | self.stream_dependency, self.weight) + # TODO: maybe remove that and only deal with header_block_fragments + # inside frames b += self.state.encoder.encode(self.headers) if self.flags & self.FLAG_PADDED: diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py deleted file mode 100644 index 227139a3..00000000 --- a/netlib/h2/h2.py +++ /dev/null @@ -1,89 +0,0 @@ -from .. import utils, odict, tcp -from frame import * - -# "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" -CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' - -ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd -) - - -class H2Client(tcp.TCPClient): - ALPN_PROTO_H2 = b'h2' - - DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, - } - - def __init__(self, address, source_address=None): - super(H2Client, self).__init__(address, source_address) - self.settings = self.DEFAULT_SETTINGS.copy() - - def connect(self, send_preface=True): - super(H2Client, self).connect() - self.convert_to_ssl(alpn_protos=[self.ALPN_PROTO_H2]) - - alp = self.get_alpn_proto_negotiated() - if alp != b'h2': - raise NotImplementedError( - "H2Client can not handle unknown protocol: %s" % - alp) - print "-> Successfully negotiated 'h2' application layer protocol." - - if send_preface: - self.wfile.write(bytes(CLIENT_CONNECTION_PREFACE.decode('hex'))) - self.send_frame(SettingsFrame()) - - frame = Frame.from_file(self.rfile) - print frame.human_readable() - assert isinstance(frame, SettingsFrame) - self.apply_settings(frame.settings) - - print "-> Connection Preface completed." - - print "-> H2Client is ready..." - - def send_frame(self, frame): - self.wfile.write(frame.to_bytes()) - self.wfile.flush() - - def read_frame(self): - frame = Frame.from_file(self.rfile) - if isinstance(frame, SettingsFrame): - self.apply_settings(frame.settings) - - return frame - - def apply_settings(self, settings): - for setting, value in settings.items(): - old_value = self.settings[setting] - if not old_value: - old_value = '-' - - self.settings[setting] = value - print "-> Setting changed: %s to %d (was %s)" % - (SettingsFrame.SETTINGS.get_name(setting), - value, - str(old_value)) - - self.send_frame(SettingsFrame(flags=Frame.FLAG_ACK)) - print "-> New settings acknowledged." diff --git a/test/h2/example.py b/test/h2/example.py deleted file mode 100644 index fc4f2f10..00000000 --- a/test/h2/example.py +++ /dev/null @@ -1,18 +0,0 @@ -from netlib.h2.frame import * -from netlib.h2.h2 import * - -c = H2Client(("127.0.0.1", 443)) -c.connect() - -c.send_frame(HeadersFrame( - flags=(Frame.FLAG_END_HEADERS | Frame.FLAG_END_STREAM), - stream_id=0x1, - headers=[ - (b':method', 'GET'), - (b':path', b'/index.html'), - (b':scheme', b'https'), - (b':authority', b'localhost'), - ])) - -while True: - print c.read_frame().human_readable() diff --git a/test/h2/test_frames.py b/test/h2/test_frames.py index 310336b0..babf8069 100644 --- a/test/h2/test_frames.py +++ b/test/h2/test_frames.py @@ -3,6 +3,7 @@ import tutils from nose.tools import assert_equal + def test_invalid_flags(): tutils.raises( ValueError, -- cgit v1.2.3 From 40fa113116a2d3a549bc57c1b1381bbb55c7014b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 4 Jun 2015 14:11:19 +0200 Subject: http2: change header_block_fragment handling --- netlib/h2/frame.py | 65 +++++++++-------------------- test/h2/test_frames.py | 111 ++++++++++++++++++++++++++----------------------- 2 files changed, 80 insertions(+), 96 deletions(-) diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 137cbb3d..0755c96c 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -48,13 +48,21 @@ class Frame(object): self.flags = flags self.stream_id = stream_id - def _check_frame_size(self, length): - max_length = self.state.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - if length > max_length: + @classmethod + def _check_frame_size(self, length, state): + from . import HTTP2Protocol + + if state: + settings = state.http2_settings + else: + settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS + + max_frame_size = settings[SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: raise NotImplementedError( "Frame size exceeded: %d, but only %d allowed." % ( - length, max_length)) + length, max_frame_size)) @classmethod def from_file(self, fp, state=None): @@ -70,7 +78,7 @@ class Frame(object): flags = fields[3] stream_id = fields[4] - # TODO: check frame size if <= current SETTINGS_MAX_FRAME_SIZE + self._check_frame_size(length, state) payload = fp.safe_read(length) return FRAMES[fields[2]].from_bytes( @@ -80,26 +88,11 @@ class Frame(object): stream_id, payload) - @classmethod - def from_bytes(self, data, state=None): - fields = struct.unpack("!HBBBL", data[:9]) - length = (fields[0] << 8) + fields[1] - # type is already deducted from class - flags = fields[3] - stream_id = fields[4] - - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - data[9:]) - def to_bytes(self): payload = self.payload_bytes() self.length = len(payload) - self._check_frame_size(self.length) + self._check_frame_size(self.length, self.state) b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) b += struct.pack('!B', self.TYPE) @@ -192,17 +185,14 @@ class HeadersFrame(Frame): length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, - headers=None, + header_block_fragment=b'', pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): super(HeadersFrame, self).__init__(state, length, flags, stream_id) - if headers is None: - headers = [] - - self.headers = headers + self.header_block_fragment = header_block_fragment self.pad_length = pad_length self.exclusive = exclusive self.stream_dependency = stream_dependency @@ -220,23 +210,14 @@ class HeadersFrame(Frame): if f.flags & self.FLAG_PRIORITY: f.stream_dependency, f.weight = struct.unpack( - '!LB', header_block_fragment[:5]) + '!LB', f.header_block_fragment[:5]) f.exclusive = bool(f.stream_dependency >> 31) f.stream_dependency &= 0x7FFFFFFF f.header_block_fragment = f.header_block_fragment[5:] - # TODO only do this if END_HEADERS or something... - # for header, value in f.state.decoder.decode(f.header_block_fragment): - # f.headers.append((header, value)) - return f def payload_bytes(self): - """ - This encodes all headers with HPACK - Do NOT call this method twice - it will change the encoder state! - """ - if self.stream_id == 0x0: raise ValueError('HEADERS frames MUST be associated with a stream.') @@ -249,9 +230,7 @@ class HeadersFrame(Frame): (int(self.exclusive) << 31) | self.stream_dependency, self.weight) - # TODO: maybe remove that and only deal with header_block_fragments - # inside frames - b += self.state.encoder.encode(self.headers) + b += self.header_block_fragment if self.flags & self.FLAG_PADDED: b += b'\0' * self.pad_length @@ -269,11 +248,7 @@ class HeadersFrame(Frame): if self.flags & self.FLAG_PADDED: s.append("padding: %d" % self.pad_length) - if not self.headers: - s.append("headers: None") - else: - for header, value in self.headers: - s.append("%s: %s" % (header, value)) + s.append("header_block_fragment: %s" % self.header_block_fragment.encode('hex')) return "\n".join(s) diff --git a/test/h2/test_frames.py b/test/h2/test_frames.py index babf8069..30dc71e8 100644 --- a/test/h2/test_frames.py +++ b/test/h2/test_frames.py @@ -3,6 +3,22 @@ import tutils from nose.tools import assert_equal +class FileAdapter(object): + def __init__(self, data, is_hex=True): + self.position = 0 + if is_hex: + self.data = data.decode('hex') + else: + self.data = data + + def safe_read(self, length): + if self.position + length > len(self.data): + raise ValueError("not enough bytes to read") + + value = self.data[self.position:self.position + length] + self.position += length + return value + def test_invalid_flags(): tutils.raises( @@ -26,14 +42,6 @@ def test_frame_equality(): payload='foobar') assert_equal(a, b) - -def test_too_large_frames(): - DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567) - - def test_data_frame_to_bytes(): f = DataFrame( length=6, @@ -61,7 +69,7 @@ def test_data_frame_to_bytes(): def test_data_frame_from_bytes(): - f = Frame.from_bytes('000006000101234567666f6f626172'.decode('hex')) + f = Frame.from_file(FileAdapter('000006000101234567666f6f626172')) assert isinstance(f, DataFrame) assert_equal(f.length, 6) assert_equal(f.TYPE, DataFrame.TYPE) @@ -69,7 +77,7 @@ def test_data_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.payload, 'foobar') - f = Frame.from_bytes('00000a00090123456703666f6f626172000000'.decode('hex')) + f = Frame.from_file(FileAdapter('00000a00090123456703666f6f626172000000')) assert isinstance(f, DataFrame) assert_equal(f.length, 10) assert_equal(f.TYPE, DataFrame.TYPE) @@ -93,14 +101,14 @@ def test_headers_frame_to_bytes(): length=6, flags=(Frame.FLAG_NO_FLAGS), stream_id=0x1234567, - headers=[('host', 'foo.bar')]) + header_block_fragment='668594e75e31d9'.decode('hex')) assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') f = HeadersFrame( length=10, flags=(HeadersFrame.FLAG_PADDED), stream_id=0x1234567, - headers=[('host', 'foo.bar')], + header_block_fragment='668594e75e31d9'.decode('hex'), pad_length=3) assert_equal( f.to_bytes().encode('hex'), @@ -110,7 +118,7 @@ def test_headers_frame_to_bytes(): length=10, flags=(HeadersFrame.FLAG_PRIORITY), stream_id=0x1234567, - headers=[('host', 'foo.bar')], + header_block_fragment='668594e75e31d9'.decode('hex'), exclusive=True, stream_dependency=0x7654321, weight=42) @@ -122,7 +130,7 @@ def test_headers_frame_to_bytes(): length=14, flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), stream_id=0x1234567, - headers=[('host', 'foo.bar')], + header_block_fragment='668594e75e31d9'.decode('hex'), pad_length=3, exclusive=True, stream_dependency=0x7654321, @@ -135,7 +143,7 @@ def test_headers_frame_to_bytes(): length=14, flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), stream_id=0x1234567, - headers=[('host', 'foo.bar')], + header_block_fragment='668594e75e31d9'.decode('hex'), pad_length=3, exclusive=False, stream_dependency=0x7654321, @@ -148,60 +156,61 @@ def test_headers_frame_to_bytes(): length=6, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, - headers=[('host', 'foo.bar')]) + header_block_fragment='668594e75e31d9'.decode('hex')) tutils.raises(ValueError, f.to_bytes) def test_headers_frame_from_bytes(): - f = Frame.from_bytes('000007010001234567668594e75e31d9'.decode('hex')) + f = Frame.from_file(FileAdapter( + '000007010001234567668594e75e31d9')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 7) assert_equal(f.TYPE, HeadersFrame.TYPE) assert_equal(f.flags, Frame.FLAG_NO_FLAGS) assert_equal(f.stream_id, 0x1234567) - assert_equal(f.headers, [('host', 'foo.bar')]) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - f = Frame.from_bytes( - '00000b01080123456703668594e75e31d9000000'.decode('hex')) + f = Frame.from_file(FileAdapter( + '00000b01080123456703668594e75e31d9000000')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 11) assert_equal(f.TYPE, HeadersFrame.TYPE) assert_equal(f.flags, HeadersFrame.FLAG_PADDED) assert_equal(f.stream_id, 0x1234567) - assert_equal(f.headers, [('host', 'foo.bar')]) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - f = Frame.from_bytes( - '00000c012001234567876543212a668594e75e31d9'.decode('hex')) + f = Frame.from_file(FileAdapter( + '00000c012001234567876543212a668594e75e31d9')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 12) assert_equal(f.TYPE, HeadersFrame.TYPE) assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) assert_equal(f.stream_id, 0x1234567) - assert_equal(f.headers, [('host', 'foo.bar')]) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) assert_equal(f.exclusive, True) assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) - f = Frame.from_bytes( - '00001001280123456703876543212a668594e75e31d9000000'.decode('hex')) + f = Frame.from_file(FileAdapter( + '00001001280123456703876543212a668594e75e31d9000000')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 16) assert_equal(f.TYPE, HeadersFrame.TYPE) assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) assert_equal(f.stream_id, 0x1234567) - assert_equal(f.headers, [('host', 'foo.bar')]) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) assert_equal(f.exclusive, True) assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) - f = Frame.from_bytes( - '00001001280123456703076543212a668594e75e31d9000000'.decode('hex')) + f = Frame.from_file(FileAdapter( + '00001001280123456703076543212a668594e75e31d9000000')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 16) assert_equal(f.TYPE, HeadersFrame.TYPE) assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) assert_equal(f.stream_id, 0x1234567) - assert_equal(f.headers, [('host', 'foo.bar')]) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) assert_equal(f.exclusive, False) assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) @@ -212,7 +221,7 @@ def test_headers_frame_human_readable(): length=7, flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), stream_id=0x1234567, - headers=[], + header_block_fragment=b'', pad_length=3, exclusive=False, stream_dependency=0x7654321, @@ -223,7 +232,7 @@ def test_headers_frame_human_readable(): length=14, flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), stream_id=0x1234567, - headers=[('host', 'foo.bar')], + header_block_fragment='668594e75e31d9'.decode('hex'), pad_length=3, exclusive=False, stream_dependency=0x7654321, @@ -266,7 +275,7 @@ def test_priority_frame_to_bytes(): def test_priority_frame_from_bytes(): - f = Frame.from_bytes('000005020001234567876543212a'.decode('hex')) + f = Frame.from_file(FileAdapter('000005020001234567876543212a')) assert isinstance(f, PriorityFrame) assert_equal(f.length, 5) assert_equal(f.TYPE, PriorityFrame.TYPE) @@ -276,7 +285,7 @@ def test_priority_frame_from_bytes(): assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) - f = Frame.from_bytes('0000050200012345670765432115'.decode('hex')) + f = Frame.from_file(FileAdapter('0000050200012345670765432115')) assert isinstance(f, PriorityFrame) assert_equal(f.length, 5) assert_equal(f.TYPE, PriorityFrame.TYPE) @@ -314,7 +323,7 @@ def test_rst_stream_frame_to_bytes(): def test_rst_stream_frame_from_bytes(): - f = Frame.from_bytes('00000403000123456707654321'.decode('hex')) + f = Frame.from_file(FileAdapter('00000403000123456707654321')) assert isinstance(f, RstStreamFrame) assert_equal(f.length, 4) assert_equal(f.TYPE, RstStreamFrame.TYPE) @@ -372,21 +381,21 @@ def test_settings_frame_to_bytes(): def test_settings_frame_from_bytes(): - f = Frame.from_bytes('000000040000000000'.decode('hex')) + f = Frame.from_file(FileAdapter('000000040000000000')) assert isinstance(f, SettingsFrame) assert_equal(f.length, 0) assert_equal(f.TYPE, SettingsFrame.TYPE) assert_equal(f.flags, Frame.FLAG_NO_FLAGS) assert_equal(f.stream_id, 0x0) - f = Frame.from_bytes('000000040100000000'.decode('hex')) + f = Frame.from_file(FileAdapter('000000040100000000')) assert isinstance(f, SettingsFrame) assert_equal(f.length, 0) assert_equal(f.TYPE, SettingsFrame.TYPE) assert_equal(f.flags, SettingsFrame.FLAG_ACK) assert_equal(f.stream_id, 0x0) - f = Frame.from_bytes('000006040100000000000200000001'.decode('hex')) + f = Frame.from_file(FileAdapter('000006040100000000000200000001')) assert isinstance(f, SettingsFrame) assert_equal(f.length, 6) assert_equal(f.TYPE, SettingsFrame.TYPE) @@ -395,8 +404,8 @@ def test_settings_frame_from_bytes(): assert_equal(len(f.settings), 1) assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - f = Frame.from_bytes( - '00000c040000000000000200000001000312345678'.decode('hex')) + f = Frame.from_file(FileAdapter( + '00000c040000000000000200000001000312345678')) assert isinstance(f, SettingsFrame) assert_equal(f.length, 12) assert_equal(f.TYPE, SettingsFrame.TYPE) @@ -466,7 +475,7 @@ def test_push_promise_frame_to_bytes(): def test_push_promise_frame_from_bytes(): - f = Frame.from_bytes('00000a05000123456707654321666f6f626172'.decode('hex')) + f = Frame.from_file(FileAdapter('00000a05000123456707654321666f6f626172')) assert isinstance(f, PushPromiseFrame) assert_equal(f.length, 10) assert_equal(f.TYPE, PushPromiseFrame.TYPE) @@ -474,8 +483,8 @@ def test_push_promise_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.header_block_fragment, 'foobar') - f = Frame.from_bytes( - '00000e0508012345670307654321666f6f626172000000'.decode('hex')) + f = Frame.from_file(FileAdapter( + '00000e0508012345670307654321666f6f626172000000')) assert isinstance(f, PushPromiseFrame) assert_equal(f.length, 14) assert_equal(f.TYPE, PushPromiseFrame.TYPE) @@ -522,7 +531,7 @@ def test_ping_frame_to_bytes(): def test_ping_frame_from_bytes(): - f = Frame.from_bytes('000008060100000000666f6f6261720000'.decode('hex')) + f = Frame.from_file(FileAdapter('000008060100000000666f6f6261720000')) assert isinstance(f, PingFrame) assert_equal(f.length, 8) assert_equal(f.TYPE, PingFrame.TYPE) @@ -530,7 +539,7 @@ def test_ping_frame_from_bytes(): assert_equal(f.stream_id, 0x0) assert_equal(f.payload, b'foobar\0\0') - f = Frame.from_bytes('000008060000000000666f6f6261726465'.decode('hex')) + f = Frame.from_file(FileAdapter('000008060000000000666f6f6261726465')) assert isinstance(f, PingFrame) assert_equal(f.length, 8) assert_equal(f.TYPE, PingFrame.TYPE) @@ -581,8 +590,8 @@ def test_goaway_frame_to_bytes(): def test_goaway_frame_from_bytes(): - f = Frame.from_bytes( - '0000080700000000000123456787654321'.decode('hex')) + f = Frame.from_file(FileAdapter( + '0000080700000000000123456787654321')) assert isinstance(f, GoAwayFrame) assert_equal(f.length, 8) assert_equal(f.TYPE, GoAwayFrame.TYPE) @@ -592,8 +601,8 @@ def test_goaway_frame_from_bytes(): assert_equal(f.error_code, 0x87654321) assert_equal(f.data, b'') - f = Frame.from_bytes( - '00000e0700000000000123456787654321666f6f626172'.decode('hex')) + f = Frame.from_file(FileAdapter( + '00000e0700000000000123456787654321666f6f626172')) assert isinstance(f, GoAwayFrame) assert_equal(f.length, 14) assert_equal(f.TYPE, GoAwayFrame.TYPE) @@ -642,7 +651,7 @@ def test_window_update_frame_to_bytes(): def test_window_update_frame_from_bytes(): - f = Frame.from_bytes('00000408000000000001234567'.decode('hex')) + f = Frame.from_file(FileAdapter('00000408000000000001234567')) assert isinstance(f, WindowUpdateFrame) assert_equal(f.length, 4) assert_equal(f.TYPE, WindowUpdateFrame.TYPE) @@ -677,7 +686,7 @@ def test_continuation_frame_to_bytes(): def test_continuation_frame_from_bytes(): - f = Frame.from_bytes('000006090401234567666f6f626172'.decode('hex')) + f = Frame.from_file(FileAdapter('000006090401234567666f6f626172')) assert isinstance(f, ContinuationFrame) assert_equal(f.length, 6) assert_equal(f.TYPE, ContinuationFrame.TYPE) -- cgit v1.2.3 From 623dd850e0ce15630e0950b4de843c0af8046618 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 4 Jun 2015 14:28:09 +0200 Subject: http2: add logging and error handling --- netlib/h2/__init__.py | 28 ++++++++++++++++++---------- netlib/h2/frame.py | 16 ++++++++++++---- test/h2/test_frames.py | 14 ++++++++++++-- 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/netlib/h2/__init__.py b/netlib/h2/__init__.py index 054ba91c..c06f7a11 100644 --- a/netlib/h2/__init__.py +++ b/netlib/h2/__init__.py @@ -1,8 +1,11 @@ from __future__ import (absolute_import, print_function, division) import itertools +import logging -from .. import utils from .frame import * +from .. import utils + +log = logging.getLogger(__name__) class HTTP2Protocol(object): @@ -49,7 +52,7 @@ class HTTP2Protocol(object): if alp != self.ALPN_PROTO_H2: raise NotImplementedError( "H2Client can not handle unknown ALP: %s" % alp) - print("-> Successfully negotiated 'h2' application layer protocol.") + log.debug("ALP 'h2' successfully negotiated.") def send_connection_preface(self): self.wfile.write(bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) @@ -60,7 +63,7 @@ class HTTP2Protocol(object): self._apply_settings(frame.settings) self.read_frame() # read setting ACK frame - print("-> Connection Preface completed.") + log.debug("Connection Preface completed.") def next_stream_id(self): if self.current_stream_id is None: @@ -88,13 +91,13 @@ class HTTP2Protocol(object): old_value = '-' self.http2_settings[setting] = value - print("-> Setting changed: %s to %d (was %s)" % ( + log.debug("Setting changed: %s to %d (was %s)" % ( SettingsFrame.SETTINGS.get_name(setting), value, str(old_value))) self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) - print("-> New settings acknowledged.") + log.debug("New settings acknowledged.") def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -103,11 +106,13 @@ class HTTP2Protocol(object): if end_stream: flags |= Frame.FLAG_END_STREAM + header_block_fragment = self.encoder.encode(headers) + bytes = HeadersFrame( state=self, flags=flags, stream_id=stream_id, - headers=headers).to_bytes() + header_block_fragment=header_block_fragment).to_bytes() return [bytes] def _create_body(self, body, stream_id): @@ -150,8 +155,8 @@ class HTTP2Protocol(object): if frame.flags | Frame.FLAG_END_HEADERS: break else: - print("Unexpected frame received:") - print(frame.human_readable()) + log.debug("Unexpected frame received:") + log.debug(frame.human_readable()) while True: frame = self.read_frame() @@ -160,11 +165,14 @@ class HTTP2Protocol(object): if frame.flags | Frame.FLAG_END_STREAM: break else: - print("Unexpected frame received:") - print(frame.human_readable()) + log.debug("Unexpected frame received:") + log.debug(frame.human_readable()) headers = {} for header, value in self.decoder.decode(header_block_fragment): headers[header] = value + for header, value in headers.items(): + log.debug("%s: %s" % (header, value)) + return headers[':status'], headers, body diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 0755c96c..018e822f 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -1,9 +1,14 @@ import struct +import logging +from functools import reduce from hpack.hpack import Encoder, Decoder from .. import utils -from functools import reduce +log = logging.getLogger(__name__) + +class FrameSizeError(Exception): + pass class Frame(object): @@ -57,10 +62,11 @@ class Frame(object): else: settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS - max_frame_size = settings[SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] if length > max_frame_size: - raise NotImplementedError( + raise FrameSizeError( "Frame size exceeded: %d, but only %d allowed." % ( length, max_frame_size)) @@ -248,7 +254,9 @@ class HeadersFrame(Frame): if self.flags & self.FLAG_PADDED: s.append("padding: %d" % self.pad_length) - s.append("header_block_fragment: %s" % self.header_block_fragment.encode('hex')) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) return "\n".join(s) diff --git a/test/h2/test_frames.py b/test/h2/test_frames.py index 30dc71e8..42a0c1cf 100644 --- a/test/h2/test_frames.py +++ b/test/h2/test_frames.py @@ -1,7 +1,7 @@ -from netlib.h2.frame import * import tutils - from nose.tools import assert_equal +from netlib.h2.frame import * + class FileAdapter(object): def __init__(self, data, is_hex=True): @@ -42,6 +42,16 @@ def test_frame_equality(): payload='foobar') assert_equal(a, b) + +def test_too_large_frames(): + f = DataFrame( + length=9000, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar' * 3000) + tutils.raises(FrameSizeError, f.to_bytes) + + def test_data_frame_to_bytes(): f = DataFrame( length=6, -- cgit v1.2.3 From f003f87197a6dffe1b51a82f7dd218121c75e206 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 4 Jun 2015 19:44:48 +0200 Subject: http2: rename module and refactor as strategy --- netlib/h2/__init__.py | 178 -------------- netlib/h2/frame.py | 623 ---------------------------------------------- netlib/http2/__init__.py | 181 ++++++++++++++ netlib/http2/frame.py | 625 +++++++++++++++++++++++++++++++++++++++++++++++ test/h2/test_frames.py | 2 +- 5 files changed, 807 insertions(+), 802 deletions(-) delete mode 100644 netlib/h2/__init__.py delete mode 100644 netlib/h2/frame.py create mode 100644 netlib/http2/__init__.py create mode 100644 netlib/http2/frame.py diff --git a/netlib/h2/__init__.py b/netlib/h2/__init__.py deleted file mode 100644 index c06f7a11..00000000 --- a/netlib/h2/__init__.py +++ /dev/null @@ -1,178 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import itertools -import logging - -from .frame import * -from .. import utils - -log = logging.getLogger(__name__) - - -class HTTP2Protocol(object): - - ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd - ) - - # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' - - ALPN_PROTO_H2 = b'h2' - - HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, - } - - def __init__(self): - self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() - self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() - - def check_alpn(self): - alp = self.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "H2Client can not handle unknown ALP: %s" % alp) - log.debug("ALP 'h2' successfully negotiated.") - - def send_connection_preface(self): - self.wfile.write(bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) - self.send_frame(SettingsFrame(state=self)) - - frame = Frame.from_file(self.rfile, self) - assert isinstance(frame, SettingsFrame) - self._apply_settings(frame.settings) - self.read_frame() # read setting ACK frame - - log.debug("Connection Preface completed.") - - def next_stream_id(self): - if self.current_stream_id is None: - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - - def send_frame(self, frame): - raw_bytes = frame.to_bytes() - self.wfile.write(raw_bytes) - self.wfile.flush() - - def read_frame(self): - frame = Frame.from_file(self.rfile, self) - if isinstance(frame, SettingsFrame): - self._apply_settings(frame.settings) - - return frame - - def _apply_settings(self, settings): - for setting, value in settings.items(): - old_value = self.http2_settings[setting] - if not old_value: - old_value = '-' - - self.http2_settings[setting] = value - log.debug("Setting changed: %s to %d (was %s)" % ( - SettingsFrame.SETTINGS.get_name(setting), - value, - str(old_value))) - - self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) - log.debug("New settings acknowledged.") - - def _create_headers(self, headers, stream_id, end_stream=True): - # TODO: implement max frame size checks and sending in chunks - - flags = Frame.FLAG_END_HEADERS - if end_stream: - flags |= Frame.FLAG_END_STREAM - - header_block_fragment = self.encoder.encode(headers) - - bytes = HeadersFrame( - state=self, - flags=flags, - stream_id=stream_id, - header_block_fragment=header_block_fragment).to_bytes() - return [bytes] - - def _create_body(self, body, stream_id): - if body is None or len(body) == 0: - return b'' - - # TODO: implement max frame size checks and sending in chunks - # TODO: implement flow-control window - - bytes = DataFrame( - state=self, - flags=Frame.FLAG_END_STREAM, - stream_id=stream_id, - payload=body).to_bytes() - return [bytes] - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https')] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self): - header_block_fragment = b'' - body = b'' - - while True: - frame = self.read_frame() - if isinstance(frame, HeadersFrame): - header_block_fragment += frame.header_block_fragment - if frame.flags | Frame.FLAG_END_HEADERS: - break - else: - log.debug("Unexpected frame received:") - log.debug(frame.human_readable()) - - while True: - frame = self.read_frame() - if isinstance(frame, DataFrame): - body += frame.payload - if frame.flags | Frame.FLAG_END_STREAM: - break - else: - log.debug("Unexpected frame received:") - log.debug(frame.human_readable()) - - headers = {} - for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value - - for header, value in headers.items(): - log.debug("%s: %s" % (header, value)) - - return headers[':status'], headers, body diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py deleted file mode 100644 index 018e822f..00000000 --- a/netlib/h2/frame.py +++ /dev/null @@ -1,623 +0,0 @@ -import struct -import logging -from functools import reduce -from hpack.hpack import Encoder, Decoder - -from .. import utils - -log = logging.getLogger(__name__) - -class FrameSizeError(Exception): - pass - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - from . import HTTP2Protocol - - class State(object): - pass - - state = State() - state.http2_settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(self, length, state): - from . import HTTP2Protocol - - if state: - settings = state.http2_settings - else: - settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise FrameSizeError( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(self, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - self._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self): - return "\n".join([ - "============================================================", - "length: %d bytes" % self.length, - "type: %s (%#x)" % (self.__class__.__name__, self.TYPE), - "flags: %#x" % self.flags, - "stream_id: %#x" % self.stream_id, - "------------------------------------------------------------", - self.payload_human_readable(), - "============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & self.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & self.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & self.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - if self.stream_dependency == 0x0: - raise ValueError('stream dependency is invalid.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & self.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append("header_block_fragment: %s" % str(self.header_block_fragment)) - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Szie Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - return "header_block_fragment: %s" % str(self.header_block_fragment) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py new file mode 100644 index 00000000..d6f2c51c --- /dev/null +++ b/netlib/http2/__init__.py @@ -0,0 +1,181 @@ +from __future__ import (absolute_import, print_function, division) +import itertools +import logging + +from .frame import * +from .. import utils + +log = logging.getLogger(__name__) + + +class HTTP2Protocol(object): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + + ALPN_PROTO_H2 = b'h2' + + HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, + } + + def __init__(self, tcp_client): + self.tcp_client = tcp_client + + self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.encoder = Encoder() + self.decoder = Decoder() + + def check_alpn(self): + alp = self.tcp_client.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "H2Client can not handle unknown ALP: %s" % alp) + log.debug("ALP 'h2' successfully negotiated.") + + def send_connection_preface(self): + self.tcp_client.wfile.write( + bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) + self.send_frame(SettingsFrame(state=self)) + + frame = Frame.from_file(self.tcp_client.rfile, self) + assert isinstance(frame, SettingsFrame) + self._apply_settings(frame.settings) + self.read_frame() # read setting ACK frame + + log.debug("Connection Preface completed.") + + def next_stream_id(self): + if self.current_stream_id is None: + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def send_frame(self, frame): + raw_bytes = frame.to_bytes() + self.tcp_client.wfile.write(raw_bytes) + self.tcp_client.wfile.flush() + + def read_frame(self): + frame = Frame.from_file(self.tcp_client.rfile, self) + if isinstance(frame, SettingsFrame): + self._apply_settings(frame.settings) + + return frame + + def _apply_settings(self, settings): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + + self.http2_settings[setting] = value + log.debug("Setting changed: %s to %d (was %s)" % ( + SettingsFrame.SETTINGS.get_name(setting), + value, + str(old_value))) + + self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) + log.debug("New settings acknowledged.") + + def _create_headers(self, headers, stream_id, end_stream=True): + # TODO: implement max frame size checks and sending in chunks + + flags = Frame.FLAG_END_HEADERS + if end_stream: + flags |= Frame.FLAG_END_STREAM + + header_block_fragment = self.encoder.encode(headers) + + bytes = HeadersFrame( + state=self, + flags=flags, + stream_id=stream_id, + header_block_fragment=header_block_fragment).to_bytes() + return [bytes] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + # TODO: implement max frame size checks and sending in chunks + # TODO: implement flow-control window + + bytes = DataFrame( + state=self, + flags=Frame.FLAG_END_STREAM, + stream_id=stream_id, + payload=body).to_bytes() + return [bytes] + + def create_request(self, method, path, headers=None, body=None): + if headers is None: + headers = [] + + headers = [ + (b':method', bytes(method)), + (b':path', bytes(path)), + (b':scheme', b'https')] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) + + def read_response(self): + header_block_fragment = b'' + body = b'' + + while True: + frame = self.read_frame() + if isinstance(frame, HeadersFrame): + header_block_fragment += frame.header_block_fragment + if frame.flags | Frame.FLAG_END_HEADERS: + break + else: + log.debug("Unexpected frame received:") + log.debug(frame.human_readable()) + + while True: + frame = self.read_frame() + if isinstance(frame, DataFrame): + body += frame.payload + if frame.flags | Frame.FLAG_END_STREAM: + break + else: + log.debug("Unexpected frame received:") + log.debug(frame.human_readable()) + + headers = {} + for header, value in self.decoder.decode(header_block_fragment): + headers[header] = value + + for header, value in headers.items(): + log.debug("%s: %s" % (header, value)) + + return headers[':status'], headers, body diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py new file mode 100644 index 00000000..1497380a --- /dev/null +++ b/netlib/http2/frame.py @@ -0,0 +1,625 @@ +import struct +import logging +from functools import reduce +from hpack.hpack import Encoder, Decoder + +from .. import utils + +log = logging.getLogger(__name__) + + +class FrameSizeError(Exception): + pass + + +class Frame(object): + + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): + valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + if state is None: + from . import HTTP2Protocol + + class State(object): + pass + + state = State() + state.http2_settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS.copy() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + + self.length = length + self.type = self.TYPE + self.flags = flags + self.stream_id = stream_id + + @classmethod + def _check_frame_size(self, length, state): + from . import HTTP2Protocol + + if state: + settings = state.http2_settings + else: + settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS + + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: + raise FrameSizeError( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_frame_size)) + + @classmethod + def from_file(self, fp, state=None): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + self._check_frame_size(length, state) + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + self._check_frame_size(self.length, self.state) + + b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self): + return "\n".join([ + "============================================================", + "length: %d bytes" % self.length, + "type: %s (%#x)" % (self.__class__.__name__, self.TYPE), + "flags: %#x" % self.flags, + "stream_id: %#x" % self.stream_id, + "------------------------------------------------------------", + self.payload_human_readable(), + "============================================================", + ]) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): + super(DataFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + return "payload: %s" % str(self.payload) + + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b'', + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(HeadersFrame, self).__init__(state, length, flags, stream_id) + + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & self.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack( + '!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) + + b += self.header_block_fragment + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(PriorityFrame, self).__init__(state, length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') + + if self.stream_dependency == 0x0: + raise ValueError('stream dependency is invalid.') + + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) + + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) + + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + + def payload_human_readable(self): + return "error code: %#x" % self.error_code + + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = utils.BiDi( + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): + super(SettingsFrame, self).__init__(state, length, flags, stream_id) + + if settings is None: + settings = {} + + self.settings = settings + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + for i in xrange(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i + 6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) + + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & self.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append("header_block_fragment: %s" % str(self.header_block_fragment)) + return "\n".join(s) + + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): + super(PingFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) + + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) + + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: + raise ValueError( + 'Window Szie Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment + + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(self, state, length, flags, stream_id, payload): + f = self(state=state, length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + + def payload_human_readable(self): + return "header_block_fragment: %s" % str(self.header_block_fragment) + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} diff --git a/test/h2/test_frames.py b/test/h2/test_frames.py index 42a0c1cf..d8a4febc 100644 --- a/test/h2/test_frames.py +++ b/test/h2/test_frames.py @@ -1,6 +1,6 @@ import tutils from nose.tools import assert_equal -from netlib.h2.frame import * +from netlib.http2.frame import * class FileAdapter(object): -- cgit v1.2.3 From fdc908cb9811628435ef02e3168c4d5931c6a3c5 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 5 Jun 2015 13:28:09 +0200 Subject: http2: add protocol tests --- netlib/http2/__init__.py | 25 +- netlib/test.py | 2 +- test/__init__.py | 0 test/h2/__init__.py | 0 test/h2/test_frames.py | 714 -------------------------------------- test/http2/test_frames.py | 714 ++++++++++++++++++++++++++++++++++++++ test/http2/test_http2_protocol.py | 216 ++++++++++++ 7 files changed, 944 insertions(+), 727 deletions(-) create mode 100644 test/__init__.py delete mode 100644 test/h2/__init__.py delete mode 100644 test/h2/test_frames.py create mode 100644 test/http2/test_frames.py create mode 100644 test/http2/test_http2_protocol.py diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py index d6f2c51c..2803cccb 100644 --- a/netlib/http2/__init__.py +++ b/netlib/http2/__init__.py @@ -30,7 +30,7 @@ class HTTP2Protocol(object): # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' - ALPN_PROTO_H2 = b'h2' + ALPN_PROTO_H2 = 'h2' HTTP2_DEFAULT_SETTINGS = { SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, @@ -53,18 +53,25 @@ class HTTP2Protocol(object): alp = self.tcp_client.get_alpn_proto_negotiated() if alp != self.ALPN_PROTO_H2: raise NotImplementedError( - "H2Client can not handle unknown ALP: %s" % alp) + "HTTP2Protocol can not handle unknown ALP: %s" % alp) log.debug("ALP 'h2' successfully negotiated.") + return True - def send_connection_preface(self): + def perform_connection_preface(self): self.tcp_client.wfile.write( bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) self.send_frame(SettingsFrame(state=self)) + # read server settings frame frame = Frame.from_file(self.tcp_client.rfile, self) assert isinstance(frame, SettingsFrame) self._apply_settings(frame.settings) - self.read_frame() # read setting ACK frame + + # read setting ACK frame + settings_ack_frame = self.read_frame() + assert isinstance(settings_ack_frame, SettingsFrame) + assert settings_ack_frame.flags & Frame.FLAG_ACK + assert len(settings_ack_frame.settings) == 0 log.debug("Connection Preface completed.") @@ -94,9 +101,9 @@ class HTTP2Protocol(object): old_value = '-' self.http2_settings[setting] = value - log.debug("Setting changed: %s to %d (was %s)" % ( + log.debug("Setting changed: %s to %s (was %s)" % ( SettingsFrame.SETTINGS.get_name(setting), - value, + str(value), str(old_value))) self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) @@ -157,9 +164,6 @@ class HTTP2Protocol(object): header_block_fragment += frame.header_block_fragment if frame.flags | Frame.FLAG_END_HEADERS: break - else: - log.debug("Unexpected frame received:") - log.debug(frame.human_readable()) while True: frame = self.read_frame() @@ -167,9 +171,6 @@ class HTTP2Protocol(object): body += frame.payload if frame.flags | Frame.FLAG_END_STREAM: break - else: - log.debug("Unexpected frame received:") - log.debug(frame.human_readable()) headers = {} for header, value in self.decoder.decode(header_block_fragment): diff --git a/netlib/test.py b/netlib/test.py index ee8c6685..4b0b6bd2 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -4,7 +4,7 @@ import Queue import cStringIO import OpenSSL from . import tcp, certutils -import tutils +from test import tutils class ServerThread(threading.Thread): diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/h2/__init__.py b/test/h2/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/h2/test_frames.py b/test/h2/test_frames.py deleted file mode 100644 index d8a4febc..00000000 --- a/test/h2/test_frames.py +++ /dev/null @@ -1,714 +0,0 @@ -import tutils -from nose.tools import assert_equal -from netlib.http2.frame import * - - -class FileAdapter(object): - def __init__(self, data, is_hex=True): - self.position = 0 - if is_hex: - self.data = data.decode('hex') - else: - self.data = data - - def safe_read(self, length): - if self.position + length > len(self.data): - raise ValueError("not enough bytes to read") - - value = self.data[self.position:self.position + length] - self.position += length - return value - - -def test_invalid_flags(): - tutils.raises( - ValueError, - DataFrame, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - payload='foobar') - - -def test_frame_equality(): - a = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - b = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - assert_equal(a, b) - - -def test_too_large_frames(): - f = DataFrame( - length=9000, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar' * 3000) - tutils.raises(FrameSizeError, f.to_bytes) - - -def test_data_frame_to_bytes(): - f = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') - - f = DataFrame( - length=11, - flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), - stream_id=0x1234567, - payload='foobar', - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000a00090123456703666f6f626172000000') - - f = DataFrame( - length=6, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload='foobar') - tutils.raises(ValueError, f.to_bytes) - - -def test_data_frame_from_bytes(): - f = Frame.from_file(FileAdapter('000006000101234567666f6f626172')) - assert isinstance(f, DataFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, DataFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_END_STREAM) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.payload, 'foobar') - - f = Frame.from_file(FileAdapter('00000a00090123456703666f6f626172000000')) - assert isinstance(f, DataFrame) - assert_equal(f.length, 10) - assert_equal(f.TYPE, DataFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.payload, 'foobar') - - -def test_data_frame_human_readable(): - f = DataFrame( - length=11, - flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), - stream_id=0x1234567, - payload='foobar', - pad_length=3) - assert f.human_readable() - - -def test_headers_frame_to_bytes(): - f = HeadersFrame( - length=6, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex')) - assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') - - f = HeadersFrame( - length=10, - flags=(HeadersFrame.FLAG_PADDED), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000b01080123456703668594e75e31d9000000') - - f = HeadersFrame( - length=10, - flags=(HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00000c012001234567876543212a668594e75e31d9') - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00001001280123456703876543212a668594e75e31d9000000') - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00001001280123456703076543212a668594e75e31d9000000') - - f = HeadersFrame( - length=6, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment='668594e75e31d9'.decode('hex')) - tutils.raises(ValueError, f.to_bytes) - - -def test_headers_frame_from_bytes(): - f = Frame.from_file(FileAdapter( - '000007010001234567668594e75e31d9')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 7) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - - f = Frame.from_file(FileAdapter( - '00000b01080123456703668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 11) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - - f = Frame.from_file(FileAdapter( - '00000c012001234567876543212a668594e75e31d9')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 12) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(FileAdapter( - '00001001280123456703876543212a668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 16) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(FileAdapter( - '00001001280123456703076543212a668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 16) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, False) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - -def test_headers_frame_human_readable(): - f = HeadersFrame( - length=7, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment=b'', - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert f.human_readable() - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert f.human_readable() - - -def test_priority_frame_to_bytes(): - f = PriorityFrame( - length=5, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') - - f = PriorityFrame( - length=5, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - exclusive=False, - stream_dependency=0x7654321, - weight=21) - assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') - - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - stream_dependency=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - stream_dependency=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_priority_frame_from_bytes(): - f = Frame.from_file(FileAdapter('000005020001234567876543212a')) - assert isinstance(f, PriorityFrame) - assert_equal(f.length, 5) - assert_equal(f.TYPE, PriorityFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(FileAdapter('0000050200012345670765432115')) - assert isinstance(f, PriorityFrame) - assert_equal(f.length, 5) - assert_equal(f.TYPE, PriorityFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.exclusive, False) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 21) - - -def test_priority_frame_human_readable(): - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - exclusive=False, - stream_dependency=0x7654321, - weight=21) - assert f.human_readable() - - -def test_rst_stream_frame_to_bytes(): - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - error_code=0x7654321) - assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') - - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_rst_stream_frame_from_bytes(): - f = Frame.from_file(FileAdapter('00000403000123456707654321')) - assert isinstance(f, RstStreamFrame) - assert_equal(f.length, 4) - assert_equal(f.TYPE, RstStreamFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.error_code, 0x07654321) - - -def test_rst_stream_frame_human_readable(): - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - error_code=0x7654321) - assert f.human_readable() - - -def test_settings_frame_to_bytes(): - f = SettingsFrame( - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0) - assert_equal(f.to_bytes().encode('hex'), '000000040000000000') - - f = SettingsFrame( - length=0, - flags=SettingsFrame.FLAG_ACK, - stream_id=0x0) - assert_equal(f.to_bytes().encode('hex'), '000000040100000000') - - f = SettingsFrame( - length=6, - flags=SettingsFrame.FLAG_ACK, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) - assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') - - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert_equal( - f.to_bytes().encode('hex'), - '00000c040000000000000200000001000312345678') - - f = SettingsFrame( - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - -def test_settings_frame_from_bytes(): - f = Frame.from_file(FileAdapter('000000040000000000')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 0) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - - f = Frame.from_file(FileAdapter('000000040100000000')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 0) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, SettingsFrame.FLAG_ACK) - assert_equal(f.stream_id, 0x0) - - f = Frame.from_file(FileAdapter('000006040100000000000200000001')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, SettingsFrame.FLAG_ACK, 0x0) - assert_equal(f.stream_id, 0x0) - assert_equal(len(f.settings), 1) - assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - - f = Frame.from_file(FileAdapter( - '00000c040000000000000200000001000312345678')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 12) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(len(f.settings), 2) - assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - assert_equal( - f.settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], - 0x12345678) - - -def test_settings_frame_human_readable(): - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={}) - assert f.human_readable() - - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert f.human_readable() - - -def test_push_promise_frame_to_bytes(): - f = PushPromiseFrame( - length=10, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar') - assert_equal( - f.to_bytes().encode('hex'), - '00000a05000123456707654321666f6f626172') - - f = PushPromiseFrame( - length=14, - flags=HeadersFrame.FLAG_PADDED, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar', - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000e0508012345670307654321666f6f626172000000') - - f = PushPromiseFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - f = PushPromiseFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - promised_stream=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_push_promise_frame_from_bytes(): - f = Frame.from_file(FileAdapter('00000a05000123456707654321666f6f626172')) - assert isinstance(f, PushPromiseFrame) - assert_equal(f.length, 10) - assert_equal(f.TYPE, PushPromiseFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - f = Frame.from_file(FileAdapter( - '00000e0508012345670307654321666f6f626172000000')) - assert isinstance(f, PushPromiseFrame) - assert_equal(f.length, 14) - assert_equal(f.TYPE, PushPromiseFrame.TYPE) - assert_equal(f.flags, PushPromiseFrame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - -def test_push_promise_frame_human_readable(): - f = PushPromiseFrame( - length=14, - flags=HeadersFrame.FLAG_PADDED, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar', - pad_length=3) - assert f.human_readable() - - -def test_ping_frame_to_bytes(): - f = PingFrame( - length=8, - flags=PingFrame.FLAG_ACK, - stream_id=0x0, - payload=b'foobar') - assert_equal( - f.to_bytes().encode('hex'), - '000008060100000000666f6f6261720000') - - f = PingFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'foobardeadbeef') - assert_equal( - f.to_bytes().encode('hex'), - '000008060000000000666f6f6261726465') - - f = PingFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - -def test_ping_frame_from_bytes(): - f = Frame.from_file(FileAdapter('000008060100000000666f6f6261720000')) - assert isinstance(f, PingFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, PingFrame.TYPE) - assert_equal(f.flags, PingFrame.FLAG_ACK) - assert_equal(f.stream_id, 0x0) - assert_equal(f.payload, b'foobar\0\0') - - f = Frame.from_file(FileAdapter('000008060000000000666f6f6261726465')) - assert isinstance(f, PingFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, PingFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.payload, b'foobarde') - - -def test_ping_frame_human_readable(): - f = PingFrame( - length=8, - flags=PingFrame.FLAG_ACK, - stream_id=0x0, - payload=b'foobar') - assert f.human_readable() - - -def test_goaway_frame_to_bytes(): - f = GoAwayFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'') - assert_equal( - f.to_bytes().encode('hex'), - '0000080700000000000123456787654321') - - f = GoAwayFrame( - length=14, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'foobar') - assert_equal( - f.to_bytes().encode('hex'), - '00000e0700000000000123456787654321666f6f626172') - - f = GoAwayFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - last_stream=0x1234567, - error_code=0x87654321) - tutils.raises(ValueError, f.to_bytes) - - -def test_goaway_frame_from_bytes(): - f = Frame.from_file(FileAdapter( - '0000080700000000000123456787654321')) - assert isinstance(f, GoAwayFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, GoAwayFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.last_stream, 0x1234567) - assert_equal(f.error_code, 0x87654321) - assert_equal(f.data, b'') - - f = Frame.from_file(FileAdapter( - '00000e0700000000000123456787654321666f6f626172')) - assert isinstance(f, GoAwayFrame) - assert_equal(f.length, 14) - assert_equal(f.TYPE, GoAwayFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.last_stream, 0x1234567) - assert_equal(f.error_code, 0x87654321) - assert_equal(f.data, b'foobar') - - -def test_go_away_frame_human_readable(): - f = GoAwayFrame( - length=14, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'foobar') - assert f.human_readable() - - -def test_window_update_frame_to_bytes(): - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x1234567) - assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') - - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - window_size_increment=0x7654321) - assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') - - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0xdeadbeef) - tutils.raises(ValueError, f.to_bytes) - - f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) - tutils.raises(ValueError, f.to_bytes) - - -def test_window_update_frame_from_bytes(): - f = Frame.from_file(FileAdapter('00000408000000000001234567')) - assert isinstance(f, WindowUpdateFrame) - assert_equal(f.length, 4) - assert_equal(f.TYPE, WindowUpdateFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.window_size_increment, 0x1234567) - - -def test_window_update_frame_human_readable(): - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - window_size_increment=0x7654321) - assert f.human_readable() - - -def test_continuation_frame_to_bytes(): - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - header_block_fragment='foobar') - assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') - - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x0, - header_block_fragment='foobar') - tutils.raises(ValueError, f.to_bytes) - - -def test_continuation_frame_from_bytes(): - f = Frame.from_file(FileAdapter('000006090401234567666f6f626172')) - assert isinstance(f, ContinuationFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, ContinuationFrame.TYPE) - assert_equal(f.flags, ContinuationFrame.FLAG_END_HEADERS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - -def test_continuation_frame_human_readable(): - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - header_block_fragment='foobar') - assert f.human_readable() diff --git a/test/http2/test_frames.py b/test/http2/test_frames.py new file mode 100644 index 00000000..d8f00dec --- /dev/null +++ b/test/http2/test_frames.py @@ -0,0 +1,714 @@ +from test import tutils +from nose.tools import assert_equal +from netlib.http2.frame import * + + +class FileAdapter(object): + def __init__(self, data, is_hex=True): + self.position = 0 + if is_hex: + self.data = data.decode('hex') + else: + self.data = data + + def safe_read(self, length): + if self.position + length > len(self.data): + raise ValueError("not enough bytes to read") + + value = self.data[self.position:self.position + length] + self.position += length + return value + + +def test_invalid_flags(): + tutils.raises( + ValueError, + DataFrame, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + payload='foobar') + + +def test_frame_equality(): + a = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + b = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + assert_equal(a, b) + + +def test_too_large_frames(): + f = DataFrame( + length=9000, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar' * 3000) + tutils.raises(FrameSizeError, f.to_bytes) + + +def test_data_frame_to_bytes(): + f = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') + + f = DataFrame( + length=11, + flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), + stream_id=0x1234567, + payload='foobar', + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000a00090123456703666f6f626172000000') + + f = DataFrame( + length=6, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload='foobar') + tutils.raises(ValueError, f.to_bytes) + + +def test_data_frame_from_bytes(): + f = Frame.from_file(FileAdapter('000006000101234567666f6f626172')) + assert isinstance(f, DataFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, DataFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_END_STREAM) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.payload, 'foobar') + + f = Frame.from_file(FileAdapter('00000a00090123456703666f6f626172000000')) + assert isinstance(f, DataFrame) + assert_equal(f.length, 10) + assert_equal(f.TYPE, DataFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.payload, 'foobar') + + +def test_data_frame_human_readable(): + f = DataFrame( + length=11, + flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), + stream_id=0x1234567, + payload='foobar', + pad_length=3) + assert f.human_readable() + + +def test_headers_frame_to_bytes(): + f = HeadersFrame( + length=6, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex')) + assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') + + f = HeadersFrame( + length=10, + flags=(HeadersFrame.FLAG_PADDED), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000b01080123456703668594e75e31d9000000') + + f = HeadersFrame( + length=10, + flags=(HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00000c012001234567876543212a668594e75e31d9') + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00001001280123456703876543212a668594e75e31d9000000') + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00001001280123456703076543212a668594e75e31d9000000') + + f = HeadersFrame( + length=6, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment='668594e75e31d9'.decode('hex')) + tutils.raises(ValueError, f.to_bytes) + + +def test_headers_frame_from_bytes(): + f = Frame.from_file(FileAdapter( + '000007010001234567668594e75e31d9')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 7) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + + f = Frame.from_file(FileAdapter( + '00000b01080123456703668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 11) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + + f = Frame.from_file(FileAdapter( + '00000c012001234567876543212a668594e75e31d9')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 12) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(FileAdapter( + '00001001280123456703876543212a668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 16) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(FileAdapter( + '00001001280123456703076543212a668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 16) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, False) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + +def test_headers_frame_human_readable(): + f = HeadersFrame( + length=7, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment=b'', + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert f.human_readable() + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert f.human_readable() + + +def test_priority_frame_to_bytes(): + f = PriorityFrame( + length=5, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') + + f = PriorityFrame( + length=5, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + exclusive=False, + stream_dependency=0x7654321, + weight=21) + assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') + + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + stream_dependency=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + stream_dependency=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_priority_frame_from_bytes(): + f = Frame.from_file(FileAdapter('000005020001234567876543212a')) + assert isinstance(f, PriorityFrame) + assert_equal(f.length, 5) + assert_equal(f.TYPE, PriorityFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(FileAdapter('0000050200012345670765432115')) + assert isinstance(f, PriorityFrame) + assert_equal(f.length, 5) + assert_equal(f.TYPE, PriorityFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.exclusive, False) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 21) + + +def test_priority_frame_human_readable(): + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + exclusive=False, + stream_dependency=0x7654321, + weight=21) + assert f.human_readable() + + +def test_rst_stream_frame_to_bytes(): + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + error_code=0x7654321) + assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') + + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_rst_stream_frame_from_bytes(): + f = Frame.from_file(FileAdapter('00000403000123456707654321')) + assert isinstance(f, RstStreamFrame) + assert_equal(f.length, 4) + assert_equal(f.TYPE, RstStreamFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.error_code, 0x07654321) + + +def test_rst_stream_frame_human_readable(): + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + error_code=0x7654321) + assert f.human_readable() + + +def test_settings_frame_to_bytes(): + f = SettingsFrame( + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0) + assert_equal(f.to_bytes().encode('hex'), '000000040000000000') + + f = SettingsFrame( + length=0, + flags=SettingsFrame.FLAG_ACK, + stream_id=0x0) + assert_equal(f.to_bytes().encode('hex'), '000000040100000000') + + f = SettingsFrame( + length=6, + flags=SettingsFrame.FLAG_ACK, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) + assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') + + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + assert_equal( + f.to_bytes().encode('hex'), + '00000c040000000000000200000001000312345678') + + f = SettingsFrame( + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + +def test_settings_frame_from_bytes(): + f = Frame.from_file(FileAdapter('000000040000000000')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 0) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + + f = Frame.from_file(FileAdapter('000000040100000000')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 0) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, SettingsFrame.FLAG_ACK) + assert_equal(f.stream_id, 0x0) + + f = Frame.from_file(FileAdapter('000006040100000000000200000001')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, SettingsFrame.FLAG_ACK, 0x0) + assert_equal(f.stream_id, 0x0) + assert_equal(len(f.settings), 1) + assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) + + f = Frame.from_file(FileAdapter( + '00000c040000000000000200000001000312345678')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 12) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(len(f.settings), 2) + assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) + assert_equal( + f.settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], + 0x12345678) + + +def test_settings_frame_human_readable(): + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={}) + assert f.human_readable() + + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + assert f.human_readable() + + +def test_push_promise_frame_to_bytes(): + f = PushPromiseFrame( + length=10, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar') + assert_equal( + f.to_bytes().encode('hex'), + '00000a05000123456707654321666f6f626172') + + f = PushPromiseFrame( + length=14, + flags=HeadersFrame.FLAG_PADDED, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar', + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000e0508012345670307654321666f6f626172000000') + + f = PushPromiseFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + f = PushPromiseFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + promised_stream=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_push_promise_frame_from_bytes(): + f = Frame.from_file(FileAdapter('00000a05000123456707654321666f6f626172')) + assert isinstance(f, PushPromiseFrame) + assert_equal(f.length, 10) + assert_equal(f.TYPE, PushPromiseFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + f = Frame.from_file(FileAdapter( + '00000e0508012345670307654321666f6f626172000000')) + assert isinstance(f, PushPromiseFrame) + assert_equal(f.length, 14) + assert_equal(f.TYPE, PushPromiseFrame.TYPE) + assert_equal(f.flags, PushPromiseFrame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + +def test_push_promise_frame_human_readable(): + f = PushPromiseFrame( + length=14, + flags=HeadersFrame.FLAG_PADDED, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar', + pad_length=3) + assert f.human_readable() + + +def test_ping_frame_to_bytes(): + f = PingFrame( + length=8, + flags=PingFrame.FLAG_ACK, + stream_id=0x0, + payload=b'foobar') + assert_equal( + f.to_bytes().encode('hex'), + '000008060100000000666f6f6261720000') + + f = PingFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'foobardeadbeef') + assert_equal( + f.to_bytes().encode('hex'), + '000008060000000000666f6f6261726465') + + f = PingFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + +def test_ping_frame_from_bytes(): + f = Frame.from_file(FileAdapter('000008060100000000666f6f6261720000')) + assert isinstance(f, PingFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, PingFrame.TYPE) + assert_equal(f.flags, PingFrame.FLAG_ACK) + assert_equal(f.stream_id, 0x0) + assert_equal(f.payload, b'foobar\0\0') + + f = Frame.from_file(FileAdapter('000008060000000000666f6f6261726465')) + assert isinstance(f, PingFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, PingFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.payload, b'foobarde') + + +def test_ping_frame_human_readable(): + f = PingFrame( + length=8, + flags=PingFrame.FLAG_ACK, + stream_id=0x0, + payload=b'foobar') + assert f.human_readable() + + +def test_goaway_frame_to_bytes(): + f = GoAwayFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'') + assert_equal( + f.to_bytes().encode('hex'), + '0000080700000000000123456787654321') + + f = GoAwayFrame( + length=14, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'foobar') + assert_equal( + f.to_bytes().encode('hex'), + '00000e0700000000000123456787654321666f6f626172') + + f = GoAwayFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + last_stream=0x1234567, + error_code=0x87654321) + tutils.raises(ValueError, f.to_bytes) + + +def test_goaway_frame_from_bytes(): + f = Frame.from_file(FileAdapter( + '0000080700000000000123456787654321')) + assert isinstance(f, GoAwayFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, GoAwayFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.last_stream, 0x1234567) + assert_equal(f.error_code, 0x87654321) + assert_equal(f.data, b'') + + f = Frame.from_file(FileAdapter( + '00000e0700000000000123456787654321666f6f626172')) + assert isinstance(f, GoAwayFrame) + assert_equal(f.length, 14) + assert_equal(f.TYPE, GoAwayFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.last_stream, 0x1234567) + assert_equal(f.error_code, 0x87654321) + assert_equal(f.data, b'foobar') + + +def test_go_away_frame_human_readable(): + f = GoAwayFrame( + length=14, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'foobar') + assert f.human_readable() + + +def test_window_update_frame_to_bytes(): + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x1234567) + assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') + + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + window_size_increment=0x7654321) + assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') + + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0xdeadbeef) + tutils.raises(ValueError, f.to_bytes) + + f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) + tutils.raises(ValueError, f.to_bytes) + + +def test_window_update_frame_from_bytes(): + f = Frame.from_file(FileAdapter('00000408000000000001234567')) + assert isinstance(f, WindowUpdateFrame) + assert_equal(f.length, 4) + assert_equal(f.TYPE, WindowUpdateFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.window_size_increment, 0x1234567) + + +def test_window_update_frame_human_readable(): + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + window_size_increment=0x7654321) + assert f.human_readable() + + +def test_continuation_frame_to_bytes(): + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + header_block_fragment='foobar') + assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') + + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x0, + header_block_fragment='foobar') + tutils.raises(ValueError, f.to_bytes) + + +def test_continuation_frame_from_bytes(): + f = Frame.from_file(FileAdapter('000006090401234567666f6f626172')) + assert isinstance(f, ContinuationFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, ContinuationFrame.TYPE) + assert_equal(f.flags, ContinuationFrame.FLAG_END_HEADERS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + +def test_continuation_frame_human_readable(): + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + header_block_fragment='foobar') + assert f.human_readable() diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py new file mode 100644 index 00000000..6a275430 --- /dev/null +++ b/test/http2/test_http2_protocol.py @@ -0,0 +1,216 @@ + +import OpenSSL + +from netlib import http2 +from netlib import tcp +from netlib import test +from netlib.http2.frame import * +from test import tutils + + +class EchoHandler(tcp.BaseHandler): + sni = None + + def handle(self): + v = self.rfile.readline() + self.wfile.write(v) + self.wfile.flush() + + +class TestCheckALPNMatch(test.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) + protocol = http2.HTTP2Protocol(c) + assert protocol.check_alpn() + + +class TestCheckALPNMismatch(test.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=None, + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) + protocol = http2.HTTP2Protocol(c) + tutils.raises(NotImplementedError, protocol.check_alpn) + + +class TestPerformConnectionPreface(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check magic + assert self.rfile.read(24) ==\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_perform_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + protocol.perform_connection_preface() + + +class TestStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c) + + def test_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol.next_stream_id() == 1 + assert self.protocol.current_stream_id == 1 + assert self.protocol.next_stream_id() == 3 + assert self.protocol.current_stream_id == 3 + assert self.protocol.next_stream_id() == 5 + assert self.protocol.current_stream_id == 5 + + +class TestApplySettings(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check settings acknowledgement + assert self.rfile.read(9) == '000000040100000000'.decode('hex') + self.wfile.write("OK") + self.wfile.flush() + + ssl = True + + def test_apply_settings(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + protocol._apply_settings({ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 'bar', + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 'deadbeef', + }) + + assert c.rfile.safe_read(2) == "OK" + + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 'foo' + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 'bar' + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE] == 'deadbeef' + + +class TestCreateHeaders(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_headers(self): + headers = [ + (b':method', b'GET'), + (b':path', b'index.html'), + (b':scheme', b'https'), + (b'foo', b'bar')] + + bytes = http2.HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=True) + assert b''.join(bytes) ==\ + '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + bytes = http2.HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=False) + assert b''.join(bytes) ==\ + '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + # TODO: add test for too large header_block_fragments + + +class TestCreateBody(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c) + + def test_create_body_empty(self): + bytes = self.protocol._create_body(b'', 1) + assert b''.join(bytes) == ''.decode('hex') + + def test_create_body_single_frame(self): + bytes = self.protocol._create_body('foobar', 1) + assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') + + def test_create_body_multiple_frames(self): + pass + # bytes = self.protocol._create_body('foobar' * 3000, 1) + # TODO: add test for too large frames + + +class TestCreateRequest(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_request_simple(self): + bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') + assert len(bytes) == 1 + assert bytes[0] == '000003010500000001828487'.decode('hex') + + def test_create_request_with_body(self): + bytes = http2.HTTP2Protocol(self.c).create_request( + 'GET', '/', [(b'foo', b'bar')], 'foobar') + assert len(bytes) == 2 + assert bytes[0] ==\ + '00000b010400000001828487408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') + + +class TestReadResponse(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801040000000188628594e78c767f'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + status, headers, body = protocol.read_response() + + assert headers == {':status': '200', 'etag': 'foobar'} + assert status == '200' + assert body == b'foobar' -- cgit v1.2.3 From 49043131cc49a602f54c1671ef5637b606c401b7 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 5 Jun 2015 19:39:15 +0200 Subject: increase test coverage --- test/test_tcp.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_tcp.py b/test/test_tcp.py index cbe92f3c..f8fc6a28 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -575,6 +575,11 @@ class TestFileLike: s = tcp.Reader(o) tutils.raises(tcp.NetLibDisconnect, s.readline, 10) + def test_reader_incomplete_error(self): + s = cStringIO.StringIO("foobar") + s = tcp.Reader(s) + tutils.raises(tcp.NetLibIncomplete, s.safe_read, 10) + class TestAddress: -- cgit v1.2.3 From e7c84a1ce14ca339184de1cd615727144d50d381 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 5 Jun 2015 20:05:05 +0200 Subject: make travis run all tests --- test/http2/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 test/http2/__init__.py diff --git a/test/http2/__init__.py b/test/http2/__init__.py new file mode 100644 index 00000000..e69de29b -- cgit v1.2.3 From 6c1c6f5f0ad375d5a8f37007e6cf2d6862282de9 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 5 Jun 2015 20:49:03 +0200 Subject: http2: fix EchoHandler test helper --- test/http2/test_http2_protocol.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py index 6a275430..cb46bc68 100644 --- a/test/http2/test_http2_protocol.py +++ b/test/http2/test_http2_protocol.py @@ -12,9 +12,10 @@ class EchoHandler(tcp.BaseHandler): sni = None def handle(self): - v = self.rfile.readline() - self.wfile.write(v) - self.wfile.flush() + while True: + v = self.rfile.safe_read(1) + self.wfile.write(v) + self.wfile.flush() class TestCheckALPNMatch(test.ServerTestBase): -- cgit v1.2.3 From f2db8abbe859266bb28117e1ffa4b0b99d62e321 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 5 Jun 2015 20:52:11 +0200 Subject: use open instead of file --- netlib/test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/test.py b/netlib/test.py index 4b0b6bd2..1e1b5e9d 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -75,13 +75,13 @@ class TServer(tcp.TCPServer): raw_cert = self.ssl.get( "cert", tutils.test_data.path("data/server.crt")) - cert = certutils.SSLCert.from_pem(file(raw_cert, "rb").read()) + cert = certutils.SSLCert.from_pem(open(raw_cert, "rb").read()) raw_key = self.ssl.get( "key", tutils.test_data.path("data/server.key")) key = OpenSSL.crypto.load_privatekey( OpenSSL.crypto.FILETYPE_PEM, - file(raw_key, "rb").read()) + open(raw_key, "rb").read()) if self.ssl.get("v3_only", False): method = tcp.SSLv3_METHOD options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 -- cgit v1.2.3 From e39d8aed6d77b6cf5d57c795c69e735a7c1430fa Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 5 Jun 2015 20:55:32 +0200 Subject: http2: refactor hex to file adapter --- test/http2/test_frames.py | 64 ++++++++++++++++++++--------------------------- 1 file changed, 27 insertions(+), 37 deletions(-) diff --git a/test/http2/test_frames.py b/test/http2/test_frames.py index d8f00dec..76a4b712 100644 --- a/test/http2/test_frames.py +++ b/test/http2/test_frames.py @@ -1,23 +1,13 @@ +import cStringIO from test import tutils from nose.tools import assert_equal +from netlib import tcp from netlib.http2.frame import * -class FileAdapter(object): - def __init__(self, data, is_hex=True): - self.position = 0 - if is_hex: - self.data = data.decode('hex') - else: - self.data = data - - def safe_read(self, length): - if self.position + length > len(self.data): - raise ValueError("not enough bytes to read") - - value = self.data[self.position:self.position + length] - self.position += length - return value +def hex_to_file(data): + data = data.decode('hex') + return tcp.Reader(cStringIO.StringIO(data)) def test_invalid_flags(): @@ -79,7 +69,7 @@ def test_data_frame_to_bytes(): def test_data_frame_from_bytes(): - f = Frame.from_file(FileAdapter('000006000101234567666f6f626172')) + f = Frame.from_file(hex_to_file('000006000101234567666f6f626172')) assert isinstance(f, DataFrame) assert_equal(f.length, 6) assert_equal(f.TYPE, DataFrame.TYPE) @@ -87,7 +77,7 @@ def test_data_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.payload, 'foobar') - f = Frame.from_file(FileAdapter('00000a00090123456703666f6f626172000000')) + f = Frame.from_file(hex_to_file('00000a00090123456703666f6f626172000000')) assert isinstance(f, DataFrame) assert_equal(f.length, 10) assert_equal(f.TYPE, DataFrame.TYPE) @@ -171,7 +161,7 @@ def test_headers_frame_to_bytes(): def test_headers_frame_from_bytes(): - f = Frame.from_file(FileAdapter( + f = Frame.from_file(hex_to_file( '000007010001234567668594e75e31d9')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 7) @@ -180,7 +170,7 @@ def test_headers_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - f = Frame.from_file(FileAdapter( + f = Frame.from_file(hex_to_file( '00000b01080123456703668594e75e31d9000000')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 11) @@ -189,7 +179,7 @@ def test_headers_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - f = Frame.from_file(FileAdapter( + f = Frame.from_file(hex_to_file( '00000c012001234567876543212a668594e75e31d9')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 12) @@ -201,7 +191,7 @@ def test_headers_frame_from_bytes(): assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) - f = Frame.from_file(FileAdapter( + f = Frame.from_file(hex_to_file( '00001001280123456703876543212a668594e75e31d9000000')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 16) @@ -213,7 +203,7 @@ def test_headers_frame_from_bytes(): assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) - f = Frame.from_file(FileAdapter( + f = Frame.from_file(hex_to_file( '00001001280123456703076543212a668594e75e31d9000000')) assert isinstance(f, HeadersFrame) assert_equal(f.length, 16) @@ -285,7 +275,7 @@ def test_priority_frame_to_bytes(): def test_priority_frame_from_bytes(): - f = Frame.from_file(FileAdapter('000005020001234567876543212a')) + f = Frame.from_file(hex_to_file('000005020001234567876543212a')) assert isinstance(f, PriorityFrame) assert_equal(f.length, 5) assert_equal(f.TYPE, PriorityFrame.TYPE) @@ -295,7 +285,7 @@ def test_priority_frame_from_bytes(): assert_equal(f.stream_dependency, 0x7654321) assert_equal(f.weight, 42) - f = Frame.from_file(FileAdapter('0000050200012345670765432115')) + f = Frame.from_file(hex_to_file('0000050200012345670765432115')) assert isinstance(f, PriorityFrame) assert_equal(f.length, 5) assert_equal(f.TYPE, PriorityFrame.TYPE) @@ -333,7 +323,7 @@ def test_rst_stream_frame_to_bytes(): def test_rst_stream_frame_from_bytes(): - f = Frame.from_file(FileAdapter('00000403000123456707654321')) + f = Frame.from_file(hex_to_file('00000403000123456707654321')) assert isinstance(f, RstStreamFrame) assert_equal(f.length, 4) assert_equal(f.TYPE, RstStreamFrame.TYPE) @@ -391,21 +381,21 @@ def test_settings_frame_to_bytes(): def test_settings_frame_from_bytes(): - f = Frame.from_file(FileAdapter('000000040000000000')) + f = Frame.from_file(hex_to_file('000000040000000000')) assert isinstance(f, SettingsFrame) assert_equal(f.length, 0) assert_equal(f.TYPE, SettingsFrame.TYPE) assert_equal(f.flags, Frame.FLAG_NO_FLAGS) assert_equal(f.stream_id, 0x0) - f = Frame.from_file(FileAdapter('000000040100000000')) + f = Frame.from_file(hex_to_file('000000040100000000')) assert isinstance(f, SettingsFrame) assert_equal(f.length, 0) assert_equal(f.TYPE, SettingsFrame.TYPE) assert_equal(f.flags, SettingsFrame.FLAG_ACK) assert_equal(f.stream_id, 0x0) - f = Frame.from_file(FileAdapter('000006040100000000000200000001')) + f = Frame.from_file(hex_to_file('000006040100000000000200000001')) assert isinstance(f, SettingsFrame) assert_equal(f.length, 6) assert_equal(f.TYPE, SettingsFrame.TYPE) @@ -414,7 +404,7 @@ def test_settings_frame_from_bytes(): assert_equal(len(f.settings), 1) assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - f = Frame.from_file(FileAdapter( + f = Frame.from_file(hex_to_file( '00000c040000000000000200000001000312345678')) assert isinstance(f, SettingsFrame) assert_equal(f.length, 12) @@ -485,7 +475,7 @@ def test_push_promise_frame_to_bytes(): def test_push_promise_frame_from_bytes(): - f = Frame.from_file(FileAdapter('00000a05000123456707654321666f6f626172')) + f = Frame.from_file(hex_to_file('00000a05000123456707654321666f6f626172')) assert isinstance(f, PushPromiseFrame) assert_equal(f.length, 10) assert_equal(f.TYPE, PushPromiseFrame.TYPE) @@ -493,7 +483,7 @@ def test_push_promise_frame_from_bytes(): assert_equal(f.stream_id, 0x1234567) assert_equal(f.header_block_fragment, 'foobar') - f = Frame.from_file(FileAdapter( + f = Frame.from_file(hex_to_file( '00000e0508012345670307654321666f6f626172000000')) assert isinstance(f, PushPromiseFrame) assert_equal(f.length, 14) @@ -541,7 +531,7 @@ def test_ping_frame_to_bytes(): def test_ping_frame_from_bytes(): - f = Frame.from_file(FileAdapter('000008060100000000666f6f6261720000')) + f = Frame.from_file(hex_to_file('000008060100000000666f6f6261720000')) assert isinstance(f, PingFrame) assert_equal(f.length, 8) assert_equal(f.TYPE, PingFrame.TYPE) @@ -549,7 +539,7 @@ def test_ping_frame_from_bytes(): assert_equal(f.stream_id, 0x0) assert_equal(f.payload, b'foobar\0\0') - f = Frame.from_file(FileAdapter('000008060000000000666f6f6261726465')) + f = Frame.from_file(hex_to_file('000008060000000000666f6f6261726465')) assert isinstance(f, PingFrame) assert_equal(f.length, 8) assert_equal(f.TYPE, PingFrame.TYPE) @@ -600,7 +590,7 @@ def test_goaway_frame_to_bytes(): def test_goaway_frame_from_bytes(): - f = Frame.from_file(FileAdapter( + f = Frame.from_file(hex_to_file( '0000080700000000000123456787654321')) assert isinstance(f, GoAwayFrame) assert_equal(f.length, 8) @@ -611,7 +601,7 @@ def test_goaway_frame_from_bytes(): assert_equal(f.error_code, 0x87654321) assert_equal(f.data, b'') - f = Frame.from_file(FileAdapter( + f = Frame.from_file(hex_to_file( '00000e0700000000000123456787654321666f6f626172')) assert isinstance(f, GoAwayFrame) assert_equal(f.length, 14) @@ -661,7 +651,7 @@ def test_window_update_frame_to_bytes(): def test_window_update_frame_from_bytes(): - f = Frame.from_file(FileAdapter('00000408000000000001234567')) + f = Frame.from_file(hex_to_file('00000408000000000001234567')) assert isinstance(f, WindowUpdateFrame) assert_equal(f.length, 4) assert_equal(f.TYPE, WindowUpdateFrame.TYPE) @@ -696,7 +686,7 @@ def test_continuation_frame_to_bytes(): def test_continuation_frame_from_bytes(): - f = Frame.from_file(FileAdapter('000006090401234567666f6f626172')) + f = Frame.from_file(hex_to_file('000006090401234567666f6f626172')) assert isinstance(f, ContinuationFrame) assert_equal(f.length, 6) assert_equal(f.TYPE, ContinuationFrame.TYPE) -- cgit v1.2.3 From f2d784896dd18ea7ded9b3a95bedcdceb3325213 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 6 Jun 2015 12:26:48 +1200 Subject: http2: resolve module structure and circular dependencies - Move implementation out of __init__.py to protocol.py (an anti-pattern because it makes the kind of structural refactoring we need hard) - protocol imports frame, frame does not import protocol. To do this, we shift the default settings to frame. If this feels wrong, we can move them to a separate module (defaults.py?.). --- netlib/http2/__init__.py | 183 +---------------------------------------------- netlib/http2/frame.py | 18 +++-- netlib/http2/protocol.py | 174 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 188 insertions(+), 187 deletions(-) create mode 100644 netlib/http2/protocol.py diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py index 2803cccb..92897b5d 100644 --- a/netlib/http2/__init__.py +++ b/netlib/http2/__init__.py @@ -1,182 +1,3 @@ -from __future__ import (absolute_import, print_function, division) -import itertools -import logging -from .frame import * -from .. import utils - -log = logging.getLogger(__name__) - - -class HTTP2Protocol(object): - - ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd - ) - - # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' - - ALPN_PROTO_H2 = 'h2' - - HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, - } - - def __init__(self, tcp_client): - self.tcp_client = tcp_client - - self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() - self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() - - def check_alpn(self): - alp = self.tcp_client.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - log.debug("ALP 'h2' successfully negotiated.") - return True - - def perform_connection_preface(self): - self.tcp_client.wfile.write( - bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) - self.send_frame(SettingsFrame(state=self)) - - # read server settings frame - frame = Frame.from_file(self.tcp_client.rfile, self) - assert isinstance(frame, SettingsFrame) - self._apply_settings(frame.settings) - - # read setting ACK frame - settings_ack_frame = self.read_frame() - assert isinstance(settings_ack_frame, SettingsFrame) - assert settings_ack_frame.flags & Frame.FLAG_ACK - assert len(settings_ack_frame.settings) == 0 - - log.debug("Connection Preface completed.") - - def next_stream_id(self): - if self.current_stream_id is None: - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - - def send_frame(self, frame): - raw_bytes = frame.to_bytes() - self.tcp_client.wfile.write(raw_bytes) - self.tcp_client.wfile.flush() - - def read_frame(self): - frame = Frame.from_file(self.tcp_client.rfile, self) - if isinstance(frame, SettingsFrame): - self._apply_settings(frame.settings) - - return frame - - def _apply_settings(self, settings): - for setting, value in settings.items(): - old_value = self.http2_settings[setting] - if not old_value: - old_value = '-' - - self.http2_settings[setting] = value - log.debug("Setting changed: %s to %s (was %s)" % ( - SettingsFrame.SETTINGS.get_name(setting), - str(value), - str(old_value))) - - self.send_frame(SettingsFrame(state=self, flags=Frame.FLAG_ACK)) - log.debug("New settings acknowledged.") - - def _create_headers(self, headers, stream_id, end_stream=True): - # TODO: implement max frame size checks and sending in chunks - - flags = Frame.FLAG_END_HEADERS - if end_stream: - flags |= Frame.FLAG_END_STREAM - - header_block_fragment = self.encoder.encode(headers) - - bytes = HeadersFrame( - state=self, - flags=flags, - stream_id=stream_id, - header_block_fragment=header_block_fragment).to_bytes() - return [bytes] - - def _create_body(self, body, stream_id): - if body is None or len(body) == 0: - return b'' - - # TODO: implement max frame size checks and sending in chunks - # TODO: implement flow-control window - - bytes = DataFrame( - state=self, - flags=Frame.FLAG_END_STREAM, - stream_id=stream_id, - payload=body).to_bytes() - return [bytes] - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https')] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self): - header_block_fragment = b'' - body = b'' - - while True: - frame = self.read_frame() - if isinstance(frame, HeadersFrame): - header_block_fragment += frame.header_block_fragment - if frame.flags | Frame.FLAG_END_HEADERS: - break - - while True: - frame = self.read_frame() - if isinstance(frame, DataFrame): - body += frame.payload - if frame.flags | Frame.FLAG_END_STREAM: - break - - headers = {} - for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value - - for header, value in headers.items(): - log.debug("%s: %s" % (header, value)) - - return headers[':status'], headers, body +from frame import * +from protocol import * diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index 1497380a..fc86c228 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -38,13 +38,11 @@ class Frame(object): raise ValueError('invalid flags detected.') if state is None: - from . import HTTP2Protocol - class State(object): pass state = State() - state.http2_settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS.copy() + state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() state.encoder = Encoder() state.decoder = Decoder() @@ -57,12 +55,10 @@ class Frame(object): @classmethod def _check_frame_size(self, length, state): - from . import HTTP2Protocol - if state: settings = state.http2_settings else: - settings = HTTP2Protocol.HTTP2_DEFAULT_SETTINGS + settings = HTTP2_DEFAULT_SETTINGS.copy() max_frame_size = settings[ SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] @@ -623,3 +619,13 @@ _FRAME_CLASSES = [ ContinuationFrame ] FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} + + +HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, +} diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py new file mode 100644 index 00000000..9bab431c --- /dev/null +++ b/netlib/http2/protocol.py @@ -0,0 +1,174 @@ +from __future__ import (absolute_import, print_function, division) +import itertools +import logging + +from hpack.hpack import Encoder, Decoder +from .. import utils +from . import frame + +log = logging.getLogger(__name__) + + +class HTTP2Protocol(object): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + + ALPN_PROTO_H2 = 'h2' + + def __init__(self, tcp_client): + self.tcp_client = tcp_client + + self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.encoder = Encoder() + self.decoder = Decoder() + + def check_alpn(self): + alp = self.tcp_client.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + log.debug("ALP 'h2' successfully negotiated.") + return True + + def perform_connection_preface(self): + self.tcp_client.wfile.write( + bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) + self.send_frame(frame.SettingsFrame(state=self)) + + # read server settings frame + frm = frame.Frame.from_file(self.tcp_client.rfile, self) + assert isinstance(frm, frame.SettingsFrame) + self._apply_settings(frm.settings) + + # read setting ACK frame + settings_ack_frame = self.read_frame() + assert isinstance(settings_ack_frame, frame.SettingsFrame) + assert settings_ack_frame.flags & frame.Frame.FLAG_ACK + assert len(settings_ack_frame.settings) == 0 + + log.debug("Connection Preface completed.") + + def next_stream_id(self): + if self.current_stream_id is None: + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def send_frame(self, frame): + raw_bytes = frame.to_bytes() + self.tcp_client.wfile.write(raw_bytes) + self.tcp_client.wfile.flush() + + def read_frame(self): + frm = frame.Frame.from_file(self.tcp_client.rfile, self) + if isinstance(frm, frame.SettingsFrame): + self._apply_settings(frm.settings) + + return frm + + def _apply_settings(self, settings): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + + self.http2_settings[setting] = value + log.debug("Setting changed: %s to %s (was %s)" % ( + frame.SettingsFrame.SETTINGS.get_name(setting), + str(value), + str(old_value))) + + self.send_frame(frame.SettingsFrame(state=self, flags=frame.Frame.FLAG_ACK)) + log.debug("New settings acknowledged.") + + def _create_headers(self, headers, stream_id, end_stream=True): + # TODO: implement max frame size checks and sending in chunks + + flags = frame.Frame.FLAG_END_HEADERS + if end_stream: + flags |= frame.Frame.FLAG_END_STREAM + + header_block_fragment = self.encoder.encode(headers) + + bytes = frame.HeadersFrame( + state=self, + flags=flags, + stream_id=stream_id, + header_block_fragment=header_block_fragment).to_bytes() + return [bytes] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + # TODO: implement max frame size checks and sending in chunks + # TODO: implement flow-control window + + bytes = frame.DataFrame( + state=self, + flags=frame.Frame.FLAG_END_STREAM, + stream_id=stream_id, + payload=body).to_bytes() + return [bytes] + + def create_request(self, method, path, headers=None, body=None): + if headers is None: + headers = [] + + headers = [ + (b':method', bytes(method)), + (b':path', bytes(path)), + (b':scheme', b'https')] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) + + def read_response(self): + header_block_fragment = b'' + body = b'' + + while True: + frm = self.read_frame() + if isinstance(frm, frame.HeadersFrame): + header_block_fragment += frm.header_block_fragment + if frm.flags | frame.Frame.FLAG_END_HEADERS: + break + + while True: + frm = self.read_frame() + if isinstance(frm, frame.DataFrame): + body += frm.payload + if frm.flags | frame.Frame.FLAG_END_STREAM: + break + + headers = {} + for header, value in self.decoder.decode(header_block_fragment): + headers[header] = value + + for header, value in headers.items(): + log.debug("%s: %s" % (header, value)) + + return headers[':status'], headers, body -- cgit v1.2.3 From 9c48bfb2a53bf3ac3c29408511e3126ada16afd8 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 6 Jun 2015 12:30:53 +1200 Subject: http2: ditch the logging for now The API is well designed: it looks like we can get all the information we need to expose debugging in the caller of the API. --- netlib/http2/frame.py | 3 --- netlib/http2/protocol.py | 13 ------------- 2 files changed, 16 deletions(-) diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index fc86c228..ac9b8d50 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -1,12 +1,9 @@ import struct -import logging from functools import reduce from hpack.hpack import Encoder, Decoder from .. import utils -log = logging.getLogger(__name__) - class FrameSizeError(Exception): pass diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 9bab431c..459c2293 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -1,13 +1,10 @@ from __future__ import (absolute_import, print_function, division) import itertools -import logging from hpack.hpack import Encoder, Decoder from .. import utils from . import frame -log = logging.getLogger(__name__) - class HTTP2Protocol(object): @@ -46,7 +43,6 @@ class HTTP2Protocol(object): if alp != self.ALPN_PROTO_H2: raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) - log.debug("ALP 'h2' successfully negotiated.") return True def perform_connection_preface(self): @@ -65,7 +61,6 @@ class HTTP2Protocol(object): assert settings_ack_frame.flags & frame.Frame.FLAG_ACK assert len(settings_ack_frame.settings) == 0 - log.debug("Connection Preface completed.") def next_stream_id(self): if self.current_stream_id is None: @@ -93,13 +88,8 @@ class HTTP2Protocol(object): old_value = '-' self.http2_settings[setting] = value - log.debug("Setting changed: %s to %s (was %s)" % ( - frame.SettingsFrame.SETTINGS.get_name(setting), - str(value), - str(old_value))) self.send_frame(frame.SettingsFrame(state=self, flags=frame.Frame.FLAG_ACK)) - log.debug("New settings acknowledged.") def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -168,7 +158,4 @@ class HTTP2Protocol(object): for header, value in self.decoder.decode(header_block_fragment): headers[header] = value - for header, value in headers.items(): - log.debug("%s: %s" % (header, value)) - return headers[':status'], headers, body -- cgit v1.2.3 From 359ef469054b6a80ff8a5a3148a52e864a76fe9b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Jun 2015 12:21:08 +0200 Subject: fix coding style --- netlib/http2/protocol.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 459c2293..feac220c 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -61,7 +61,6 @@ class HTTP2Protocol(object): assert settings_ack_frame.flags & frame.Frame.FLAG_ACK assert len(settings_ack_frame.settings) == 0 - def next_stream_id(self): if self.current_stream_id is None: self.current_stream_id = 1 @@ -89,7 +88,10 @@ class HTTP2Protocol(object): self.http2_settings[setting] = value - self.send_frame(frame.SettingsFrame(state=self, flags=frame.Frame.FLAG_ACK)) + self.send_frame( + frame.SettingsFrame( + state=self, + flags=frame.Frame.FLAG_ACK)) def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks -- cgit v1.2.3 From ff478b5290fdd9aad9d8ba5b4e48a2f4bf54177c Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Jun 2015 12:34:01 +0200 Subject: ignore eggs directory --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 68d71ab6..c3c6f1cb 100644 --- a/.gitignore +++ b/.gitignore @@ -10,5 +10,6 @@ MANIFEST .idea/ __pycache__ _cffi__* +.eggs/ netlib.egg-info/ pathod/ -- cgit v1.2.3 From 4666d1e7bbf77b470d938d873d1a760283963adf Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Jun 2015 11:29:01 +0200 Subject: improve ALPN support on travis --- .travis.yml | 53 +++++++++++++++++++++++++++++++++++++++++++++++------ netlib/tcp.py | 19 +++++++++++-------- 2 files changed, 58 insertions(+), 14 deletions(-) diff --git a/.travis.yml b/.travis.yml index a1eafcea..83fcc265 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,26 +1,67 @@ language: python + sudo: false + python: - "2.7" - pypy -# command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors + +matrix: + include: + - python: 2.7 + env: OPENSSL=1.0.2 + addons: + apt: + sources: + # Debian sid currently holds OpenSSL 1.0.2 + # change this with future releases! + - debian-sid + packages: + - libssl-dev + - python: pypy + env: OPENSSL=1.0.2 + addons: + apt: + sources: + # Debian sid currently holds OpenSSL 1.0.2 + # change this with future releases! + - debian-sid + packages: + - libssl-dev + install: - "pip install --src . -r requirements.txt" -# command to run tests, e.g. python setup.py test + +before_script: + - "openssl version -a" + script: - "nosetests --with-cov --cov-report term-missing" - "./check_coding_style.sh" + after_success: - coveralls + notifications: irc: channels: - "irc.oftc.net#mitmproxy" on_success: change on_failure: always + +# exclude cryptography from cache +# it depends on libssl-dev version +# which needs to be compiled specifically to each version +before_cache: + - pip uninstall -y cryptography + - rm -rf /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages/cryptography/ + - rm -rf /home/travis/virtualenv/pypy-2.5.0/site-packages/cryptography/ + - rm /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages/pip/_vendor/requests/packages/urllib3/contrib/pyopenssl.py + - rm /home/travis/virtualenv/pypy-2.5.0/site-packages/pip/_vendor/requests/packages/urllib3/contrib/pyopenssl.py + cache: directories: - - /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages - - /home/travis/virtualenv/python2.7.9/bin - - /home/travis/virtualenv/pypy-2.5.0/site-packages - - /home/travis/virtualenv/pypy-2.5.0/bin + - /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages + - /home/travis/virtualenv/python2.7.9/bin + - /home/travis/virtualenv/pypy-2.5.0/site-packages + - /home/travis/virtualenv/pypy-2.5.0/bin diff --git a/netlib/tcp.py b/netlib/tcp.py index f6179faa..fc2ce115 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -6,6 +6,8 @@ import sys import threading import time import traceback + +import OpenSSL from OpenSSL import SSL from . import certutils @@ -401,16 +403,17 @@ class _Connection(object): if log_ssl_key: context.set_info_callback(log_ssl_key) - # advertise application layer protocols - if alpn_protos is not None: - context.set_alpn_protos(alpn_protos) + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + # advertise application layer protocols + if alpn_protos is not None: + context.set_alpn_protos(alpn_protos) - # select application layer protocol - if alpn_select is not None: - def alpn_select_f(conn, options): - return bytes(alpn_select) + # select application layer protocol + if alpn_select is not None: + def alpn_select_f(conn, options): + return bytes(alpn_select) - context.set_alpn_select_callback(alpn_select_f) + context.set_alpn_select_callback(alpn_select_f) return context -- cgit v1.2.3 From abbe88c8ce4f19de33723ac0828cd24b8ec5f38b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Jun 2015 13:25:42 +0200 Subject: fix non-ALPN supported OpenSSL-related tests --- netlib/tcp.py | 5 ++++- test/test_tcp.py | 8 +++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index fc2ce115..09c43ffc 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -499,7 +499,10 @@ class TCPClient(_Connection): return self.connection.gettimeout() def get_alpn_proto_negotiated(self): - return self.connection.get_alpn_proto_negotiated() + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + return self.connection.get_alpn_proto_negotiated() + else: + return None class BaseHandler(_Connection): diff --git a/test/test_tcp.py b/test/test_tcp.py index f8fc6a28..d5506556 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -370,13 +370,19 @@ class TestALPN(test.ServerTestBase): ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: - def test_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(alpn_protos=["foobar"]) assert c.get_alpn_proto_negotiated() == "foobar" + else: + def test_none_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=["foobar"]) + assert c.get_alpn_proto_negotiated() == None + class TestSSLTimeOut(test.ServerTestBase): handler = HangHandler -- cgit v1.2.3 From fdbb3b76cf8cd7caaa644dc31e48521096ed5349 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Jun 2015 16:54:19 +0200 Subject: http2: add warning if raw data looks like HTTP/1 --- netlib/http2/frame.py | 4 ++++ netlib/tcp.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index ac9b8d50..4a305d82 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -1,3 +1,4 @@ +import sys import struct from functools import reduce from hpack.hpack import Encoder, Decoder @@ -79,6 +80,9 @@ class Frame(object): flags = fields[3] stream_id = fields[4] + if raw_header[:4] == b'HTTP': # pragma no cover + print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" + self._check_frame_size(length, state) payload = fp.safe_read(length) diff --git a/netlib/tcp.py b/netlib/tcp.py index 09c43ffc..62545244 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -501,7 +501,7 @@ class TCPClient(_Connection): def get_alpn_proto_negotiated(self): if OpenSSL._util.lib.Cryptography_HAS_ALPN: return self.connection.get_alpn_proto_negotiated() - else: + else: # pragma no cover return None -- cgit v1.2.3 From 0595585974dd889a10e05cade06f5534c85d7401 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Jun 2015 17:00:03 +0200 Subject: fix coding style --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 62545244..9a980035 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -501,7 +501,7 @@ class TCPClient(_Connection): def get_alpn_proto_negotiated(self): if OpenSSL._util.lib.Cryptography_HAS_ALPN: return self.connection.get_alpn_proto_negotiated() - else: # pragma no cover + else: # pragma no cover return None -- cgit v1.2.3 From eeaed93a83fbe14762e263e9f25b5361088daa15 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 11 Jun 2015 15:37:17 +0200 Subject: improve ALPN integration --- netlib/tcp.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 9a980035..98b17c50 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -404,16 +404,17 @@ class _Connection(object): context.set_info_callback(log_ssl_key) if OpenSSL._util.lib.Cryptography_HAS_ALPN: - # advertise application layer protocols if alpn_protos is not None: + # advertise application layer protocols context.set_alpn_protos(alpn_protos) - - # select application layer protocol - if alpn_select is not None: - def alpn_select_f(conn, options): - return bytes(alpn_select) - - context.set_alpn_select_callback(alpn_select_f) + elif alpn_select is not None: + # select application layer protocol + def alpn_select_callback(conn, options): + if alpn_select in options: + return bytes(alpn_select) + else: + return options[0] + context.set_alpn_select_callback(alpn_select_callback) return context @@ -612,6 +613,12 @@ class BaseHandler(_Connection): def settimeout(self, n): self.connection.settimeout(n) + def get_alpn_proto_negotiated(self): + if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: + return self.connection.get_alpn_proto_negotiated() + else: # pragma no cover + return None + class TCPServer(object): request_queue_size = 20 -- cgit v1.2.3 From 8ea157775debeccfa0f2fab3aa7e009d13ce4391 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 11 Jun 2015 15:38:32 +0200 Subject: http2: general improvements --- netlib/http2/protocol.py | 63 ++++++++++++++++++++++++++------------- test/http2/test_http2_protocol.py | 41 +++++++++++++++++++++---- 2 files changed, 78 insertions(+), 26 deletions(-) diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index feac220c..4b69764f 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -26,12 +26,13 @@ class HTTP2Protocol(object): ) # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a' + CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') ALPN_PROTO_H2 = 'h2' - def __init__(self, tcp_client): - self.tcp_client = tcp_client + def __init__(self, tcp_handler, is_server=False): + self.tcp_handler = tcp_handler + self.is_server = is_server self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None @@ -39,28 +40,39 @@ class HTTP2Protocol(object): self.decoder = Decoder() def check_alpn(self): - alp = self.tcp_client.get_alpn_proto_negotiated() + alp = self.tcp_handler.get_alpn_proto_negotiated() if alp != self.ALPN_PROTO_H2: raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True - def perform_connection_preface(self): - self.tcp_client.wfile.write( - bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex'))) - self.send_frame(frame.SettingsFrame(state=self)) - - # read server settings frame - frm = frame.Frame.from_file(self.tcp_client.rfile, self) + def _receive_settings(self): + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) assert isinstance(frm, frame.SettingsFrame) self._apply_settings(frm.settings) - # read setting ACK frame + def _read_settings_ack(self): settings_ack_frame = self.read_frame() assert isinstance(settings_ack_frame, frame.SettingsFrame) assert settings_ack_frame.flags & frame.Frame.FLAG_ACK assert len(settings_ack_frame.settings) == 0 + def perform_server_connection_preface(self): + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE + + self.send_frame(frame.SettingsFrame(state=self)) + self._receive_settings() + self._read_settings_ack() + + def perform_client_connection_preface(self): + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self)) + self._receive_settings() + self._read_settings_ack() + def next_stream_id(self): if self.current_stream_id is None: self.current_stream_id = 1 @@ -70,11 +82,11 @@ class HTTP2Protocol(object): def send_frame(self, frame): raw_bytes = frame.to_bytes() - self.tcp_client.wfile.write(raw_bytes) - self.tcp_client.wfile.flush() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() def read_frame(self): - frm = frame.Frame.from_file(self.tcp_client.rfile, self) + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) if isinstance(frm, frame.SettingsFrame): self._apply_settings(frm.settings) @@ -139,25 +151,36 @@ class HTTP2Protocol(object): self._create_body(body, stream_id))) def read_response(self): + headers, body = self._receive_transmission() + return headers[':status'], headers, body + + def read_request(self): + return self._receive_transmission() + + def _receive_transmission(self): + body_expected = True + header_block_fragment = b'' body = b'' while True: frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame): + if isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame): header_block_fragment += frm.header_block_fragment - if frm.flags | frame.Frame.FLAG_END_HEADERS: + if frm.flags & frame.Frame.FLAG_END_HEADERS: + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False break - while True: + while body_expected: frm = self.read_frame() if isinstance(frm, frame.DataFrame): body += frm.payload - if frm.flags | frame.Frame.FLAG_END_STREAM: + if frm.flags & frame.Frame.FLAG_END_STREAM: break headers = {} for header, value in self.decoder.decode(header_block_fragment): headers[header] = value - return headers[':status'], headers, body + return headers, body diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py index cb46bc68..1591edd8 100644 --- a/test/http2/test_http2_protocol.py +++ b/test/http2/test_http2_protocol.py @@ -50,7 +50,39 @@ class TestCheckALPNMismatch(test.ServerTestBase): tutils.raises(NotImplementedError, protocol.check_alpn) -class TestPerformConnectionPreface(test.ServerTestBase): +class TestPerformServerConnectionPreface(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # send magic + self.wfile.write(\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) + self.wfile.flush() + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + def test_perform_server_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = http2.HTTP2Protocol(c) + protocol.perform_server_connection_preface() + + +class TestPerformClientConnectionPreface(test.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -74,14 +106,11 @@ class TestPerformConnectionPreface(test.ServerTestBase): self.wfile.write('000000040100000000'.decode('hex')) self.wfile.flush() - ssl = True - - def test_perform_connection_preface(self): + def test_perform_client_connection_preface(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) - protocol.perform_connection_preface() + protocol.perform_client_connection_preface() class TestStreamIds(): -- cgit v1.2.3 From a901bc3032747faf00adf82c3187d38213c070ca Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 12 Jun 2015 14:41:54 +0200 Subject: http2: add response creation --- netlib/http2/protocol.py | 56 ++++++++++++++++++++++++++++----------- test/http2/test_http2_protocol.py | 2 +- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 4b69764f..56aee490 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -26,7 +26,8 @@ class HTTP2Protocol(object): ) # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + CLIENT_CONNECTION_PREFACE =\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') ALPN_PROTO_H2 = 'h2' @@ -38,6 +39,7 @@ class HTTP2Protocol(object): self.current_stream_id = None self.encoder = Encoder() self.decoder = Decoder() + self.connection_preface_performed = False def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() @@ -57,25 +59,36 @@ class HTTP2Protocol(object): assert settings_ack_frame.flags & frame.Frame.FLAG_ACK assert len(settings_ack_frame.settings) == 0 - def perform_server_connection_preface(self): - magic_length = len(self.CLIENT_CONNECTION_PREFACE) - magic = self.tcp_handler.rfile.safe_read(magic_length) - assert magic == self.CLIENT_CONNECTION_PREFACE + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True - self.send_frame(frame.SettingsFrame(state=self)) - self._receive_settings() - self._read_settings_ack() + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE - def perform_client_connection_preface(self): - self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + self.send_frame(frame.SettingsFrame(state=self)) + self._receive_settings() + self._read_settings_ack() - self.send_frame(frame.SettingsFrame(state=self)) - self._receive_settings() - self._read_settings_ack() + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self)) + self._receive_settings() + self._read_settings_ack() def next_stream_id(self): if self.current_stream_id is None: - self.current_stream_id = 1 + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 else: self.current_stream_id += 2 return self.current_stream_id @@ -165,7 +178,8 @@ class HTTP2Protocol(object): while True: frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame): + if isinstance(frm, frame.HeadersFrame)\ + or isinstance(frm, frame.ContinuationFrame): header_block_fragment += frm.header_block_fragment if frm.flags & frame.Frame.FLAG_END_HEADERS: if frm.flags & frame.Frame.FLAG_END_STREAM: @@ -184,3 +198,15 @@ class HTTP2Protocol(object): headers[header] = value return headers, body + + def create_response(self, code, headers=None, body=None): + if headers is None: + headers = [] + + headers = [(b':status', bytes(str(code)))] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py index 1591edd8..76a0ffe9 100644 --- a/test/http2/test_http2_protocol.py +++ b/test/http2/test_http2_protocol.py @@ -55,7 +55,7 @@ class TestPerformServerConnectionPreface(test.ServerTestBase): def handle(self): # send magic - self.wfile.write(\ + self.wfile.write( '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) self.wfile.flush() -- cgit v1.2.3 From 5fab755a05f2ddd1b3e8e446e10fdcbded894e70 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 12 Jun 2015 15:21:23 +0200 Subject: add more tests --- netlib/tcp.py | 8 ++-- test/http2/test_http2_protocol.py | 87 +++++++++++++++++++++++++++++++++++++-- test/test_tcp.py | 5 +++ 3 files changed, 92 insertions(+), 8 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 98b17c50..eb8a523f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -412,7 +412,7 @@ class _Connection(object): def alpn_select_callback(conn, options): if alpn_select in options: return bytes(alpn_select) - else: + else: # pragma no cover return options[0] context.set_alpn_select_callback(alpn_select_callback) @@ -500,9 +500,9 @@ class TCPClient(_Connection): return self.connection.gettimeout() def get_alpn_proto_negotiated(self): - if OpenSSL._util.lib.Cryptography_HAS_ALPN: + if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() - else: # pragma no cover + else: return None @@ -616,7 +616,7 @@ class BaseHandler(_Connection): def get_alpn_proto_negotiated(self): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() - else: # pragma no cover + else: return None diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py index 76a0ffe9..ebd2c9a7 100644 --- a/test/http2/test_http2_protocol.py +++ b/test/http2/test_http2_protocol.py @@ -1,4 +1,3 @@ - import OpenSSL from netlib import http2 @@ -113,11 +112,11 @@ class TestPerformClientConnectionPreface(test.ServerTestBase): protocol.perform_client_connection_preface() -class TestStreamIds(): +class TestClientStreamIds(): c = tcp.TCPClient(("127.0.0.1", 0)) protocol = http2.HTTP2Protocol(c) - def test_stream_ids(self): + def test_client_stream_ids(self): assert self.protocol.current_stream_id is None assert self.protocol.next_stream_id() == 1 assert self.protocol.current_stream_id == 1 @@ -127,6 +126,20 @@ class TestStreamIds(): assert self.protocol.current_stream_id == 5 +class TestServerStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c, is_server=True) + + def test_server_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol.next_stream_id() == 2 + assert self.protocol.current_stream_id == 2 + assert self.protocol.next_stream_id() == 4 + assert self.protocol.current_stream_id == 4 + assert self.protocol.next_stream_id() == 6 + assert self.protocol.current_stream_id == 6 + + class TestApplySettings(test.ServerTestBase): class handler(tcp.BaseHandler): @@ -242,5 +255,71 @@ class TestReadResponse(test.ServerTestBase): status, headers, body = protocol.read_response() assert headers == {':status': '200', 'etag': 'foobar'} - assert status == '200' + assert status == "200" + assert body == b'foobar' + + +class TestReadEmptyResponse(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801050000000188628594e78c767f'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_empty_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + status, headers, body = protocol.read_response() + + assert headers == {':status': '200', 'etag': 'foobar'} + assert status == "200" + assert body == b'' + + +class TestReadRequest(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'000003010400000001828487'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_request(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c, is_server=True) + + headers, body = protocol.read_request() + + assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} assert body == b'foobar' + + +class TestCreateResponse(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_request_simple(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000000288'.decode('hex') + + def test_create_request_with_body(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( + 200, [(b'foo', b'bar')], 'foobar') + assert len(bytes) == 2 + assert bytes[0] ==\ + '00000901040000000288408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000002666f6f626172'.decode('hex') diff --git a/test/test_tcp.py b/test/test_tcp.py index d5506556..8aa34d2b 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -376,6 +376,11 @@ class TestALPN(test.ServerTestBase): c.convert_to_ssl(alpn_protos=["foobar"]) assert c.get_alpn_proto_negotiated() == "foobar" + def test_no_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + assert c.get_alpn_proto_negotiated() == None + else: def test_none_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) -- cgit v1.2.3 From 9c6d237d02290c2388f19ec8f215827d4f921e4b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 12 Jun 2015 16:03:01 +0200 Subject: add new TLS methods --- netlib/tcp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index eb8a523f..74fe70d4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -19,6 +19,9 @@ SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD +TLSv1_1_METHOD = SSL.TLSv1_1_METHOD +TLSv1_2_METHOD = SSL.TLSv1_2_METHOD + OP_NO_SSLv2 = SSL.OP_NO_SSLv2 OP_NO_SSLv3 = SSL.OP_NO_SSLv3 @@ -376,7 +379,7 @@ class _Connection(object): alpn_select=None, ): """ - :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD + :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD :param options: A bit field consisting of OpenSSL.SSL.OP_* values :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html :rtype : SSL.Context -- cgit v1.2.3 From 8d71a5b4aba8248b97918b11b12275bbf5197337 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 14 Jun 2015 19:17:34 +0200 Subject: http2: add authority header --- netlib/http2/protocol.py | 6 +++++- test/http2/test_http2_protocol.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 56aee490..1e722dfb 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -152,10 +152,13 @@ class HTTP2Protocol(object): if headers is None: headers = [] + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host headers = [ (b':method', bytes(method)), (b':path', bytes(path)), - (b':scheme', b'https')] + headers + (b':scheme', b'https'), + (b':authority', authority), + ] + headers stream_id = self.next_stream_id() @@ -192,6 +195,7 @@ class HTTP2Protocol(object): body += frm.payload if frm.flags & frame.Frame.FLAG_END_STREAM: break + # TODO: implement window update & flow headers = {} for header, value in self.decoder.decode(header_block_fragment): diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py index ebd2c9a7..34c69fa9 100644 --- a/test/http2/test_http2_protocol.py +++ b/test/http2/test_http2_protocol.py @@ -222,14 +222,14 @@ class TestCreateRequest(): def test_create_request_simple(self): bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') assert len(bytes) == 1 - assert bytes[0] == '000003010500000001828487'.decode('hex') + assert bytes[0] == '00000c0105000000018284874187089d5c0b8170ff'.decode('hex') def test_create_request_with_body(self): bytes = http2.HTTP2Protocol(self.c).create_request( 'GET', '/', [(b'foo', b'bar')], 'foobar') assert len(bytes) == 2 assert bytes[0] ==\ - '00000b010400000001828487408294e7838c767f'.decode('hex') + '0000140104000000018284874187089d5c0b8170ff408294e7838c767f'.decode('hex') assert bytes[1] ==\ '000006000100000001666f6f626172'.decode('hex') -- cgit v1.2.3 From 0d137eac6f4c00a72d3aa4d11fce7d1ea15f0f21 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 14 Jun 2015 19:50:35 +0200 Subject: simplify ALPN --- netlib/tcp.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 74fe70d4..897e3e65 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -535,7 +535,6 @@ class BaseHandler(_Connection): request_client_cert=None, chain_file=None, dhparams=None, - alpn_select=None, **sslctx_kwargs): """ cert: A certutils.SSLCert object. @@ -562,9 +561,7 @@ class BaseHandler(_Connection): until then we're conservative. """ - context = self._create_ssl_context( - alpn_select=alpn_select, - **sslctx_kwargs) + context = self._create_ssl_context(**sslctx_kwargs) context.use_privatekey(key) context.use_certificate(cert.x509) @@ -589,7 +586,7 @@ class BaseHandler(_Connection): return context - def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs): + def convert_to_ssl(self, cert, key, **sslctx_kwargs): """ Convert connection to SSL. For a list of parameters, see BaseHandler._create_ssl_context(...) @@ -598,7 +595,6 @@ class BaseHandler(_Connection): context = self.create_ssl_context( cert, key, - alpn_select=alpn_select, **sslctx_kwargs) self.connection = SSL.Connection(context, self.connection) self.connection.set_accept_state() -- cgit v1.2.3 From 08f988e9f65d8628657cf2018fd36ab82a4d0789 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 11:58:24 +0200 Subject: improve meta code --- check_coding_style.sh | 4 ++-- setup.cfg | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/check_coding_style.sh b/check_coding_style.sh index 5b38e003..a1c94e03 100755 --- a/check_coding_style.sh +++ b/check_coding_style.sh @@ -5,7 +5,7 @@ if [[ -n "$(git status -s)" ]]; then echo "autopep8 yielded the following changes:" git status -s git --no-pager diff - exit 1 + exit 0 # don't be so strict about coding style errors fi autoflake -i -r --remove-all-unused-imports --remove-unused-variables . @@ -13,7 +13,7 @@ if [[ -n "$(git status -s)" ]]; then echo "autoflake yielded the following changes:" git status -s git --no-pager diff - exit 1 + exit 0 # don't be so strict about coding style errors fi echo "Coding style seems to be ok." diff --git a/setup.cfg b/setup.cfg index bc980d56..4207020e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,6 +4,5 @@ max-complexity = 15 [pep8] max-line-length = 80 -max-complexity = 15 exclude = */contrib/* ignore = E251,E309 -- cgit v1.2.3 From fe764cde5229046b8447062971c61fac745d2d58 Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Mon, 15 Jun 2015 10:16:44 -0700 Subject: Adding support for upstream certificate validation when using SSL/TLS with an instance of TCPClient. --- netlib/tcp.py | 23 +++++++++++++++++++++ test/data/not-server.crt | 15 ++++++++++++++ test/test_tcp.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+) create mode 100644 test/data/not-server.crt diff --git a/netlib/tcp.py b/netlib/tcp.py index 9a980035..ca948514 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -21,6 +21,7 @@ SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD OP_NO_SSLv2 = SSL.OP_NO_SSLv2 OP_NO_SSLv3 = SSL.OP_NO_SSLv3 +VERIFY_NONE = SSL.VERIFY_NONE class NetLibError(Exception): @@ -371,6 +372,9 @@ class _Connection(object): def _create_ssl_context(self, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), + verify_options=VERIFY_NONE, + ca_path=None, + ca_pemfile=None, cipher_list=None, alpn_protos=None, alpn_select=None, @@ -378,6 +382,9 @@ class _Connection(object): """ :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD :param options: A bit field consisting of OpenSSL.SSL.OP_* values + :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values + :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool + :param ca_pemfile: Path to a PEM formatted trusted CA certificate :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html :rtype : SSL.Context """ @@ -386,6 +393,19 @@ class _Connection(object): if options is not None: context.set_options(options) + # Verify Options (NONE/PEER/PEER|FAIL_IF_... and trusted CAs) + if verify_options is not None and verify_options is not VERIFY_NONE: + def verify_cert(conn, cert, errno, err_depth, is_cert_verified): + if is_cert_verified: + return True + raise NetLibError( + "Upstream certificate validation failed at depth: %s with error number: %s" % + (err_depth, errno)) + + context.set_verify(verify_options, verify_cert) + if ca_path is not None or ca_pemfile is not None: + context.load_verify_locations(ca_pemfile, ca_path) + # Workaround for # https://github.com/pyca/pyopenssl/issues/190 # https://github.com/mitmproxy/mitmproxy/issues/472 @@ -458,6 +478,9 @@ class TCPClient(_Connection): cert: Path to a file containing both client cert and private key. options: A bit field consisting of OpenSSL.SSL.OP_* values + verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values + ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool + ca_pemfile: Path to a PEM formatted trusted CA certificate """ context = self.create_ssl_context( alpn_protos=alpn_protos, diff --git a/test/data/not-server.crt b/test/data/not-server.crt new file mode 100644 index 00000000..08c015c2 --- /dev/null +++ b/test/data/not-server.crt @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICRTCCAa4CCQD/j4qq1h3iCjANBgkqhkiG9w0BAQUFADBnMQswCQYDVQQGEwJV +UzELMAkGA1UECBMCQ0ExETAPBgNVBAcTCFNvbWVDaXR5MRcwFQYDVQQKEw5Ob3RU +aGVSaWdodE9yZzELMAkGA1UECxMCTkExEjAQBgNVBAMTCU5vdFNlcnZlcjAeFw0x +NTA2MTMwMTE2MDZaFw0yNTA2MTAwMTE2MDZaMGcxCzAJBgNVBAYTAlVTMQswCQYD +VQQIEwJDQTERMA8GA1UEBxMIU29tZUNpdHkxFzAVBgNVBAoTDk5vdFRoZVJpZ2h0 +T3JnMQswCQYDVQQLEwJOQTESMBAGA1UEAxMJTm90U2VydmVyMIGfMA0GCSqGSIb3 +DQEBAQUAA4GNADCBiQKBgQDPkJlXAOCMKF0R7aDn5QJ7HtrJgOUDk/LpbhKhRZZR +dRGnJ4/HQxYYHh9k/4yZamYcvQPUxvFJt7UJUocf+84LUcIusUk7GvJMgsMVtFMq +7UKNXBN5tl3oOtoFDWGMZ8ksaIxS6oW3V/9v2WgU23PfvwE0EZqy+QhMLZZP5GOH +RwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAJI6UtMKdCS2ghjqhAek2W1rt9u+Wuvx +776WYm5VyrJEtBDc/axLh0OteXzy/A31JrYe15fnVWIeFbDF0Ief9/Ezv6Jn+Pk8 +DErw5IHk2B399O4K3L3Eig06piu7uf3vE4l8ZanY02ZEnw7DyL6kmG9lX98VGenF +uXPfu3yxKbR4 +-----END CERTIFICATE----- diff --git a/test/test_tcp.py b/test/test_tcp.py index d5506556..081c83a7 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -171,6 +171,59 @@ class TestSSLv3Only(test.ServerTestBase): tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com") +class TestSSLUpstreamCertVerification(test.ServerTestBase): + handler = EchoHandler + + ssl = dict( + cert=tutils.test_data.path("data/server.crt") + ) + + def test_mode_default(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + + c.convert_to_ssl() + + testval = "echo!\n" + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval + + def test_mode_none(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + + c.convert_to_ssl(verify_options=SSL.VERIFY_NONE) + + testval = "echo!\n" + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval + + def test_mode_strict_w_bad_cert(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + + tutils.raises( + tcp.NetLibError, + c.convert_to_ssl, + verify_options=SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, + ca_pemfile=tutils.test_data.path("data/not-server.crt")) + + def test_mode_strict_w_cert(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + + c.convert_to_ssl( + verify_options=SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, + ca_pemfile=tutils.test_data.path("data/server.crt")) + + testval = "echo!\n" + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval + + class TestSSLClientCert(test.ServerTestBase): class handler(tcp.BaseHandler): -- cgit v1.2.3 From 9089226d661793e2eb5d3cf6f1bbe916578e5b7b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 16 Jun 2015 02:31:47 +0200 Subject: explicitly state that we only support 2.7 --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 0051ea77..dc19a870 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ setup( "Operating System :: POSIX", "Programming Language :: Python", "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Internet", -- cgit v1.2.3 From d8db9330a01cfab2603c8ad465c0ba00e8310994 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 16 Jun 2015 02:52:07 +0200 Subject: update badges --- README.mkd | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/README.mkd b/README.mkd index 79e7f803..33d50b29 100644 --- a/README.mkd +++ b/README.mkd @@ -1,8 +1,9 @@ -[![Build Status](https://travis-ci.org/mitmproxy/netlib.svg?branch=master)](https://travis-ci.org/mitmproxy/netlib) -[![Coverage Status](https://coveralls.io/repos/mitmproxy/netlib/badge.svg?branch=master)](https://coveralls.io/r/mitmproxy/netlib) -[![Latest Version](https://pypip.in/version/netlib/badge.svg?style=flat)](https://pypi.python.org/pypi/netlib) -[![Supported Python versions](https://pypip.in/py_versions/netlib/badge.svg?style=flat)](https://pypi.python.org/pypi/netlib) -[![Supported Python implementations](https://pypip.in/implementation/netlib/badge.svg?style=flat)](https://pypi.python.org/pypi/netlib) +[![Build Status](https://img.shields.io/travis/mitmproxy/netlib/master.svg)](https://travis-ci.org/mitmproxy/netlib) +[![Coverage Status](https://img.shields.io/coveralls/mitmproxy/netlib/master.svg)](https://coveralls.io/r/mitmproxy/netlib) +[![Downloads](https://img.shields.io/pypi/dm/netlib.svg?color=orange)](https://pypi.python.org/pypi/netlib) +[![Latest Version](https://img.shields.io/pypi/v/netlib.svg)](https://pypi.python.org/pypi/netlib) +[![Supported Python versions](https://img.shields.io/pypi/pyversions/netlib.svg)](https://pypi.python.org/pypi/netlib) +[![Supported Python implementations](https://img.shields.io/pypi/implementation/netlib.svg)](https://pypi.python.org/pypi/netlib) Netlib is a collection of network utility classes, used by the pathod and mitmproxy projects. It differs from other projects in some fundamental @@ -15,4 +16,4 @@ Requirements ------------ * [Python](http://www.python.org) 2.7.x. -* Third-party packages listed in [setup.py](https://github.com/mitmproxy/netlib/blob/master/setup.py) \ No newline at end of file +* Third-party packages listed in [setup.py](https://github.com/mitmproxy/netlib/blob/master/setup.py) -- cgit v1.2.3 From 1f0c55a942ef1e36d21e2d8006a1585ad4cf2700 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 16 Jun 2015 03:30:34 +0200 Subject: add hacking section --- README.mkd | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.mkd b/README.mkd index 33d50b29..7039e203 100644 --- a/README.mkd +++ b/README.mkd @@ -17,3 +17,8 @@ Requirements * [Python](http://www.python.org) 2.7.x. * Third-party packages listed in [setup.py](https://github.com/mitmproxy/netlib/blob/master/setup.py) + +Hacking +------- + +If you'd like to work on netlib, check out the instructions in mitmproxy's [README](https://github.com/mitmproxy/mitmproxy#hacking). -- cgit v1.2.3 From 12702b9a01fb6baf4d675d6f974c140581982843 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 13:15:06 +0200 Subject: http2: improve frame output --- netlib/http2/frame.py | 11 +++------ netlib/http2/protocol.py | 61 ++++++++++++++++++++++++++++-------------------- 2 files changed, 39 insertions(+), 33 deletions(-) diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index 4a305d82..3e285cba 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -113,16 +113,11 @@ class Frame(object): def payload_human_readable(self): # pragma: no cover raise NotImplementedError() - def human_readable(self): + def human_readable(self, direction="-"): return "\n".join([ - "============================================================", - "length: %d bytes" % self.length, - "type: %s (%#x)" % (self.__class__.__name__, self.TYPE), - "flags: %#x" % self.flags, - "stream_id: %#x" % self.stream_id, - "------------------------------------------------------------", + "%s: %s | length: %d | flags: %#x | stream_id: %d" % (direction, self.__class__.__name__, self.length, self.flags, self.stream_id), self.payload_human_readable(), - "============================================================", + "===============================================================", ]) def __eq__(self, other): diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 1e722dfb..7bf68602 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -48,13 +48,12 @@ class HTTP2Protocol(object): "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True - def _receive_settings(self): - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + def _receive_settings(self, hide=False): + frm = self.read_frame(hide) assert isinstance(frm, frame.SettingsFrame) - self._apply_settings(frm.settings) - def _read_settings_ack(self): - settings_ack_frame = self.read_frame() + def _read_settings_ack(self, hide=False): + settings_ack_frame = self.read_frame(hide) assert isinstance(settings_ack_frame, frame.SettingsFrame) assert settings_ack_frame.flags & frame.Frame.FLAG_ACK assert len(settings_ack_frame.settings) == 0 @@ -67,9 +66,8 @@ class HTTP2Protocol(object): magic = self.tcp_handler.rfile.safe_read(magic_length) assert magic == self.CLIENT_CONNECTION_PREFACE - self.send_frame(frame.SettingsFrame(state=self)) - self._receive_settings() - self._read_settings_ack() + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) def perform_client_connection_preface(self, force=False): if force or not self.connection_preface_performed: @@ -77,9 +75,8 @@ class HTTP2Protocol(object): self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - self.send_frame(frame.SettingsFrame(state=self)) - self._receive_settings() - self._read_settings_ack() + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) def next_stream_id(self): if self.current_stream_id is None: @@ -93,30 +90,35 @@ class HTTP2Protocol(object): self.current_stream_id += 2 return self.current_stream_id - def send_frame(self, frame): - raw_bytes = frame.to_bytes() + def send_frame(self, frm, hide=False): + raw_bytes = frm.to_bytes() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() + if not hide and self.tcp_handler.http2_framedump: + print(frm.human_readable(">>")) - def read_frame(self): + def read_frame(self, hide=False): frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if isinstance(frm, frame.SettingsFrame): - self._apply_settings(frm.settings) + if not hide and self.tcp_handler.http2_framedump: + print(frm.human_readable("<<")) + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) return frm - def _apply_settings(self, settings): + def _apply_settings(self, settings, hide=False): for setting, value in settings.items(): old_value = self.http2_settings[setting] if not old_value: old_value = '-' - self.http2_settings[setting] = value self.send_frame( frame.SettingsFrame( state=self, - flags=frame.Frame.FLAG_ACK)) + flags=frame.Frame.FLAG_ACK), + hide) + self._read_settings_ack(hide) def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -127,12 +129,16 @@ class HTTP2Protocol(object): header_block_fragment = self.encoder.encode(headers) - bytes = frame.HeadersFrame( + frm = frame.HeadersFrame( state=self, flags=flags, stream_id=stream_id, - header_block_fragment=header_block_fragment).to_bytes() - return [bytes] + header_block_fragment=header_block_fragment) + + if self.tcp_handler.http2_framedump: + print(frm.human_readable(">>")) + + return [frm.to_bytes()] def _create_body(self, body, stream_id): if body is None or len(body) == 0: @@ -141,12 +147,17 @@ class HTTP2Protocol(object): # TODO: implement max frame size checks and sending in chunks # TODO: implement flow-control window - bytes = frame.DataFrame( + frm = frame.DataFrame( state=self, flags=frame.Frame.FLAG_END_STREAM, stream_id=stream_id, - payload=body).to_bytes() - return [bytes] + payload=body) + + if self.tcp_handler.http2_framedump: + print(frm.human_readable(">>")) + + return [frm.to_bytes()] + def create_request(self, method, path, headers=None, body=None): if headers is None: -- cgit v1.2.3 From 79ff43993018209a76a2a7cff995e912eb20d4c3 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 09:47:43 +0200 Subject: add elliptic curve during TLS handshake --- netlib/tcp.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 953cef6e..2e847d83 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -22,11 +22,6 @@ TLSv1_METHOD = SSL.TLSv1_METHOD TLSv1_1_METHOD = SSL.TLSv1_1_METHOD TLSv1_2_METHOD = SSL.TLSv1_2_METHOD -OP_NO_SSLv2 = SSL.OP_NO_SSLv2 -OP_NO_SSLv3 = SSL.OP_NO_SSLv3 -VERIFY_NONE = SSL.VERIFY_NONE - - class NetLibError(Exception): pass @@ -374,8 +369,8 @@ class _Connection(object): def _create_ssl_context(self, method=SSLv23_METHOD, - options=(OP_NO_SSLv2 | OP_NO_SSLv3), - verify_options=VERIFY_NONE, + options=(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_CIPHER_SERVER_PREFERENCE | SSL.OP_NO_COMPRESSION), + verify_options=SSL.VERIFY_NONE, ca_path=None, ca_pemfile=None, cipher_list=None, @@ -397,7 +392,7 @@ class _Connection(object): context.set_options(options) # Verify Options (NONE/PEER/PEER|FAIL_IF_... and trusted CAs) - if verify_options is not None and verify_options is not VERIFY_NONE: + if verify_options is not None and verify_options is not SSL.VERIFY_NONE: def verify_cert(conn, cert, errno, err_depth, is_cert_verified): if is_cert_verified: return True @@ -426,6 +421,8 @@ class _Connection(object): if log_ssl_key: context.set_info_callback(log_ssl_key) + context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) + if OpenSSL._util.lib.Cryptography_HAS_ALPN: if alpn_protos is not None: # advertise application layer protocols -- cgit v1.2.3 From e3db241a2fa47a38fcb85532ed52eeecf1a7b965 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 13:43:23 +0200 Subject: http2: improve frame output --- netlib/http2/protocol.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 7bf68602..24fcb712 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -31,7 +31,7 @@ class HTTP2Protocol(object): ALPN_PROTO_H2 = 'h2' - def __init__(self, tcp_handler, is_server=False): + def __init__(self, tcp_handler, is_server=False, dump_frames=False): self.tcp_handler = tcp_handler self.is_server = is_server @@ -40,6 +40,7 @@ class HTTP2Protocol(object): self.encoder = Encoder() self.decoder = Decoder() self.connection_preface_performed = False + self.dump_frames = dump_frames def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() @@ -94,12 +95,12 @@ class HTTP2Protocol(object): raw_bytes = frm.to_bytes() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() - if not hide and self.tcp_handler.http2_framedump: + if not hide and self.dump_frames: print(frm.human_readable(">>")) def read_frame(self, hide=False): frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.tcp_handler.http2_framedump: + if not hide and self.dump_frames: print(frm.human_readable("<<")) if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: self._apply_settings(frm.settings, hide) @@ -135,7 +136,7 @@ class HTTP2Protocol(object): stream_id=stream_id, header_block_fragment=header_block_fragment) - if self.tcp_handler.http2_framedump: + if self.dump_frames: print(frm.human_readable(">>")) return [frm.to_bytes()] @@ -153,7 +154,7 @@ class HTTP2Protocol(object): stream_id=stream_id, payload=body) - if self.tcp_handler.http2_framedump: + if self.dump_frames: print(frm.human_readable(">>")) return [frm.to_bytes()] -- cgit v1.2.3 From d0a9d3cdda6d1f784a23ea4bd9efd3134e292628 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 14:21:34 +0200 Subject: http2: only first headers frame as END_STREAM flag --- netlib/http2/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 24fcb712..682b7863 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -196,9 +196,9 @@ class HTTP2Protocol(object): if isinstance(frm, frame.HeadersFrame)\ or isinstance(frm, frame.ContinuationFrame): header_block_fragment += frm.header_block_fragment + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False if frm.flags & frame.Frame.FLAG_END_HEADERS: - if frm.flags & frame.Frame.FLAG_END_STREAM: - body_expected = False break while body_expected: -- cgit v1.2.3 From 1c124421e34d310c6e0577f20b595413d639a5c3 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 15:31:58 +0200 Subject: http2: fix header_block_fragments and length --- netlib/http2/frame.py | 13 +++++++++++-- netlib/http2/protocol.py | 23 +++++++++++++++-------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index 3e285cba..98ced904 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -114,6 +114,8 @@ class Frame(object): raise NotImplementedError() def human_readable(self, direction="-"): + self.length = len(self.payload_bytes()) + return "\n".join([ "%s: %s | length: %d | flags: %#x | stream_id: %d" % (direction, self.__class__.__name__, self.length, self.flags, self.stream_id), self.payload_human_readable(), @@ -456,7 +458,10 @@ class PushPromiseFrame(Frame): s.append("padding: %d" % self.pad_length) s.append("promised stream: %#x" % self.promised_stream) - s.append("header_block_fragment: %s" % str(self.header_block_fragment)) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) @@ -600,7 +605,11 @@ class ContinuationFrame(Frame): return self.header_block_fragment def payload_human_readable(self): - return "header_block_fragment: %s" % str(self.header_block_fragment) + s = [] + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) _FRAME_CLASSES = [ DataFrame, diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 682b7863..f17f998f 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -50,14 +50,18 @@ class HTTP2Protocol(object): return True def _receive_settings(self, hide=False): - frm = self.read_frame(hide) - assert isinstance(frm, frame.SettingsFrame) + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break def _read_settings_ack(self, hide=False): - settings_ack_frame = self.read_frame(hide) - assert isinstance(settings_ack_frame, frame.SettingsFrame) - assert settings_ack_frame.flags & frame.Frame.FLAG_ACK - assert len(settings_ack_frame.settings) == 0 + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + assert settings_ack_frame.flags & frame.Frame.FLAG_ACK + assert len(settings_ack_frame.settings) == 0 + break def perform_server_connection_preface(self, force=False): if force or not self.connection_preface_performed: @@ -119,7 +123,7 @@ class HTTP2Protocol(object): state=self, flags=frame.Frame.FLAG_ACK), hide) - self._read_settings_ack(hide) + # self._read_settings_ack(hide) def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -219,10 +223,13 @@ class HTTP2Protocol(object): if headers is None: headers = [] + body='foobar' + headers = [(b':status', bytes(str(code)))] + headers stream_id = self.next_stream_id() return list(itertools.chain( self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) + self._create_body(body, stream_id), + )) -- cgit v1.2.3 From 20c136e070cee0e93e870bf32199cb36b1b85275 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 15:51:40 +0200 Subject: http2: return stream_id from request for response --- netlib/http2/protocol.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index f17f998f..a77edd9b 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -183,7 +183,7 @@ class HTTP2Protocol(object): self._create_body(body, stream_id))) def read_response(self): - headers, body = self._receive_transmission() + stream_id, headers, body = self._receive_transmission() return headers[':status'], headers, body def read_request(self): @@ -192,6 +192,7 @@ class HTTP2Protocol(object): def _receive_transmission(self): body_expected = True + stream_id = 0 header_block_fragment = b'' body = b'' @@ -199,6 +200,7 @@ class HTTP2Protocol(object): frm = self.read_frame() if isinstance(frm, frame.HeadersFrame)\ or isinstance(frm, frame.ContinuationFrame): + stream_id = frm.stream_id header_block_fragment += frm.header_block_fragment if frm.flags & frame.Frame.FLAG_END_STREAM: body_expected = False @@ -217,9 +219,9 @@ class HTTP2Protocol(object): for header, value in self.decoder.decode(header_block_fragment): headers[header] = value - return headers, body + return stream_id, headers, body - def create_response(self, code, headers=None, body=None): + def create_response(self, code, stream_id=None, headers=None, body=None): if headers is None: headers = [] @@ -227,7 +229,8 @@ class HTTP2Protocol(object): headers = [(b':status', bytes(str(code)))] + headers - stream_id = self.next_stream_id() + if not stream_id: + stream_id = self.next_stream_id() return list(itertools.chain( self._create_headers(headers, stream_id, end_stream=(body is None)), -- cgit v1.2.3 From abb37a3ef52ab9a0f68dc46e4a8ca165e365139b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 17:31:08 +0200 Subject: http2: improve test suite --- netlib/http2/protocol.py | 16 +++++++-------- netlib/tcp.py | 9 +++++---- test/http2/test_http2_protocol.py | 13 +++++++------ test/test_tcp.py | 41 +++++++++++++++++++++++++++++++-------- 4 files changed, 53 insertions(+), 26 deletions(-) diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index a77edd9b..8191090c 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -55,7 +55,7 @@ class HTTP2Protocol(object): if isinstance(frm, frame.SettingsFrame): break - def _read_settings_ack(self, hide=False): + def _read_settings_ack(self, hide=False): # pragma no cover while True: frm = self.read_frame(hide) if isinstance(frm, frame.SettingsFrame): @@ -99,12 +99,12 @@ class HTTP2Protocol(object): raw_bytes = frm.to_bytes() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() - if not hide and self.dump_frames: + if not hide and self.dump_frames: # pragma no cover print(frm.human_readable(">>")) def read_frame(self, hide=False): frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: + if not hide and self.dump_frames: # pragma no cover print(frm.human_readable("<<")) if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: self._apply_settings(frm.settings, hide) @@ -123,7 +123,9 @@ class HTTP2Protocol(object): state=self, flags=frame.Frame.FLAG_ACK), hide) - # self._read_settings_ack(hide) + + # be liberal in what we expect from the other end + # to be more strict use: self._read_settings_ack(hide) def _create_headers(self, headers, stream_id, end_stream=True): # TODO: implement max frame size checks and sending in chunks @@ -140,7 +142,7 @@ class HTTP2Protocol(object): stream_id=stream_id, header_block_fragment=header_block_fragment) - if self.dump_frames: + if self.dump_frames: # pragma no cover print(frm.human_readable(">>")) return [frm.to_bytes()] @@ -158,7 +160,7 @@ class HTTP2Protocol(object): stream_id=stream_id, payload=body) - if self.dump_frames: + if self.dump_frames: # pragma no cover print(frm.human_readable(">>")) return [frm.to_bytes()] @@ -225,8 +227,6 @@ class HTTP2Protocol(object): if headers is None: headers = [] - body='foobar' - headers = [(b':status', bytes(str(code)))] + headers if not stream_id: diff --git a/netlib/tcp.py b/netlib/tcp.py index 2e847d83..cafc3ed9 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -414,6 +414,9 @@ class _Connection(object): if cipher_list: try: context.set_cipher_list(cipher_list) + + # TODO: maybe change this to with newer pyOpenSSL APIs + context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) except SSL.Error as v: raise NetLibError("SSL cipher specification error: %s" % str(v)) @@ -421,8 +424,6 @@ class _Connection(object): if log_ssl_key: context.set_info_callback(log_ssl_key) - context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) - if OpenSSL._util.lib.Cryptography_HAS_ALPN: if alpn_protos is not None: # advertise application layer protocols @@ -526,7 +527,7 @@ class TCPClient(_Connection): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: - return None + return "" class BaseHandler(_Connection): @@ -636,7 +637,7 @@ class BaseHandler(_Connection): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: - return None + return "" class TCPServer(object): diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py index 34c69fa9..231b35e0 100644 --- a/test/http2/test_http2_protocol.py +++ b/test/http2/test_http2_protocol.py @@ -300,8 +300,9 @@ class TestReadRequest(test.ServerTestBase): c.convert_to_ssl() protocol = http2.HTTP2Protocol(c, is_server=True) - headers, body = protocol.read_request() + stream_id, headers, body = protocol.read_request() + assert stream_id assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} assert body == b'foobar' @@ -309,17 +310,17 @@ class TestReadRequest(test.ServerTestBase): class TestCreateResponse(): c = tcp.TCPClient(("127.0.0.1", 0)) - def test_create_request_simple(self): + def test_create_response_simple(self): bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) assert len(bytes) == 1 assert bytes[0] ==\ '00000101050000000288'.decode('hex') - def test_create_request_with_body(self): + def test_create_response_with_body(self): bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( - 200, [(b'foo', b'bar')], 'foobar') + 200, 1, [(b'foo', b'bar')], 'foobar') assert len(bytes) == 2 assert bytes[0] ==\ - '00000901040000000288408294e7838c767f'.decode('hex') + '00000901040000000188408294e7838c767f'.decode('hex') assert bytes[1] ==\ - '000006000100000002666f6f626172'.decode('hex') + '000006000100000001666f6f626172'.decode('hex') diff --git a/test/test_tcp.py b/test/test_tcp.py index 0cecaaa2..122c1f0f 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -41,6 +41,18 @@ class HangHandler(tcp.BaseHandler): time.sleep(1) +class ALPNHandler(tcp.BaseHandler): + sni = None + + def handle(self): + alp = self.get_alpn_proto_negotiated() + if alp: + self.wfile.write("%s" % alp) + else: + self.wfile.write("NONE") + self.wfile.flush() + + class TestServer(test.ServerTestBase): handler = EchoHandler @@ -416,30 +428,43 @@ class TestTimeOut(test.ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) -class TestALPN(test.ServerTestBase): - handler = EchoHandler +class TestALPNClient(test.ServerTestBase): + handler = ALPNHandler ssl = dict( - alpn_select="foobar" + alpn_select="bar" ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: def test_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=["foobar"]) - assert c.get_alpn_proto_negotiated() == "foobar" + c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"]) + assert c.get_alpn_proto_negotiated() == "bar" + assert c.rfile.readline().strip() == "bar" def test_no_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - assert c.get_alpn_proto_negotiated() == None + c.convert_to_ssl() + assert c.get_alpn_proto_negotiated() == "" + assert c.rfile.readline().strip() == "NONE" else: def test_none_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=["foobar"]) - assert c.get_alpn_proto_negotiated() == None + c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"]) + assert c.get_alpn_proto_negotiated() == "" + assert c.rfile.readline() == "NONE" + +class TestNoSSLNoALPNClient(test.ServerTestBase): + handler = ALPNHandler + + def test_no_ssl_no_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + assert c.get_alpn_proto_negotiated() == "" + assert c.rfile.readline().strip() == "NONE" class TestSSLTimeOut(test.ServerTestBase): -- cgit v1.2.3 From eb823a04a19de7fd9e15d225064ae4581f0b85bf Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 15 Jun 2015 23:36:14 +0200 Subject: http2: improve :authority header --- netlib/http2/protocol.py | 3 +++ test/http2/test_http2_protocol.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index 8191090c..ac89bac4 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -171,6 +171,9 @@ class HTTP2Protocol(object): headers = [] authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + headers = [ (b':method', bytes(method)), (b':path', bytes(path)), diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py index 231b35e0..9b49acd3 100644 --- a/test/http2/test_http2_protocol.py +++ b/test/http2/test_http2_protocol.py @@ -222,14 +222,14 @@ class TestCreateRequest(): def test_create_request_simple(self): bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') assert len(bytes) == 1 - assert bytes[0] == '00000c0105000000018284874187089d5c0b8170ff'.decode('hex') + assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') def test_create_request_with_body(self): bytes = http2.HTTP2Protocol(self.c).create_request( 'GET', '/', [(b'foo', b'bar')], 'foobar') assert len(bytes) == 2 assert bytes[0] ==\ - '0000140104000000018284874187089d5c0b8170ff408294e7838c767f'.decode('hex') + '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') assert bytes[1] ==\ '000006000100000001666f6f626172'.decode('hex') -- cgit v1.2.3 From c9c93af453ec332b660f70402b78ae8f269280f0 Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Tue, 16 Jun 2015 11:11:10 -0700 Subject: Adding certifi as default CA bundle. --- netlib/tcp.py | 6 +++--- setup.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index ca948514..b523bea4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -7,6 +7,7 @@ import threading import time import traceback +import certifi import OpenSSL from OpenSSL import SSL @@ -373,7 +374,7 @@ class _Connection(object): method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), verify_options=VERIFY_NONE, - ca_path=None, + ca_path=certifi.where(), ca_pemfile=None, cipher_list=None, alpn_protos=None, @@ -403,8 +404,7 @@ class _Connection(object): (err_depth, errno)) context.set_verify(verify_options, verify_cert) - if ca_path is not None or ca_pemfile is not None: - context.load_verify_locations(ca_pemfile, ca_path) + context.load_verify_locations(ca_pemfile, ca_path) # Workaround for # https://github.com/pyca/pyopenssl/issues/190 diff --git a/setup.py b/setup.py index 0051ea77..aa27cd90 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,8 @@ setup( "pyOpenSSL>=0.15.1", "cryptography>=0.9", "passlib>=1.6.2", - "hpack>=1.0.1"], + "hpack>=1.0.1", + "certifi"], setup_requires=[ "cffi", "pyOpenSSL>=0.15.1", -- cgit v1.2.3 From ff20e64537ad25aa988f212b0473bdb5e696611b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 16 Jun 2015 02:37:46 +0200 Subject: add landscape configuration --- .landscape.yml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .landscape.yml diff --git a/.landscape.yml b/.landscape.yml new file mode 100644 index 00000000..680ee0e7 --- /dev/null +++ b/.landscape.yml @@ -0,0 +1,3 @@ +pylint: + disable: + - unpacking-non-sequence \ No newline at end of file -- cgit v1.2.3 From 836b1eab9700230991822102d411aed067308123 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 17 Jun 2015 13:10:27 +0200 Subject: fix warnings and code smells use prospector to find them --- .landscape.yml | 12 ++++++++++- netlib/http2/__init__.py | 1 - netlib/http2/frame.py | 55 ++++++++++++++++++++++++------------------------ netlib/http_cookies.py | 8 +++---- netlib/http_uastrings.py | 24 +++++++++++---------- netlib/tcp.py | 8 +++---- netlib/utils.py | 2 +- netlib/websockets.py | 16 +++++++------- setup.cfg | 8 ------- 9 files changed, 67 insertions(+), 67 deletions(-) delete mode 100644 setup.cfg diff --git a/.landscape.yml b/.landscape.yml index 680ee0e7..5926e7bf 100644 --- a/.landscape.yml +++ b/.landscape.yml @@ -1,3 +1,13 @@ +max-line-length: 120 pylint: disable: - - unpacking-non-sequence \ No newline at end of file + - missing-docstring + - protected-access + - too-few-public-methods + - too-many-arguments + - too-many-instance-attributes + - too-many-locals + - too-many-public-methods + - too-many-return-statements + - too-many-statements + - unpacking-non-sequence diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py index 92897b5d..5acf7696 100644 --- a/netlib/http2/__init__.py +++ b/netlib/http2/__init__.py @@ -1,3 +1,2 @@ - from frame import * from protocol import * diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index 4a305d82..43676623 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -1,6 +1,5 @@ import sys import struct -from functools import reduce from hpack.hpack import Encoder, Decoder from .. import utils @@ -52,7 +51,7 @@ class Frame(object): self.stream_id = stream_id @classmethod - def _check_frame_size(self, length, state): + def _check_frame_size(cls, length, state): if state: settings = state.http2_settings else: @@ -67,7 +66,7 @@ class Frame(object): length, max_frame_size)) @classmethod - def from_file(self, fp, state=None): + def from_file(cls, fp, state=None): """ read a HTTP/2 frame sent by a server or client fp is a "file like" object that could be backed by a network @@ -83,7 +82,7 @@ class Frame(object): if raw_header[:4] == b'HTTP': # pragma no cover print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" - self._check_frame_size(length, state) + cls._check_frame_size(length, state) payload = fp.safe_read(length) return FRAMES[fields[2]].from_bytes( @@ -146,10 +145,10 @@ class DataFrame(Frame): self.pad_length = pad_length @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - if f.flags & self.FLAG_PADDED: + if f.flags & Frame.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] f.payload = payload[1:-f.pad_length] else: @@ -204,16 +203,16 @@ class HeadersFrame(Frame): self.weight = weight @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - if f.flags & self.FLAG_PADDED: + if f.flags & Frame.FLAG_PADDED: f.pad_length = struct.unpack('!B', payload[0])[0] f.header_block_fragment = payload[1:-f.pad_length] else: f.header_block_fragment = payload[0:] - if f.flags & self.FLAG_PRIORITY: + if f.flags & Frame.FLAG_PRIORITY: f.stream_dependency, f.weight = struct.unpack( '!LB', f.header_block_fragment[:5]) f.exclusive = bool(f.stream_dependency >> 31) @@ -279,8 +278,8 @@ class PriorityFrame(Frame): self.weight = weight @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.stream_dependency, f.weight = struct.unpack('!LB', payload) f.exclusive = bool(f.stream_dependency >> 31) @@ -325,8 +324,8 @@ class RstStreamFrame(Frame): self.error_code = error_code @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.error_code = struct.unpack('!L', payload)[0] return f @@ -369,8 +368,8 @@ class SettingsFrame(Frame): self.settings = settings @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) for i in xrange(0, len(payload), 6): identifier, value = struct.unpack("!HL", payload[i:i + 6]) @@ -420,10 +419,10 @@ class PushPromiseFrame(Frame): self.header_block_fragment = header_block_fragment @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - if f.flags & self.FLAG_PADDED: + if f.flags & Frame.FLAG_PADDED: f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) f.header_block_fragment = payload[5:-f.pad_length] else: @@ -480,8 +479,8 @@ class PingFrame(Frame): self.payload = payload @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.payload = payload return f @@ -517,8 +516,8 @@ class GoAwayFrame(Frame): self.data = data @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) f.last_stream &= 0x7FFFFFFF @@ -558,8 +557,8 @@ class WindowUpdateFrame(Frame): self.window_size_increment = window_size_increment @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.window_size_increment = struct.unpack("!L", payload)[0] f.window_size_increment &= 0x7FFFFFFF @@ -592,8 +591,8 @@ class ContinuationFrame(Frame): self.header_block_fragment = header_block_fragment @classmethod - def from_bytes(self, state, length, flags, stream_id, payload): - f = self(state=state, length=length, flags=flags, stream_id=stream_id) + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) f.header_block_fragment = payload return f diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 5cb39e5c..b7311714 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -158,7 +158,7 @@ def _parse_set_cookie_pairs(s): return pairs -def parse_set_cookie_header(str): +def parse_set_cookie_header(line): """ Parse a Set-Cookie header value @@ -166,7 +166,7 @@ def parse_set_cookie_header(str): ODictCaseless set of attributes. No attempt is made to parse attribute values - they are treated purely as strings. """ - pairs = _parse_set_cookie_pairs(str) + pairs = _parse_set_cookie_pairs(line) if pairs: return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) @@ -180,12 +180,12 @@ def format_set_cookie_header(name, value, attrs): return _format_set_cookie_pairs(pairs) -def parse_cookie_header(str): +def parse_cookie_header(line): """ Parse a Cookie header value. Returns a (possibly empty) ODict object. """ - pairs, off = _read_pairs(str) + pairs, off = _read_pairs(line) return odict.ODict(pairs) diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py index d9869531..c1ef557c 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http_uastrings.py @@ -5,40 +5,42 @@ from __future__ import (absolute_import, print_function, division) kept reasonably current to reflect common usage. """ +# pylint: line-too-long + # A collection of (name, shortcut, string) tuples. UASTRINGS = [ ("android", "a", - "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa ("blackberry", "l", - "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa ("bingbot", "b", - "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa ("chrome", "c", - "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa ("firefox", "f", - "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa ("googlebot", "g", - "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa ("ie9", "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"), + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"), # noqa ("ipad", "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"), + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa ("iphone", "h", - "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5", - ), + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa ("safari", "s", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10")] + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa +] def get_by_shortcut(s): diff --git a/netlib/tcp.py b/netlib/tcp.py index 953cef6e..807015c8 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -297,7 +297,7 @@ def close_socket(sock): """ try: # We already indicate that we close our end. - # may raise "Transport endpoint is not connected" on Linux + # may raise "Transport endpoint is not connected" on Linux sock.shutdown(socket.SHUT_WR) # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending @@ -368,10 +368,6 @@ class _Connection(object): except SSL.Error: pass - """ - Creates an SSL Context. - """ - def _create_ssl_context(self, method=SSLv23_METHOD, options=(OP_NO_SSLv2 | OP_NO_SSLv3), @@ -383,6 +379,8 @@ class _Connection(object): alpn_select=None, ): """ + Creates an SSL Context. + :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD :param options: A bit field consisting of OpenSSL.SSL.OP_* values :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values diff --git a/netlib/utils.py b/netlib/utils.py index 9c5404e6..ac42bd53 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -67,7 +67,7 @@ def getbit(byte, offset): return True -class BiDi: +class BiDi(object): """ A wee utility class for keeping bi-directional mappings, like field diff --git a/netlib/websockets.py b/netlib/websockets.py index 346adf1b..c45db4df 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -35,7 +35,7 @@ OPCODE = utils.BiDi( ) -class Masker: +class Masker(object): """ Data sent from the server must be masked to prevent malicious clients @@ -94,15 +94,15 @@ def server_handshake_headers(key): ) -def make_length_code(len): +def make_length_code(length): """ A websockets frame contains an initial length_code, and an optional extended length code to represent the actual length if length code is larger than 125 """ - if len <= 125: - return len - elif len >= 126 and len <= 65535: + if length <= 125: + return length + elif length >= 126 and length <= 65535: return 126 else: return 127 @@ -129,7 +129,7 @@ def create_server_nonce(client_nonce): DEFAULT = object() -class FrameHeader: +class FrameHeader(object): def __init__( self, @@ -216,7 +216,7 @@ class FrameHeader: return b @classmethod - def from_file(klass, fp): + def from_file(cls, fp): """ read a websockets frame header """ @@ -248,7 +248,7 @@ class FrameHeader: else: masking_key = None - return klass( + return cls( fin=fin, rsv1=rsv1, rsv2=rsv2, diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 4207020e..00000000 --- a/setup.cfg +++ /dev/null @@ -1,8 +0,0 @@ -[flake8] -max-line-length = 80 -max-complexity = 15 - -[pep8] -max-line-length = 80 -exclude = */contrib/* -ignore = E251,E309 -- cgit v1.2.3 From a652e050b759ca27aa3b794b8e11009853edef34 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 17 Jun 2015 13:19:44 +0200 Subject: add landscape.io badge --- README.mkd | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.mkd b/README.mkd index 7039e203..2f87d111 100644 --- a/README.mkd +++ b/README.mkd @@ -1,4 +1,5 @@ -[![Build Status](https://img.shields.io/travis/mitmproxy/netlib/master.svg)](https://travis-ci.org/mitmproxy/netlib) +[![Build Status](https://img.shields.io/travis/mitmproxy/netlib/master.svg)](https://travis-ci.org/mitmproxy/netlib) +[![Code Health](https://landscape.io/github/mitmproxy/netlib/master/landscape.svg?style=flat)](https://landscape.io/github/mitmproxy/netlib/master) [![Coverage Status](https://img.shields.io/coveralls/mitmproxy/netlib/master.svg)](https://coveralls.io/r/mitmproxy/netlib) [![Downloads](https://img.shields.io/pypi/dm/netlib.svg?color=orange)](https://pypi.python.org/pypi/netlib) [![Latest Version](https://img.shields.io/pypi/v/netlib.svg)](https://pypi.python.org/pypi/netlib) -- cgit v1.2.3 From 6e301f37d0597d86008c440f62526f906f0ae9f4 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 18 Jun 2015 12:18:22 +1200 Subject: Only set OP_NO_COMPRESSION by default if it exists in our version of OpenSSL We'll need to start testing under both new and old versions of OpenSSL somehow to catch these... --- netlib/tcp.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index a1d1fe62..52ebc3c0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -22,6 +22,17 @@ TLSv1_METHOD = SSL.TLSv1_METHOD TLSv1_1_METHOD = SSL.TLSv1_1_METHOD TLSv1_2_METHOD = SSL.TLSv1_2_METHOD + +SSL_DEFAULT_OPTIONS = ( + SSL.OP_NO_SSLv2 | + SSL.OP_NO_SSLv3 | + SSL.OP_CIPHER_SERVER_PREFERENCE +) + +if hasattr(SSL, "OP_NO_COMPRESSION"): + SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION + + class NetLibError(Exception): pass @@ -365,7 +376,7 @@ class _Connection(object): def _create_ssl_context(self, method=SSLv23_METHOD, - options=(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_CIPHER_SERVER_PREFERENCE | SSL.OP_NO_COMPRESSION), + options=SSL_DEFAULT_OPTIONS, verify_options=SSL.VERIFY_NONE, ca_path=None, ca_pemfile=None, -- cgit v1.2.3 From 61cbe36e4016d77b93386e3df9b17b36b1633d7e Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 18 Jun 2015 10:38:26 +0200 Subject: http2: rename test file --- test/http2/test_http2_protocol.py | 326 -------------------------------------- test/http2/test_protocol.py | 326 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 326 insertions(+), 326 deletions(-) delete mode 100644 test/http2/test_http2_protocol.py create mode 100644 test/http2/test_protocol.py diff --git a/test/http2/test_http2_protocol.py b/test/http2/test_http2_protocol.py deleted file mode 100644 index 9b49acd3..00000000 --- a/test/http2/test_http2_protocol.py +++ /dev/null @@ -1,326 +0,0 @@ -import OpenSSL - -from netlib import http2 -from netlib import tcp -from netlib import test -from netlib.http2.frame import * -from test import tutils - - -class EchoHandler(tcp.BaseHandler): - sni = None - - def handle(self): - while True: - v = self.rfile.safe_read(1) - self.wfile.write(v) - self.wfile.flush() - - -class TestCheckALPNMatch(test.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, - ) - - if OpenSSL._util.lib.Cryptography_HAS_ALPN: - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) - assert protocol.check_alpn() - - -class TestCheckALPNMismatch(test.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=None, - ) - - if OpenSSL._util.lib.Cryptography_HAS_ALPN: - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) - tutils.raises(NotImplementedError, protocol.check_alpn) - - -class TestPerformServerConnectionPreface(test.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # send magic - self.wfile.write( - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) - self.wfile.flush() - - # send empty settings frame - self.wfile.write('000000040000000000'.decode('hex')) - self.wfile.flush() - - # check empty settings frame - assert self.rfile.read(9) ==\ - '000000040000000000'.decode('hex') - - # check settings acknowledgement - assert self.rfile.read(9) == \ - '000000040100000000'.decode('hex') - - # send settings acknowledgement - self.wfile.write('000000040100000000'.decode('hex')) - self.wfile.flush() - - def test_perform_server_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - protocol = http2.HTTP2Protocol(c) - protocol.perform_server_connection_preface() - - -class TestPerformClientConnectionPreface(test.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # check magic - assert self.rfile.read(24) ==\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') - - # check empty settings frame - assert self.rfile.read(9) ==\ - '000000040000000000'.decode('hex') - - # send empty settings frame - self.wfile.write('000000040000000000'.decode('hex')) - self.wfile.flush() - - # check settings acknowledgement - assert self.rfile.read(9) == \ - '000000040100000000'.decode('hex') - - # send settings acknowledgement - self.wfile.write('000000040100000000'.decode('hex')) - self.wfile.flush() - - def test_perform_client_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - protocol = http2.HTTP2Protocol(c) - protocol.perform_client_connection_preface() - - -class TestClientStreamIds(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) - - def test_client_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 1 - assert self.protocol.current_stream_id == 1 - assert self.protocol.next_stream_id() == 3 - assert self.protocol.current_stream_id == 3 - assert self.protocol.next_stream_id() == 5 - assert self.protocol.current_stream_id == 5 - - -class TestServerStreamIds(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c, is_server=True) - - def test_server_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 2 - assert self.protocol.current_stream_id == 2 - assert self.protocol.next_stream_id() == 4 - assert self.protocol.current_stream_id == 4 - assert self.protocol.next_stream_id() == 6 - assert self.protocol.current_stream_id == 6 - - -class TestApplySettings(test.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # check settings acknowledgement - assert self.rfile.read(9) == '000000040100000000'.decode('hex') - self.wfile.write("OK") - self.wfile.flush() - - ssl = True - - def test_apply_settings(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - protocol._apply_settings({ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 'bar', - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 'deadbeef', - }) - - assert c.rfile.safe_read(2) == "OK" - - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 'foo' - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 'bar' - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE] == 'deadbeef' - - -class TestCreateHeaders(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_headers(self): - headers = [ - (b':method', b'GET'), - (b':path', b'index.html'), - (b':scheme', b'https'), - (b'foo', b'bar')] - - bytes = http2.HTTP2Protocol(self.c)._create_headers( - headers, 1, end_stream=True) - assert b''.join(bytes) ==\ - '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') - - bytes = http2.HTTP2Protocol(self.c)._create_headers( - headers, 1, end_stream=False) - assert b''.join(bytes) ==\ - '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') - - # TODO: add test for too large header_block_fragments - - -class TestCreateBody(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) - - def test_create_body_empty(self): - bytes = self.protocol._create_body(b'', 1) - assert b''.join(bytes) == ''.decode('hex') - - def test_create_body_single_frame(self): - bytes = self.protocol._create_body('foobar', 1) - assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') - - def test_create_body_multiple_frames(self): - pass - # bytes = self.protocol._create_body('foobar' * 3000, 1) - # TODO: add test for too large frames - - -class TestCreateRequest(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_request_simple(self): - bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') - assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') - - def test_create_request_with_body(self): - bytes = http2.HTTP2Protocol(self.c).create_request( - 'GET', '/', [(b'foo', b'bar')], 'foobar') - assert len(bytes) == 2 - assert bytes[0] ==\ - '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') - - -class TestReadResponse(test.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'00000801040000000188628594e78c767f'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - status, headers, body = protocol.read_response() - - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'foobar' - - -class TestReadEmptyResponse(test.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'00000801050000000188628594e78c767f'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_empty_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - status, headers, body = protocol.read_response() - - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'' - - -class TestReadRequest(test.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'000003010400000001828487'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_request(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c, is_server=True) - - stream_id, headers, body = protocol.read_request() - - assert stream_id - assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} - assert body == b'foobar' - - -class TestCreateResponse(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_response_simple(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) - assert len(bytes) == 1 - assert bytes[0] ==\ - '00000101050000000288'.decode('hex') - - def test_create_response_with_body(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( - 200, 1, [(b'foo', b'bar')], 'foobar') - assert len(bytes) == 2 - assert bytes[0] ==\ - '00000901040000000188408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') diff --git a/test/http2/test_protocol.py b/test/http2/test_protocol.py new file mode 100644 index 00000000..9b49acd3 --- /dev/null +++ b/test/http2/test_protocol.py @@ -0,0 +1,326 @@ +import OpenSSL + +from netlib import http2 +from netlib import tcp +from netlib import test +from netlib.http2.frame import * +from test import tutils + + +class EchoHandler(tcp.BaseHandler): + sni = None + + def handle(self): + while True: + v = self.rfile.safe_read(1) + self.wfile.write(v) + self.wfile.flush() + + +class TestCheckALPNMatch(test.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) + protocol = http2.HTTP2Protocol(c) + assert protocol.check_alpn() + + +class TestCheckALPNMismatch(test.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=None, + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) + protocol = http2.HTTP2Protocol(c) + tutils.raises(NotImplementedError, protocol.check_alpn) + + +class TestPerformServerConnectionPreface(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # send magic + self.wfile.write( + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) + self.wfile.flush() + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + def test_perform_server_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = http2.HTTP2Protocol(c) + protocol.perform_server_connection_preface() + + +class TestPerformClientConnectionPreface(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check magic + assert self.rfile.read(24) ==\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + def test_perform_client_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = http2.HTTP2Protocol(c) + protocol.perform_client_connection_preface() + + +class TestClientStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c) + + def test_client_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol.next_stream_id() == 1 + assert self.protocol.current_stream_id == 1 + assert self.protocol.next_stream_id() == 3 + assert self.protocol.current_stream_id == 3 + assert self.protocol.next_stream_id() == 5 + assert self.protocol.current_stream_id == 5 + + +class TestServerStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c, is_server=True) + + def test_server_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol.next_stream_id() == 2 + assert self.protocol.current_stream_id == 2 + assert self.protocol.next_stream_id() == 4 + assert self.protocol.current_stream_id == 4 + assert self.protocol.next_stream_id() == 6 + assert self.protocol.current_stream_id == 6 + + +class TestApplySettings(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check settings acknowledgement + assert self.rfile.read(9) == '000000040100000000'.decode('hex') + self.wfile.write("OK") + self.wfile.flush() + + ssl = True + + def test_apply_settings(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + protocol._apply_settings({ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 'bar', + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 'deadbeef', + }) + + assert c.rfile.safe_read(2) == "OK" + + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 'foo' + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 'bar' + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE] == 'deadbeef' + + +class TestCreateHeaders(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_headers(self): + headers = [ + (b':method', b'GET'), + (b':path', b'index.html'), + (b':scheme', b'https'), + (b'foo', b'bar')] + + bytes = http2.HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=True) + assert b''.join(bytes) ==\ + '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + bytes = http2.HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=False) + assert b''.join(bytes) ==\ + '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + # TODO: add test for too large header_block_fragments + + +class TestCreateBody(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c) + + def test_create_body_empty(self): + bytes = self.protocol._create_body(b'', 1) + assert b''.join(bytes) == ''.decode('hex') + + def test_create_body_single_frame(self): + bytes = self.protocol._create_body('foobar', 1) + assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') + + def test_create_body_multiple_frames(self): + pass + # bytes = self.protocol._create_body('foobar' * 3000, 1) + # TODO: add test for too large frames + + +class TestCreateRequest(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_request_simple(self): + bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') + assert len(bytes) == 1 + assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') + + def test_create_request_with_body(self): + bytes = http2.HTTP2Protocol(self.c).create_request( + 'GET', '/', [(b'foo', b'bar')], 'foobar') + assert len(bytes) == 2 + assert bytes[0] ==\ + '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') + + +class TestReadResponse(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801040000000188628594e78c767f'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + status, headers, body = protocol.read_response() + + assert headers == {':status': '200', 'etag': 'foobar'} + assert status == "200" + assert body == b'foobar' + + +class TestReadEmptyResponse(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801050000000188628594e78c767f'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_empty_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + status, headers, body = protocol.read_response() + + assert headers == {':status': '200', 'etag': 'foobar'} + assert status == "200" + assert body == b'' + + +class TestReadRequest(test.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'000003010400000001828487'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_request(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c, is_server=True) + + stream_id, headers, body = protocol.read_request() + + assert stream_id + assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} + assert body == b'foobar' + + +class TestCreateResponse(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_response_simple(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000000288'.decode('hex') + + def test_create_response_with_body(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( + 200, 1, [(b'foo', b'bar')], 'foobar') + assert len(bytes) == 2 + assert bytes[0] ==\ + '00000901040000000188408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') -- cgit v1.2.3 From 6a4dcaf3561cf279937114a8a80ebad8adcc1eec Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 18 Jun 2015 11:33:43 +0200 Subject: remove implementation badge line too short :-/ --- README.mkd | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.mkd b/README.mkd index 2f87d111..f5e66d99 100644 --- a/README.mkd +++ b/README.mkd @@ -4,7 +4,6 @@ [![Downloads](https://img.shields.io/pypi/dm/netlib.svg?color=orange)](https://pypi.python.org/pypi/netlib) [![Latest Version](https://img.shields.io/pypi/v/netlib.svg)](https://pypi.python.org/pypi/netlib) [![Supported Python versions](https://img.shields.io/pypi/pyversions/netlib.svg)](https://pypi.python.org/pypi/netlib) -[![Supported Python implementations](https://img.shields.io/pypi/implementation/netlib.svg)](https://pypi.python.org/pypi/netlib) Netlib is a collection of network utility classes, used by the pathod and mitmproxy projects. It differs from other projects in some fundamental @@ -16,7 +15,7 @@ functions, and are designed to allow misbehaviour when needed. Requirements ------------ -* [Python](http://www.python.org) 2.7.x. +* [Python](http://www.python.org) 2.7.x or a compatible version of pypy. * Third-party packages listed in [setup.py](https://github.com/mitmproxy/netlib/blob/master/setup.py) Hacking -- cgit v1.2.3 From 014b76bff7b9af0a9ff3704be49aa84232c7fa3e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 18 Jun 2015 11:36:58 +0200 Subject: include wheel as dev dependency --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 0051ea77..1f215baa 100644 --- a/setup.py +++ b/setup.py @@ -79,6 +79,7 @@ setup( "coveralls>=0.4.1", "autopep8>=1.0.3", "autoflake>=0.6.6", + "wheel>=0.24.0" "pathod>=%s, <%s" % (version.MINORVERSION, version.NEXT_MINORVERSION)]}, -- cgit v1.2.3 From 40436ffb1f8293dde9217e2a0167e6c66b11d1f1 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 18 Jun 2015 13:12:06 +0200 Subject: fix setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d08ea17a..3a1d7811 100644 --- a/setup.py +++ b/setup.py @@ -81,7 +81,7 @@ setup( "coveralls>=0.4.1", "autopep8>=1.0.3", "autoflake>=0.6.6", - "wheel>=0.24.0" + "wheel>=0.24.0", "pathod>=%s, <%s" % (version.MINORVERSION, version.NEXT_MINORVERSION)]}, -- cgit v1.2.3 From 69e71097f7a9633a43d566b2a46aab370f07dce3 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 18 Jun 2015 15:32:52 +0200 Subject: mark unused variables and arguments --- .landscape.yml | 3 +++ netlib/certutils.py | 2 +- netlib/http2/frame.py | 3 ++- netlib/http2/protocol.py | 15 +++++++-------- netlib/http_auth.py | 9 +++++---- netlib/http_cookies.py | 9 +++------ netlib/tcp.py | 10 +++++----- netlib/wsgi.py | 2 +- 8 files changed, 27 insertions(+), 26 deletions(-) diff --git a/.landscape.yml b/.landscape.yml index 5926e7bf..ccaa5fc3 100644 --- a/.landscape.yml +++ b/.landscape.yml @@ -1,5 +1,8 @@ max-line-length: 120 pylint: + options: + dummy-variables-rgx: _$|.+_$|dummy_.+ + disable: - missing-docstring - protected-access diff --git a/netlib/certutils.py b/netlib/certutils.py index ade61bb5..c6f0e628 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -333,7 +333,7 @@ class CertStore(object): return entry.cert, entry.privatekey, entry.chain_file - def gen_pkey(self, cert): + def gen_pkey(self, cert_): # FIXME: We should do something with cert here? from . import certffi certffi.set_flags(self.default_privatekey, 1) diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py index b4783a02..f7e60471 100644 --- a/netlib/http2/frame.py +++ b/netlib/http2/frame.py @@ -116,7 +116,8 @@ class Frame(object): self.length = len(self.payload_bytes()) return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % (direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), self.payload_human_readable(), "===============================================================", ]) diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py index ac89bac4..8e5f5429 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http2/protocol.py @@ -59,8 +59,8 @@ class HTTP2Protocol(object): while True: frm = self.read_frame(hide) if isinstance(frm, frame.SettingsFrame): - assert settings_ack_frame.flags & frame.Frame.FLAG_ACK - assert len(settings_ack_frame.settings) == 0 + assert frm.flags & frame.Frame.FLAG_ACK + assert len(frm.settings) == 0 break def perform_server_connection_preface(self, force=False): @@ -118,11 +118,10 @@ class HTTP2Protocol(object): old_value = '-' self.http2_settings[setting] = value - self.send_frame( - frame.SettingsFrame( - state=self, - flags=frame.Frame.FLAG_ACK), - hide) + frm = frame.SettingsFrame( + state=self, + flags=frame.Frame.FLAG_ACK) + self.send_frame(frm, hide) # be liberal in what we expect from the other end # to be more strict use: self._read_settings_ack(hide) @@ -188,7 +187,7 @@ class HTTP2Protocol(object): self._create_body(body, stream_id))) def read_response(self): - stream_id, headers, body = self._receive_transmission() + stream_id_, headers, body = self._receive_transmission() return headers[':status'], headers, body def read_request(self): diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 0143760c..adab4aed 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -12,12 +12,13 @@ class NullProxyAuth(object): def __init__(self, password_manager): self.password_manager = password_manager - def clean(self, headers): + def clean(self, headers_): """ Clean up authentication headers, so they're not passed upstream. """ + pass - def authenticate(self, headers): + def authenticate(self, headers_): """ Tests that the user is allowed to use the proxy """ @@ -62,7 +63,7 @@ class BasicProxyAuth(NullProxyAuth): class PassMan(object): - def test(self, username, password_token): + def test(self, username_, password_token_): return False @@ -72,7 +73,7 @@ class PassManNonAnon(PassMan): Ensure the user specifies a username, accept any password. """ - def test(self, username, password_token): + def test(self, username, password_token_): if username: return True return False diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index b7311714..e91ee5c0 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -87,7 +87,7 @@ def _read_value(s, start, delims): return _read_until(s, start, delims) -def _read_pairs(s, off=0, specials=()): +def _read_pairs(s, off=0): """ Read pairs of lhs=rhs values. @@ -151,10 +151,7 @@ def _parse_set_cookie_pairs(s): For Set-Cookie, we support multiple cookies as described in RFC2109. This function therefore returns a list of lists. """ - pairs, off = _read_pairs( - s, - specials=("expires", "path") - ) + pairs, off_ = _read_pairs(s) return pairs @@ -185,7 +182,7 @@ def parse_cookie_header(line): Parse a Cookie header value. Returns a (possibly empty) ODict object. """ - pairs, off = _read_pairs(line) + pairs, off_ = _read_pairs(line) return odict.ODict(pairs) diff --git a/netlib/tcp.py b/netlib/tcp.py index 65075776..77eb7b52 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -403,7 +403,7 @@ class _Connection(object): # Verify Options (NONE/PEER/PEER|FAIL_IF_... and trusted CAs) if verify_options is not None and verify_options is not SSL.VERIFY_NONE: - def verify_cert(conn, cert, errno, err_depth, is_cert_verified): + def verify_cert(conn_, cert_, errno, err_depth, is_cert_verified): if is_cert_verified: return True raise NetLibError( @@ -439,7 +439,7 @@ class _Connection(object): context.set_alpn_protos(alpn_protos) elif alpn_select is not None: # select application layer protocol - def alpn_select_callback(conn, options): + def alpn_select_callback(conn_, options): if alpn_select in options: return bytes(alpn_select) else: # pragma no cover @@ -601,7 +601,7 @@ class BaseHandler(_Connection): context.set_tlsext_servername_callback(handle_sni) if request_client_cert: - def save_cert(conn, cert, errno, depth, preverify_ok): + def save_cert(conn_, cert, errno_, depth_, preverify_ok_): self.clientcert = certutils.SSLCert(cert) # Return true to prevent cert verification error return True @@ -676,7 +676,7 @@ class TCPServer(object): try: while not self.__shutdown_request: try: - r, w, e = select.select( + r, w_, e_ = select.select( [self.socket], [], [], poll_interval) except select.error as ex: # pragma: no cover if ex[0] == EINTR: @@ -708,7 +708,7 @@ class TCPServer(object): self.socket.close() self.handle_shutdown() - def handle_error(self, connection, client_address, fp=sys.stderr): + def handle_error(self, connection_, client_address, fp=sys.stderr): """ Called when handle_client_connection raises an exception. """ diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 827cf6f0..ad43dc19 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -35,7 +35,7 @@ def date_time_string(): 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' ] now = time.time() - year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) + year, month, day, hh, mm, ss, wd, y_, z_ = time.gmtime(now) s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( WEEKS[wd], day, MONTHS[month], year, -- cgit v1.2.3 From f5c5deb2aea047394238f3b993ddf24c60845768 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 18 Jun 2015 17:36:58 +0200 Subject: fix http user agents --- netlib/http_uastrings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py index c1ef557c..e8681908 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http_uastrings.py @@ -30,10 +30,10 @@ UASTRINGS = [ "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa ("ie9", "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US))"), # noqa + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa ("ipad", "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko ) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa ("iphone", "h", "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa -- cgit v1.2.3 From 2aa1b98fbf8d03005e022da86e3e534cf25ebf62 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 22 Jun 2015 14:52:23 +1200 Subject: netlib/test.py -> test/tservers.py --- netlib/test.py | 108 -------------------------------------------- test/http2/test_protocol.py | 18 ++++---- test/test_http.py | 6 +-- test/test_tcp.py | 56 +++++++++++------------ test/test_websockets.py | 8 ++-- test/tservers.py | 108 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 152 insertions(+), 152 deletions(-) delete mode 100644 netlib/test.py create mode 100644 test/tservers.py diff --git a/netlib/test.py b/netlib/test.py deleted file mode 100644 index 1e1b5e9d..00000000 --- a/netlib/test.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import threading -import Queue -import cStringIO -import OpenSSL -from . import tcp, certutils -from test import tutils - - -class ServerThread(threading.Thread): - - def __init__(self, server): - self.server = server - threading.Thread.__init__(self) - - def run(self): - self.server.serve_forever() - - def shutdown(self): - self.server.shutdown() - - -class ServerTestBase(object): - ssl = None - handler = None - addr = ("localhost", 0) - - @classmethod - def setupAll(cls): - cls.q = Queue.Queue() - s = cls.makeserver() - cls.port = s.address.port - cls.server = ServerThread(s) - cls.server.start() - - @classmethod - def makeserver(cls): - return TServer(cls.ssl, cls.q, cls.handler, cls.addr) - - @classmethod - def teardownAll(cls): - cls.server.shutdown() - - @property - def last_handler(self): - return self.server.server.last_handler - - -class TServer(tcp.TCPServer): - - def __init__(self, ssl, q, handler_klass, addr): - """ - ssl: A dictionary of SSL parameters: - - cert, key, request_client_cert, cipher_list, - dhparams, v3_only - """ - tcp.TCPServer.__init__(self, addr) - - if ssl is True: - self.ssl = dict() - elif isinstance(ssl, dict): - self.ssl = ssl - else: - self.ssl = None - - self.q = q - self.handler_klass = handler_klass - self.last_handler = None - - def handle_client_connection(self, request, client_address): - h = self.handler_klass(request, client_address, self) - self.last_handler = h - if self.ssl is not None: - raw_cert = self.ssl.get( - "cert", - tutils.test_data.path("data/server.crt")) - cert = certutils.SSLCert.from_pem(open(raw_cert, "rb").read()) - raw_key = self.ssl.get( - "key", - tutils.test_data.path("data/server.key")) - key = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - open(raw_key, "rb").read()) - if self.ssl.get("v3_only", False): - method = tcp.SSLv3_METHOD - options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 - else: - method = tcp.SSLv23_METHOD - options = None - h.convert_to_ssl( - cert, key, - method=method, - options=options, - handle_sni=getattr(h, "handle_sni", None), - request_client_cert=self.ssl.get("request_client_cert", None), - cipher_list=self.ssl.get("cipher_list", None), - dhparams=self.ssl.get("dhparams", None), - chain_file=self.ssl.get("chain_file", None), - alpn_select=self.ssl.get("alpn_select", None) - ) - h.handle() - h.finish() - - def handle_error(self, connection, client_address, fp=None): - s = cStringIO.StringIO() - tcp.TCPServer.handle_error(self, connection, client_address, s) - self.q.put(s.getvalue()) diff --git a/test/http2/test_protocol.py b/test/http2/test_protocol.py index 9b49acd3..5e2af34e 100644 --- a/test/http2/test_protocol.py +++ b/test/http2/test_protocol.py @@ -2,9 +2,9 @@ import OpenSSL from netlib import http2 from netlib import tcp -from netlib import test from netlib.http2.frame import * from test import tutils +from .. import tservers class EchoHandler(tcp.BaseHandler): @@ -17,7 +17,7 @@ class EchoHandler(tcp.BaseHandler): self.wfile.flush() -class TestCheckALPNMatch(test.ServerTestBase): +class TestCheckALPNMatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, @@ -33,7 +33,7 @@ class TestCheckALPNMatch(test.ServerTestBase): assert protocol.check_alpn() -class TestCheckALPNMismatch(test.ServerTestBase): +class TestCheckALPNMismatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( alpn_select=None, @@ -49,7 +49,7 @@ class TestCheckALPNMismatch(test.ServerTestBase): tutils.raises(NotImplementedError, protocol.check_alpn) -class TestPerformServerConnectionPreface(test.ServerTestBase): +class TestPerformServerConnectionPreface(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -81,7 +81,7 @@ class TestPerformServerConnectionPreface(test.ServerTestBase): protocol.perform_server_connection_preface() -class TestPerformClientConnectionPreface(test.ServerTestBase): +class TestPerformClientConnectionPreface(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -140,7 +140,7 @@ class TestServerStreamIds(): assert self.protocol.current_stream_id == 6 -class TestApplySettings(test.ServerTestBase): +class TestApplySettings(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -234,7 +234,7 @@ class TestCreateRequest(): '000006000100000001666f6f626172'.decode('hex') -class TestReadResponse(test.ServerTestBase): +class TestReadResponse(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -259,7 +259,7 @@ class TestReadResponse(test.ServerTestBase): assert body == b'foobar' -class TestReadEmptyResponse(test.ServerTestBase): +class TestReadEmptyResponse(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -282,7 +282,7 @@ class TestReadEmptyResponse(test.ServerTestBase): assert body == b'' -class TestReadRequest(test.ServerTestBase): +class TestReadRequest(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): diff --git a/test/test_http.py b/test/test_http.py index 0a9e276f..2ad81d24 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -1,8 +1,8 @@ import cStringIO import textwrap import binascii -from netlib import http, odict, tcp, test -import tutils +from netlib import http, odict, tcp +from . import tutils, tservers def test_httperror(): @@ -284,7 +284,7 @@ class NoContentLengthHTTPHandler(tcp.BaseHandler): self.wfile.flush() -class TestReadResponseNoContentLength(test.ServerTestBase): +class TestReadResponseNoContentLength(tservers.ServerTestBase): handler = NoContentLengthHTTPHandler def test_no_content_length(self): diff --git a/test/test_tcp.py b/test/test_tcp.py index 122c1f0f..4253e073 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -10,8 +10,8 @@ import mock from OpenSSL import SSL import OpenSSL -from netlib import tcp, certutils, test, certffi -import tutils +from netlib import tcp, certutils, certffi +from . import tutils, tservers class EchoHandler(tcp.BaseHandler): @@ -53,7 +53,7 @@ class ALPNHandler(tcp.BaseHandler): self.wfile.flush() -class TestServer(test.ServerTestBase): +class TestServer(tservers.ServerTestBase): handler = EchoHandler def test_echo(self): @@ -74,7 +74,7 @@ class TestServer(test.ServerTestBase): self.test_echo() -class TestServerBind(test.ServerTestBase): +class TestServerBind(tservers.ServerTestBase): class handler(tcp.BaseHandler): @@ -97,7 +97,7 @@ class TestServerBind(test.ServerTestBase): pass -class TestServerIPv6(test.ServerTestBase): +class TestServerIPv6(tservers.ServerTestBase): handler = EchoHandler addr = tcp.Address(("localhost", 0), use_ipv6=True) @@ -110,7 +110,7 @@ class TestServerIPv6(test.ServerTestBase): assert c.rfile.readline() == testval -class TestEcho(test.ServerTestBase): +class TestEcho(tservers.ServerTestBase): handler = EchoHandler def test_echo(self): @@ -128,7 +128,7 @@ class HardDisconnectHandler(tcp.BaseHandler): self.connection.close() -class TestFinishFail(test.ServerTestBase): +class TestFinishFail(tservers.ServerTestBase): """ This tests a difficult-to-trigger exception in the .finish() method of @@ -144,7 +144,7 @@ class TestFinishFail(test.ServerTestBase): c.finish() -class TestServerSSL(test.ServerTestBase): +class TestServerSSL(tservers.ServerTestBase): handler = EchoHandler ssl = dict( cipher_list="AES256-SHA", @@ -170,7 +170,7 @@ class TestServerSSL(test.ServerTestBase): assert "AES" in ret[0] -class TestSSLv3Only(test.ServerTestBase): +class TestSSLv3Only(tservers.ServerTestBase): handler = EchoHandler ssl = dict( request_client_cert=False, @@ -183,7 +183,7 @@ class TestSSLv3Only(test.ServerTestBase): tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com") -class TestSSLUpstreamCertVerification(test.ServerTestBase): +class TestSSLUpstreamCertVerification(tservers.ServerTestBase): handler = EchoHandler ssl = dict( @@ -236,7 +236,7 @@ class TestSSLUpstreamCertVerification(test.ServerTestBase): assert c.rfile.readline() == testval -class TestSSLClientCert(test.ServerTestBase): +class TestSSLClientCert(tservers.ServerTestBase): class handler(tcp.BaseHandler): sni = None @@ -270,7 +270,7 @@ class TestSSLClientCert(test.ServerTestBase): ) -class TestSNI(test.ServerTestBase): +class TestSNI(tservers.ServerTestBase): class handler(tcp.BaseHandler): sni = None @@ -292,7 +292,7 @@ class TestSNI(test.ServerTestBase): assert c.rfile.readline() == "foo.com" -class TestServerCipherList(test.ServerTestBase): +class TestServerCipherList(tservers.ServerTestBase): handler = ClientCipherListHandler ssl = dict( cipher_list='RC4-SHA' @@ -305,7 +305,7 @@ class TestServerCipherList(test.ServerTestBase): assert c.rfile.readline() == "['RC4-SHA']" -class TestServerCurrentCipher(test.ServerTestBase): +class TestServerCurrentCipher(tservers.ServerTestBase): class handler(tcp.BaseHandler): sni = None @@ -325,7 +325,7 @@ class TestServerCurrentCipher(test.ServerTestBase): assert "RC4-SHA" in c.rfile.readline() -class TestServerCipherListError(test.ServerTestBase): +class TestServerCipherListError(tservers.ServerTestBase): handler = ClientCipherListHandler ssl = dict( cipher_list='bogus' @@ -337,7 +337,7 @@ class TestServerCipherListError(test.ServerTestBase): tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com") -class TestClientCipherListError(test.ServerTestBase): +class TestClientCipherListError(tservers.ServerTestBase): handler = ClientCipherListHandler ssl = dict( cipher_list='RC4-SHA' @@ -353,7 +353,7 @@ class TestClientCipherListError(test.ServerTestBase): cipher_list="bogus") -class TestSSLDisconnect(test.ServerTestBase): +class TestSSLDisconnect(tservers.ServerTestBase): class handler(tcp.BaseHandler): @@ -373,7 +373,7 @@ class TestSSLDisconnect(test.ServerTestBase): tutils.raises(Queue.Empty, self.q.get_nowait) -class TestSSLHardDisconnect(test.ServerTestBase): +class TestSSLHardDisconnect(tservers.ServerTestBase): handler = HardDisconnectHandler ssl = True @@ -387,7 +387,7 @@ class TestSSLHardDisconnect(test.ServerTestBase): tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") -class TestDisconnect(test.ServerTestBase): +class TestDisconnect(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -398,7 +398,7 @@ class TestDisconnect(test.ServerTestBase): c.close() -class TestServerTimeOut(test.ServerTestBase): +class TestServerTimeOut(tservers.ServerTestBase): class handler(tcp.BaseHandler): @@ -417,7 +417,7 @@ class TestServerTimeOut(test.ServerTestBase): assert self.last_handler.timeout -class TestTimeOut(test.ServerTestBase): +class TestTimeOut(tservers.ServerTestBase): handler = HangHandler def test_timeout(self): @@ -428,7 +428,7 @@ class TestTimeOut(test.ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) -class TestALPNClient(test.ServerTestBase): +class TestALPNClient(tservers.ServerTestBase): handler = ALPNHandler ssl = dict( alpn_select="bar" @@ -457,7 +457,7 @@ class TestALPNClient(test.ServerTestBase): assert c.get_alpn_proto_negotiated() == "" assert c.rfile.readline() == "NONE" -class TestNoSSLNoALPNClient(test.ServerTestBase): +class TestNoSSLNoALPNClient(tservers.ServerTestBase): handler = ALPNHandler def test_no_ssl_no_alpn(self): @@ -467,7 +467,7 @@ class TestNoSSLNoALPNClient(test.ServerTestBase): assert c.rfile.readline().strip() == "NONE" -class TestSSLTimeOut(test.ServerTestBase): +class TestSSLTimeOut(tservers.ServerTestBase): handler = HangHandler ssl = True @@ -479,7 +479,7 @@ class TestSSLTimeOut(test.ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) -class TestDHParams(test.ServerTestBase): +class TestDHParams(tservers.ServerTestBase): handler = HangHandler ssl = dict( dhparams=certutils.CertStore.load_dhparam( @@ -502,7 +502,7 @@ class TestDHParams(test.ServerTestBase): assert os.path.exists(filename) -class TestPrivkeyGen(test.ServerTestBase): +class TestPrivkeyGen(tservers.ServerTestBase): class handler(tcp.BaseHandler): @@ -520,7 +520,7 @@ class TestPrivkeyGen(test.ServerTestBase): tutils.raises("bad record mac", c.convert_to_ssl) -class TestPrivkeyGenNoFlags(test.ServerTestBase): +class TestPrivkeyGenNoFlags(tservers.ServerTestBase): class handler(tcp.BaseHandler): @@ -684,7 +684,7 @@ class TestAddress: assert repr(a) -class TestSSLKeyLogger(test.ServerTestBase): +class TestSSLKeyLogger(tservers.ServerTestBase): handler = EchoHandler ssl = dict( cipher_list="AES256-SHA" diff --git a/test/test_websockets.py b/test/test_websockets.py index 8ed14708..9956543b 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -2,8 +2,8 @@ import os from nose.tools import raises -from netlib import tcp, test, websockets, http -import tutils +from netlib import tcp, websockets, http +from . import tutils, tservers class WebSocketsEchoHandler(tcp.BaseHandler): @@ -75,7 +75,7 @@ class WebSocketsClient(tcp.TCPClient): frame.to_file(self.wfile) -class TestWebSockets(test.ServerTestBase): +class TestWebSockets(tservers.ServerTestBase): handler = WebSocketsEchoHandler def random_bytes(self, n=100): @@ -155,7 +155,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler): self.handshake_done = True -class TestBadHandshake(test.ServerTestBase): +class TestBadHandshake(tservers.ServerTestBase): """ Ensure that the client disconnects if the server handshake is malformed diff --git a/test/tservers.py b/test/tservers.py new file mode 100644 index 00000000..899b51bd --- /dev/null +++ b/test/tservers.py @@ -0,0 +1,108 @@ +from __future__ import (absolute_import, print_function, division) +import threading +import Queue +import cStringIO +import OpenSSL +from netlib import tcp, certutils +from . import tutils + + +class ServerThread(threading.Thread): + + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class ServerTestBase(object): + ssl = None + handler = None + addr = ("localhost", 0) + + @classmethod + def setupAll(cls): + cls.q = Queue.Queue() + s = cls.makeserver() + cls.port = s.address.port + cls.server = ServerThread(s) + cls.server.start() + + @classmethod + def makeserver(cls): + return TServer(cls.ssl, cls.q, cls.handler, cls.addr) + + @classmethod + def teardownAll(cls): + cls.server.shutdown() + + @property + def last_handler(self): + return self.server.server.last_handler + + +class TServer(tcp.TCPServer): + + def __init__(self, ssl, q, handler_klass, addr): + """ + ssl: A dictionary of SSL parameters: + + cert, key, request_client_cert, cipher_list, + dhparams, v3_only + """ + tcp.TCPServer.__init__(self, addr) + + if ssl is True: + self.ssl = dict() + elif isinstance(ssl, dict): + self.ssl = ssl + else: + self.ssl = None + + self.q = q + self.handler_klass = handler_klass + self.last_handler = None + + def handle_client_connection(self, request, client_address): + h = self.handler_klass(request, client_address, self) + self.last_handler = h + if self.ssl is not None: + raw_cert = self.ssl.get( + "cert", + tutils.test_data.path("data/server.crt")) + cert = certutils.SSLCert.from_pem(open(raw_cert, "rb").read()) + raw_key = self.ssl.get( + "key", + tutils.test_data.path("data/server.key")) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + open(raw_key, "rb").read()) + if self.ssl.get("v3_only", False): + method = tcp.SSLv3_METHOD + options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 + else: + method = tcp.SSLv23_METHOD + options = None + h.convert_to_ssl( + cert, key, + method=method, + options=options, + handle_sni=getattr(h, "handle_sni", None), + request_client_cert=self.ssl.get("request_client_cert", None), + cipher_list=self.ssl.get("cipher_list", None), + dhparams=self.ssl.get("dhparams", None), + chain_file=self.ssl.get("chain_file", None), + alpn_select=self.ssl.get("alpn_select", None) + ) + h.handle() + h.finish() + + def handle_error(self, connection, client_address, fp=None): + s = cStringIO.StringIO() + tcp.TCPServer.handle_error(self, connection, client_address, s) + self.q.put(s.getvalue()) -- cgit v1.2.3 From 58118d607e810e95fe8a0c0e6d7b8f4423f1f558 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 22 Jun 2015 20:39:30 +0200 Subject: unify SSL version/method handling --- netlib/tcp.py | 25 ++++++++++++++++++------- test/tservers.py | 4 ++-- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 77eb7b52..705cc311 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,13 +16,24 @@ from . import certutils EINTR = 4 -SSLv2_METHOD = SSL.SSLv2_METHOD -SSLv3_METHOD = SSL.SSLv3_METHOD -SSLv23_METHOD = SSL.SSLv23_METHOD -TLSv1_METHOD = SSL.TLSv1_METHOD -TLSv1_1_METHOD = SSL.TLSv1_1_METHOD -TLSv1_2_METHOD = SSL.TLSv1_2_METHOD +# To enable all SSL methods use: SSLv23 +# then add options to disable certain methods +# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 +# Use ONLY for parsing of CLI arguments! +# All code internals should use OpenSSL constants directly! +SSL_VERSIONS = { + 'TLSv1.2': SSL.TLSv1_2_METHOD, + 'TLSv1.1': SSL.TLSv1_1_METHOD, + 'TLSv1': SSL.TLSv1_METHOD, + 'SSLv3': SSL.SSLv3_METHOD, + 'SSLv2': SSL.SSLv2_METHOD, + 'SSLv23': SSL.SSLv23_METHOD, +} + +SSL_DEFAULT_VERSION = 'SSLv23' + +SSL_DEFAULT_METHOD = SSL_VERSIONS[SSL_DEFAULT_VERSION] SSL_DEFAULT_OPTIONS = ( SSL.OP_NO_SSLv2 | @@ -376,7 +387,7 @@ class _Connection(object): pass def _create_ssl_context(self, - method=SSLv23_METHOD, + method=SSL_DEFAULT_METHOD, options=SSL_DEFAULT_OPTIONS, verify_options=SSL.VERIFY_NONE, ca_path=certifi.where(), diff --git a/test/tservers.py b/test/tservers.py index 899b51bd..09f1b095 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -83,10 +83,10 @@ class TServer(tcp.TCPServer): OpenSSL.crypto.FILETYPE_PEM, open(raw_key, "rb").read()) if self.ssl.get("v3_only", False): - method = tcp.SSLv3_METHOD + method = OpenSSL.SSL.SSLv3_METHOD options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 else: - method = tcp.SSLv23_METHOD + method = OpenSSL.SSL.SSLv23_METHOD options = None h.convert_to_ssl( cert, key, -- cgit v1.2.3 From 7afe44ba4ee8810e24abfa32f74dfac61e5551d3 Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Sat, 20 Jun 2015 12:54:03 -0700 Subject: Updating TCPServer to allow tests (and potentially other use cases) to serve certificate chains instead of only single certificates. --- netlib/tcp.py | 8 ++++++-- test/tservers.py | 3 +-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 77eb7b52..61306e4e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -567,7 +567,8 @@ class BaseHandler(_Connection): dhparams=None, **sslctx_kwargs): """ - cert: A certutils.SSLCert object. + cert: A certutils.SSLCert object or the path to a certificate + chain file. handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: @@ -594,7 +595,10 @@ class BaseHandler(_Connection): context = self._create_ssl_context(**sslctx_kwargs) context.use_privatekey(key) - context.use_certificate(cert.x509) + if isinstance(cert, certutils.SSLCert): + context.use_certificate(cert.x509) + else: + context.use_certificate_chain_file(cert) if handle_sni: # SNI callback happens during do_handshake() diff --git a/test/tservers.py b/test/tservers.py index 899b51bd..5c1ea08b 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -72,10 +72,9 @@ class TServer(tcp.TCPServer): h = self.handler_klass(request, client_address, self) self.last_handler = h if self.ssl is not None: - raw_cert = self.ssl.get( + cert = self.ssl.get( "cert", tutils.test_data.path("data/server.crt")) - cert = certutils.SSLCert.from_pem(open(raw_cert, "rb").read()) raw_key = self.ssl.get( "key", tutils.test_data.path("data/server.key")) -- cgit v1.2.3 From d1452424beced04dc42bbadd68878d9e1c24da9c Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Sat, 20 Jun 2015 13:07:23 -0700 Subject: Cleaning up upstream server verification. Adding storage of cerificate verification errors on TCPClient object to enable warnings in downstream projects. --- netlib/tcp.py | 16 ++-- test/data/not-server.crt | 15 ---- test/data/verificationcerts/9d45e6a9.0 | 1 + test/data/verificationcerts/interm.key | 16 ++++ test/data/verificationcerts/trusted-chain.crt | 35 +++++++++ test/data/verificationcerts/trusted-interm.crt | 19 +++++ test/data/verificationcerts/trusted.key | 15 ++++ test/data/verificationcerts/trusted.pem | 15 ++++ test/data/verificationcerts/untrusted-chain.crt | 33 +++++++++ test/data/verificationcerts/untrusted-interm.crt | 17 +++++ test/data/verificationcerts/untrusted.crt | 16 ++++ .../data/verificationcerts/verification-server.key | 16 ++++ test/test_tcp.py | 86 +++++++++++++++++++--- 13 files changed, 266 insertions(+), 34 deletions(-) delete mode 100644 test/data/not-server.crt create mode 120000 test/data/verificationcerts/9d45e6a9.0 create mode 100644 test/data/verificationcerts/interm.key create mode 100644 test/data/verificationcerts/trusted-chain.crt create mode 100644 test/data/verificationcerts/trusted-interm.crt create mode 100644 test/data/verificationcerts/trusted.key create mode 100644 test/data/verificationcerts/trusted.pem create mode 100644 test/data/verificationcerts/untrusted-chain.crt create mode 100644 test/data/verificationcerts/untrusted-interm.crt create mode 100644 test/data/verificationcerts/untrusted.crt create mode 100644 test/data/verificationcerts/verification-server.key diff --git a/netlib/tcp.py b/netlib/tcp.py index 61306e4e..2cae34ec 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -401,14 +401,13 @@ class _Connection(object): if options is not None: context.set_options(options) - # Verify Options (NONE/PEER/PEER|FAIL_IF_... and trusted CAs) - if verify_options is not None and verify_options is not SSL.VERIFY_NONE: - def verify_cert(conn_, cert_, errno, err_depth, is_cert_verified): - if is_cert_verified: - return True - raise NetLibError( - "Upstream certificate validation failed at depth: %s with error number: %s" % - (err_depth, errno)) + # Verify Options (NONE/PEER and trusted CAs) + if verify_options is not None: + def verify_cert(conn, x509, errno, err_depth, is_cert_verified): + if not is_cert_verified: + self.ssl_verification_error = dict(errno=errno, + depth=err_depth) + return is_cert_verified context.set_verify(verify_options, verify_cert) context.load_verify_locations(ca_pemfile, ca_path) @@ -469,6 +468,7 @@ class TCPClient(_Connection): self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False + self.ssl_verification_error = None self.sni = None def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): diff --git a/test/data/not-server.crt b/test/data/not-server.crt deleted file mode 100644 index 08c015c2..00000000 --- a/test/data/not-server.crt +++ /dev/null @@ -1,15 +0,0 @@ ------BEGIN CERTIFICATE----- -MIICRTCCAa4CCQD/j4qq1h3iCjANBgkqhkiG9w0BAQUFADBnMQswCQYDVQQGEwJV -UzELMAkGA1UECBMCQ0ExETAPBgNVBAcTCFNvbWVDaXR5MRcwFQYDVQQKEw5Ob3RU -aGVSaWdodE9yZzELMAkGA1UECxMCTkExEjAQBgNVBAMTCU5vdFNlcnZlcjAeFw0x -NTA2MTMwMTE2MDZaFw0yNTA2MTAwMTE2MDZaMGcxCzAJBgNVBAYTAlVTMQswCQYD -VQQIEwJDQTERMA8GA1UEBxMIU29tZUNpdHkxFzAVBgNVBAoTDk5vdFRoZVJpZ2h0 -T3JnMQswCQYDVQQLEwJOQTESMBAGA1UEAxMJTm90U2VydmVyMIGfMA0GCSqGSIb3 -DQEBAQUAA4GNADCBiQKBgQDPkJlXAOCMKF0R7aDn5QJ7HtrJgOUDk/LpbhKhRZZR -dRGnJ4/HQxYYHh9k/4yZamYcvQPUxvFJt7UJUocf+84LUcIusUk7GvJMgsMVtFMq -7UKNXBN5tl3oOtoFDWGMZ8ksaIxS6oW3V/9v2WgU23PfvwE0EZqy+QhMLZZP5GOH -RwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAJI6UtMKdCS2ghjqhAek2W1rt9u+Wuvx -776WYm5VyrJEtBDc/axLh0OteXzy/A31JrYe15fnVWIeFbDF0Ief9/Ezv6Jn+Pk8 -DErw5IHk2B399O4K3L3Eig06piu7uf3vE4l8ZanY02ZEnw7DyL6kmG9lX98VGenF -uXPfu3yxKbR4 ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/9d45e6a9.0 b/test/data/verificationcerts/9d45e6a9.0 new file mode 120000 index 00000000..2f34cfaa --- /dev/null +++ b/test/data/verificationcerts/9d45e6a9.0 @@ -0,0 +1 @@ +trusted.pem \ No newline at end of file diff --git a/test/data/verificationcerts/interm.key b/test/data/verificationcerts/interm.key new file mode 100644 index 00000000..76c05cf4 --- /dev/null +++ b/test/data/verificationcerts/interm.key @@ -0,0 +1,16 @@ +# Key used to sign trusted-interm.crt and untrusted-interm.crt +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQC1E80qCHhZ1gaZTYB7pN/Yxt3ehpEj+5hCbpop5iTWLuDjULS9 +WjA1wP+p02kZQ2dqL8pqT1qcc5jKmk2jvMeB/cQ7zNDg1NCmQMqx0KptRByMZ+GN +Zcqc7D4jl6vhGP4zAzV/lxvBvxtgeJI+ZdrHN0vT9I1cYADKz9SzCDCRTwIDAQAB +AoGAfKHocKnrzEmXuSSy7meI+vfF9kfA1ndxUSg3S+dwK0uQ1mTSQhI1ZIo2bnlo +uU6/e0Lxm0KLJ2wZGjoifjSNTC8pcxIfAQY4kM9fqoUcXVSBVSS2kByTunhNSVZQ +yQyc+UTq9g1zBnJsZAltn7/PaihU4heWgP/++lposuShqmECQQDaG+7l0qul1xak +9kuZgc88BSTfn9iMK2zIQRcVKuidK4dT3QEp0wmWR5Ue8jq8lvTmVTGNGZbHcheh +KhoZfLgLAkEA1IjwAw/8z02yV3lbc2QUjIl9m9lvjHBoE2sGuSfq/cZskLKrGat+ +CVj3spqVAg22tpQwVBuHiipBziWVnEtiTQJAB9FKfchQSLBt6lm9mfHyKJeSm8VR +8Kw5yO+0URjpn4CI6DOasBIVXOKR8LsD6fCLNJpHHWSWZ+2p9SfaKaGzwwJBAM31 +Scld89qca4fzNZkT0goCrvOZeUy6HVE79Q72zPVSFSD/02kT1BaQ3bB5to5/5aD2 +6AKJjwZoPs7bgykrsD0CQBzU8U/8x2dNQnG0QeqaKQu5kKhZSZ9bsawvrCkxSl6b +WAjl/Jehi5bbQ07zQo3cge6qeR38FCWVCHQ/5wNbc54= +-----END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/trusted-chain.crt b/test/data/verificationcerts/trusted-chain.crt new file mode 100644 index 00000000..dd30bff3 --- /dev/null +++ b/test/data/verificationcerts/trusted-chain.crt @@ -0,0 +1,35 @@ +# untrusted.crt, signed by trusted-interm.crt +-----BEGIN CERTIFICATE----- +MIICYzCCAcwCAhAIMA0GCSqGSIb3DQEBBQUAMH4xCzAJBgNVBAYTAkFVMRMwEQYD +VQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBM +dGQxFDASBgNVBAsTC0lOVEVSTSBVTklUMSEwHwYDVQQDExhPUkcgV0lUSCBJTlRF +Uk1FRElBVEUgQ0EwIBcNMTUwNjIwMDEyMDI1WhgPMjExNTA1MjcwMTIwMjVaMHMx +CzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRl +cm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAsTCUxFQUYgVU5JVDEYMBYGA1UE +AxMPTk9UIFRSVVNURUQgT1JHMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDf +NZx/tugICrWGcpP8sa+EBX9WhazCsYIm8YgQrQO9B19dK7cHsWB+vIdFuDKHxfS2 +JBIeVSaZ6H4onWGnZRAMpi5xnitVhBQKCZP1yOewtrg2umZIbcTz8A+BwAcvmmQN +7RZMfpxN9PMccWDfgtAXsjZ2E47o9EfhpGvxfcFc0wIDAQABMA0GCSqGSIb3DQEB +BQUAA4GBABtmc8zn5efVi3iVIgODadKkTv43elIwNZBqEJ6IaoVXvi5Mp1m4VxML +LQGPTNG1lpuVDz2z/Ml78942316ailCTOx48oDnb/yy4jI6hsp+N8p6T28/Wvkbm +cCgohk6/Cwat5gf+HwoIe5Z3B3HRJaIcB0OteluuLsHAvverBjc4 +-----END CERTIFICATE----- +# trusted-interm.crt, signed by trusted.pem +-----BEGIN CERTIFICATE----- +MIIC8jCCAlugAwIBAgICEAcwDQYJKoZIhvcNAQEFBQAwVzELMAkGA1UEBhMCQVUx +EzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMg +UHR5IEx0ZDEQMA4GA1UEAxMHVFJVU1RFRDAgFw0xNTA2MjAwMTE4MjdaGA8yMTE1 +MDUyNzAxMTgyN1owfjELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUx +ITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UECxMLSU5U +RVJNIFVOSVQxITAfBgNVBAMTGE9SRyBXSVRIIElOVEVSTUVESUFURSBDQTCBnzAN +BgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAtRPNKgh4WdYGmU2Ae6Tf2Mbd3oaRI/uY +Qm6aKeYk1i7g41C0vVowNcD/qdNpGUNnai/Kak9anHOYyppNo7zHgf3EO8zQ4NTQ +pkDKsdCqbUQcjGfhjWXKnOw+I5er4Rj+MwM1f5cbwb8bYHiSPmXaxzdL0/SNXGAA +ys/UswgwkU8CAwEAAaOBozCBoDAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBTPkPQW +DAPOIy8mipuEsZcP1694EDBxBgNVHSMEajBooVukWTBXMQswCQYDVQQGEwJBVTET +MBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQ +dHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEggkAqNQXaKXXTf0wDQYJKoZIhvcNAQEF +BQADgYEApaPbwonY8l+zSxlY2Fw4WNKfl5nwcTW4fuv/0tZLzvsS6P4hTXxbYJNa +k3hQ1qlrr8DiWJewF85hYvEI2F/7eqS5dhhPTEUFPpsjhbgiqnASvW+WKQIgoY2r +aHgOXi7RNFtTcCgk0UZISWOY7ORLy8Xu6vKrLRjDhyfIbGlqnAs= +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/trusted-interm.crt b/test/data/verificationcerts/trusted-interm.crt new file mode 100644 index 00000000..d577db7d --- /dev/null +++ b/test/data/verificationcerts/trusted-interm.crt @@ -0,0 +1,19 @@ +# trusted-interm.crt, signed by trusted.pem +-----BEGIN CERTIFICATE----- +MIIC8jCCAlugAwIBAgICEAcwDQYJKoZIhvcNAQEFBQAwVzELMAkGA1UEBhMCQVUx +EzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMg +UHR5IEx0ZDEQMA4GA1UEAxMHVFJVU1RFRDAgFw0xNTA2MjAwMTE4MjdaGA8yMTE1 +MDUyNzAxMTgyN1owfjELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUx +ITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UECxMLSU5U +RVJNIFVOSVQxITAfBgNVBAMTGE9SRyBXSVRIIElOVEVSTUVESUFURSBDQTCBnzAN +BgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAtRPNKgh4WdYGmU2Ae6Tf2Mbd3oaRI/uY +Qm6aKeYk1i7g41C0vVowNcD/qdNpGUNnai/Kak9anHOYyppNo7zHgf3EO8zQ4NTQ +pkDKsdCqbUQcjGfhjWXKnOw+I5er4Rj+MwM1f5cbwb8bYHiSPmXaxzdL0/SNXGAA +ys/UswgwkU8CAwEAAaOBozCBoDAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBTPkPQW +DAPOIy8mipuEsZcP1694EDBxBgNVHSMEajBooVukWTBXMQswCQYDVQQGEwJBVTET +MBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQ +dHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEggkAqNQXaKXXTf0wDQYJKoZIhvcNAQEF +BQADgYEApaPbwonY8l+zSxlY2Fw4WNKfl5nwcTW4fuv/0tZLzvsS6P4hTXxbYJNa +k3hQ1qlrr8DiWJewF85hYvEI2F/7eqS5dhhPTEUFPpsjhbgiqnASvW+WKQIgoY2r +aHgOXi7RNFtTcCgk0UZISWOY7ORLy8Xu6vKrLRjDhyfIbGlqnAs= +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/trusted.key b/test/data/verificationcerts/trusted.key new file mode 100644 index 00000000..3c26edf6 --- /dev/null +++ b/test/data/verificationcerts/trusted.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQC00Jf3KrBAmLQWl+Dz8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h +3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IRfwrYCtBE77UbxklSlrwn06j6YSotz0/d +wLEQEFDXWITJq7AyntaiafDHazbbXESNm/+I/YEl2wKemEHE//qWbeM9kwIDAQAB +AoGAVs2FBs1hi8FDQ01qWvGuzgt94MnACfxWw0xd6RY5OFUT25DqHxmb/7YVSIag +T/SS38osQ3zCA2s2FTkD7u5UX5AzJyqYJwmJhe6ZmaVly6IpebMxkX5w/hy15/N4 +uy+kzdtEBUUTNLL3DM7THkDYUxmeDzCBrHsMvYUqFgsBLOECQQDeNc1pDC++ovg5 +d9sKqMnEykBfvuvR6ra/343tYxy9zNFBvYjU3BA83MITIbEa/KtlSkIppz/K/jk5 +IRwSrwsJAkEA0E9aZfjDZbC9Z4oL7T8gtj2ftSh2g37KE5AWW2OxMJwrzoJ/6wjB +nG26ATlHEFP9bRzL2O1iovFLalqEjQo+uwJAMjtZXvjZRjATCvK0Onmjeu/5k2tW +ZdK4UzGXJOW11pYZa9ILv4qrxQZmfOqt3Zrmp/QcdswPGLVVfDum2/Zj+QJABJO5 +yMPOh0162+uMl4nrjhWMjM52zCzdA9EGrLtkCU1lKQR1CxUGLAm9LIm1pgYya1NW +p02P/USQA6Y5g1/WQQJBAIwl42Bebgaxl7dUbQX/vF+TryoCkM3B3eSM+P4XKB4f +kKSkNxvp59uq+b40gkoqEowhdq97y+pmrCxJHK43NJM= +-----END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/trusted.pem b/test/data/verificationcerts/trusted.pem new file mode 100644 index 00000000..8ebc0e5c --- /dev/null +++ b/test/data/verificationcerts/trusted.pem @@ -0,0 +1,15 @@ +# Self signed +-----BEGIN CERTIFICATE----- +MIICJzCCAZACCQCo1BdopddN/TANBgkqhkiG9w0BAQUFADBXMQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEMCAXDTE1MDYxOTE4MDEzMVoYDzIx +MTUwNTI2MTgwMTMxWjBXMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0 +ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRAwDgYDVQQDEwdU +UlVTVEVEMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC00Jf3KrBAmLQWl+Dz +8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IR +fwrYCtBE77UbxklSlrwn06j6YSotz0/dwLEQEFDXWITJq7AyntaiafDHazbbXESN +m/+I/YEl2wKemEHE//qWbeM9kwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAF0NREP3 +X+fTebzJGttzrFkDhGVFKRNyLXblXRVanlGOYF+q8grgZY2ufC/55gqf+ub6FRT5 +gKPhL4V2rqL8UAvCE7jq8ujpVfTB8kRAKC675W2DBZk2EJX9mjlr89t7qXGsI5nF +onpfJ1UtiJshNoV7h/NFHeoag91kx628807n +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/untrusted-chain.crt b/test/data/verificationcerts/untrusted-chain.crt new file mode 100644 index 00000000..272779d8 --- /dev/null +++ b/test/data/verificationcerts/untrusted-chain.crt @@ -0,0 +1,33 @@ +# untrusted.crt, signed by trusted-interm.crt +-----BEGIN CERTIFICATE----- +MIICYzCCAcwCAhAIMA0GCSqGSIb3DQEBBQUAMH4xCzAJBgNVBAYTAkFVMRMwEQYD +VQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBM +dGQxFDASBgNVBAsTC0lOVEVSTSBVTklUMSEwHwYDVQQDExhPUkcgV0lUSCBJTlRF +Uk1FRElBVEUgQ0EwIBcNMTUwNjIwMDEyMDI1WhgPMjExNTA1MjcwMTIwMjVaMHMx +CzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRl +cm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAsTCUxFQUYgVU5JVDEYMBYGA1UE +AxMPTk9UIFRSVVNURUQgT1JHMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDf +NZx/tugICrWGcpP8sa+EBX9WhazCsYIm8YgQrQO9B19dK7cHsWB+vIdFuDKHxfS2 +JBIeVSaZ6H4onWGnZRAMpi5xnitVhBQKCZP1yOewtrg2umZIbcTz8A+BwAcvmmQN +7RZMfpxN9PMccWDfgtAXsjZ2E47o9EfhpGvxfcFc0wIDAQABMA0GCSqGSIb3DQEB +BQUAA4GBABtmc8zn5efVi3iVIgODadKkTv43elIwNZBqEJ6IaoVXvi5Mp1m4VxML +LQGPTNG1lpuVDz2z/Ml78942316ailCTOx48oDnb/yy4jI6hsp+N8p6T28/Wvkbm +cCgohk6/Cwat5gf+HwoIe5Z3B3HRJaIcB0OteluuLsHAvverBjc4 +-----END CERTIFICATE----- +# untrusted-interm.crt, self-signed +-----BEGIN CERTIFICATE----- +MIICdTCCAd4CCQDRSKOnIMbTgDANBgkqhkiG9w0BAQUFADB+MQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMRQwEgYDVQQLEwtJTlRFUk0gVU5JVDEhMB8GA1UEAxMYT1JHIFdJ +VEggSU5URVJNRURJQVRFIENBMCAXDTE1MDYyMDAxMzY0M1oYDzIxMTUwNTI3MDEz +NjQzWjB+MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UE +ChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRQwEgYDVQQLEwtJTlRFUk0gVU5J +VDEhMB8GA1UEAxMYT1JHIFdJVEggSU5URVJNRURJQVRFIENBMIGfMA0GCSqGSIb3 +DQEBAQUAA4GNADCBiQKBgQC1E80qCHhZ1gaZTYB7pN/Yxt3ehpEj+5hCbpop5iTW +LuDjULS9WjA1wP+p02kZQ2dqL8pqT1qcc5jKmk2jvMeB/cQ7zNDg1NCmQMqx0Kpt +RByMZ+GNZcqc7D4jl6vhGP4zAzV/lxvBvxtgeJI+ZdrHN0vT9I1cYADKz9SzCDCR +TwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAGbObAMEajCz4kj7OP2/DB5SRy2+H/G3 +8Qvc43xlMMNQyYxsDuLOFL0UMRzoKgntrrm2nni8jND+tuMt+hv3ZlBcJlYJ6ynR +sC1ITTC/1SwwwO0AFIyduUEIJYr/B3sgcVYPLcEfeDZgmEQc9Tnc01aEu3lx2+l9 +0JTSPL2L9LdA +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/untrusted-interm.crt b/test/data/verificationcerts/untrusted-interm.crt new file mode 100644 index 00000000..875cdcd6 --- /dev/null +++ b/test/data/verificationcerts/untrusted-interm.crt @@ -0,0 +1,17 @@ +# untrusted-interm.crt, self-signed +-----BEGIN CERTIFICATE----- +MIICdTCCAd4CCQDRSKOnIMbTgDANBgkqhkiG9w0BAQUFADB+MQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMRQwEgYDVQQLEwtJTlRFUk0gVU5JVDEhMB8GA1UEAxMYT1JHIFdJ +VEggSU5URVJNRURJQVRFIENBMCAXDTE1MDYyMDAxMzY0M1oYDzIxMTUwNTI3MDEz +NjQzWjB+MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UE +ChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRQwEgYDVQQLEwtJTlRFUk0gVU5J +VDEhMB8GA1UEAxMYT1JHIFdJVEggSU5URVJNRURJQVRFIENBMIGfMA0GCSqGSIb3 +DQEBAQUAA4GNADCBiQKBgQC1E80qCHhZ1gaZTYB7pN/Yxt3ehpEj+5hCbpop5iTW +LuDjULS9WjA1wP+p02kZQ2dqL8pqT1qcc5jKmk2jvMeB/cQ7zNDg1NCmQMqx0Kpt +RByMZ+GNZcqc7D4jl6vhGP4zAzV/lxvBvxtgeJI+ZdrHN0vT9I1cYADKz9SzCDCR +TwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAGbObAMEajCz4kj7OP2/DB5SRy2+H/G3 +8Qvc43xlMMNQyYxsDuLOFL0UMRzoKgntrrm2nni8jND+tuMt+hv3ZlBcJlYJ6ynR +sC1ITTC/1SwwwO0AFIyduUEIJYr/B3sgcVYPLcEfeDZgmEQc9Tnc01aEu3lx2+l9 +0JTSPL2L9LdA +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/untrusted.crt b/test/data/verificationcerts/untrusted.crt new file mode 100644 index 00000000..2dab470b --- /dev/null +++ b/test/data/verificationcerts/untrusted.crt @@ -0,0 +1,16 @@ +# untrusted.crt, signed by trusted-interm.crt +-----BEGIN CERTIFICATE----- +MIICYzCCAcwCAhAIMA0GCSqGSIb3DQEBBQUAMH4xCzAJBgNVBAYTAkFVMRMwEQYD +VQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBM +dGQxFDASBgNVBAsTC0lOVEVSTSBVTklUMSEwHwYDVQQDExhPUkcgV0lUSCBJTlRF +Uk1FRElBVEUgQ0EwIBcNMTUwNjIwMDEyMDI1WhgPMjExNTA1MjcwMTIwMjVaMHMx +CzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRl +cm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAsTCUxFQUYgVU5JVDEYMBYGA1UE +AxMPTk9UIFRSVVNURUQgT1JHMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDf +NZx/tugICrWGcpP8sa+EBX9WhazCsYIm8YgQrQO9B19dK7cHsWB+vIdFuDKHxfS2 +JBIeVSaZ6H4onWGnZRAMpi5xnitVhBQKCZP1yOewtrg2umZIbcTz8A+BwAcvmmQN +7RZMfpxN9PMccWDfgtAXsjZ2E47o9EfhpGvxfcFc0wIDAQABMA0GCSqGSIb3DQEB +BQUAA4GBABtmc8zn5efVi3iVIgODadKkTv43elIwNZBqEJ6IaoVXvi5Mp1m4VxML +LQGPTNG1lpuVDz2z/Ml78942316ailCTOx48oDnb/yy4jI6hsp+N8p6T28/Wvkbm +cCgohk6/Cwat5gf+HwoIe5Z3B3HRJaIcB0OteluuLsHAvverBjc4 +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/verification-server.key b/test/data/verificationcerts/verification-server.key new file mode 100644 index 00000000..c527b09f --- /dev/null +++ b/test/data/verificationcerts/verification-server.key @@ -0,0 +1,16 @@ +# Key used for untrusted.crt, untrusted-chain.crt and trusted-chain.crt +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQDfNZx/tugICrWGcpP8sa+EBX9WhazCsYIm8YgQrQO9B19dK7cH +sWB+vIdFuDKHxfS2JBIeVSaZ6H4onWGnZRAMpi5xnitVhBQKCZP1yOewtrg2umZI +bcTz8A+BwAcvmmQN7RZMfpxN9PMccWDfgtAXsjZ2E47o9EfhpGvxfcFc0wIDAQAB +AoGAE4B9ofL7Jui4n3yXTXbA3QoV7BtV0tTriDeGKd7T+soQHPXa0gM/aRNTxlWn +pJE5JkjUhG3wJ3ZWv3mwtI1x718y0yL9uEgQJYsrNN+VJQwbGxXPio5SaG39gs+y +/8xklytMIgvuCXxmcfljemW9+PGT8otYlHeIU3wvHQennDECQQD2vWAEU9k02R9w +EkCM7mZEaW+WwrzyAD1NqatsVWErbNeXFPcHwU6y+DiDg2s5iEk89+xN2rX5mW2S +PF/2RpaNAkEA55YpZN5nN4P8yCYNz5mWN0kuSPytSgJ3fQY3BY2GkdIft/KcAuDV +1pf6jxubwP4vlamnZpqLfylbGdlRBoMY3wJBALQVE3cVG3qO3XsWVzaE6O8VZPRL +vUuDETsVkp/G0Ny428DQ9FscoyvMLrMNv7yF065D5JwN/LLnYClTF1bPviECQQCo +1BavO1eh6C3DN8K/wmb5PPdqLBKkrrGvSnWYLbmZ2sZW0p4blw8tVzRJWcYtZuEH +yVuJeEcT1/FbIcto5O+fAkASbZXZka3nm41wWNYg479Sl8I+qvtScfJgpyByYhCx +QaUAtZ791U+WNNHLqfZhSzP9lFZNRI0WNBSAy3SBR2Ur +-----END RSA PRIVATE KEY----- diff --git a/test/test_tcp.py b/test/test_tcp.py index 4253e073..52398ef3 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -183,52 +183,115 @@ class TestSSLv3Only(tservers.ServerTestBase): tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com") -class TestSSLUpstreamCertVerification(tservers.ServerTestBase): +class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/server.crt") - ) + cert=tutils.test_data.path("data/verificationcerts/untrusted.crt"), + key=tutils.test_data.path("data/verificationcerts/verification-server.key")) - def test_mode_default(self): + def test_mode_default_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() + # Verification errors should be saved even if connection isn't aborted + # aborted + assert c.ssl_verification_error is not None + testval = "echo!\n" c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval - def test_mode_none(self): + def test_mode_none_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(verify_options=SSL.VERIFY_NONE) + # Verification errors should be saved even if connection isn't aborted + assert c.ssl_verification_error is not None + testval = "echo!\n" c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval - def test_mode_strict_w_bad_cert(self): + def test_mode_strict_should_fail(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() tutils.raises( tcp.NetLibError, c.convert_to_ssl, - verify_options=SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, - ca_pemfile=tutils.test_data.path("data/not-server.crt")) + verify_options=SSL.VERIFY_PEER, + ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem")) + + assert c.ssl_verification_error is not None + + # Unknown issuing certificate authority for first certificate + assert c.ssl_verification_error['errno'] == 20 + assert c.ssl_verification_error['depth'] == 0 + + +class TestSSLUpstreamCertVerificationWBadCertChain(tservers.ServerTestBase): + handler = EchoHandler + + ssl = dict( + cert=tutils.test_data.path("data/verificationcerts/untrusted-chain.crt"), + key=tutils.test_data.path("data/verificationcerts/verification-server.key")) + + def test_mode_strict_should_fail(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + + tutils.raises( + "certificate verify failed", + c.convert_to_ssl, + verify_options=SSL.VERIFY_PEER, + ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem")) + + assert c.ssl_verification_error is not None + + # Untrusted self-signed certificate at second position in certificate + # chain + assert c.ssl_verification_error['errno'] == 19 + assert c.ssl_verification_error['depth'] == 1 - def test_mode_strict_w_cert(self): + +class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): + handler = EchoHandler + + ssl = dict( + cert=tutils.test_data.path("data/verificationcerts/trusted-chain.crt"), + key=tutils.test_data.path("data/verificationcerts/verification-server.key")) + + def test_mode_strict_w_pemfile_should_pass(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + + c.convert_to_ssl( + verify_options=SSL.VERIFY_PEER, + ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem")) + + assert c.ssl_verification_error is None + + testval = "echo!\n" + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval + + def test_mode_strict_w_cadir_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl( - verify_options=SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, - ca_pemfile=tutils.test_data.path("data/server.crt")) + verify_options=SSL.VERIFY_PEER, + ca_path=tutils.test_data.path("data/verificationcerts/")) + + assert c.ssl_verification_error is None testval = "echo!\n" c.wfile.write(testval) @@ -457,6 +520,7 @@ class TestALPNClient(tservers.ServerTestBase): assert c.get_alpn_proto_negotiated() == "" assert c.rfile.readline() == "NONE" + class TestNoSSLNoALPNClient(tservers.ServerTestBase): handler = ALPNHandler -- cgit v1.2.3 From 45c2ac2cf7646dc272e8d9717caefcec30b97456 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 23 Jun 2015 13:16:28 +1200 Subject: Travis notifications for Slack. --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 83fcc265..3b62c51b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -48,6 +48,7 @@ notifications: - "irc.oftc.net#mitmproxy" on_success: change on_failure: always + slack: mitmproxy:YaDGC9Gt9TEM7o8zkC2OLNsu # exclude cryptography from cache # it depends on libssl-dev version -- cgit v1.2.3 From 85b46cd88819a6a7243ddba7e1935482e7b4b271 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 23 Jun 2015 13:28:40 +1200 Subject: Refine travis. And, lest some meticulous code reader (I'm looking at you, Thomas) notices the extra colon: https://github.com/travis-ci/travis-ci/issues/2894 --- .travis.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 3b62c51b..d961b6c9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -48,7 +48,10 @@ notifications: - "irc.oftc.net#mitmproxy" on_success: change on_failure: always - slack: mitmproxy:YaDGC9Gt9TEM7o8zkC2OLNsu + slack: + mitmproxy:YaDGC9Gt9TEM7o8zkC2OLNsu + on_success: :change + on_failure: always # exclude cryptography from cache # it depends on libssl-dev version -- cgit v1.2.3 From 5588e57ca4f002f1e569847901616397282f6b3b Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 23 Jun 2015 13:51:08 +1200 Subject: Moar Travis. --- .travis.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index d961b6c9..4991e241 100644 --- a/.travis.yml +++ b/.travis.yml @@ -49,7 +49,8 @@ notifications: on_success: change on_failure: always slack: - mitmproxy:YaDGC9Gt9TEM7o8zkC2OLNsu + rooms: + - mitmproxy:YaDGC9Gt9TEM7o8zkC2OLNsu on_success: :change on_failure: always -- cgit v1.2.3 From 239f4758afa65995769e896d8f4faa9e12414d28 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 23 Jun 2015 22:16:03 +1200 Subject: Remove dependence on pathod in test suite. --- netlib/utils.py | 21 ++++++++++++++++++++- test/tutils.py | 3 +-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/netlib/utils.py b/netlib/utils.py index ac42bd53..bee412f9 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,5 +1,5 @@ from __future__ import (absolute_import, print_function, division) - +import os.path def isascii(s): try: @@ -110,3 +110,22 @@ def pretty_size(size): if x == int(x): x = int(x) return str(x) + suf + + +class Data(object): + def __init__(self, name): + m = __import__(name) + dirname, _ = os.path.split(m.__file__) + self.dirname = os.path.abspath(dirname) + + def path(self, path): + """ + Returns a path to the package data housed at 'path' under this + module.Path can be a path to a file, or to a directory. + + This function will raise ValueError if the path does not exist. + """ + fullpath = os.path.join(self.dirname, path) + if not os.path.exists(fullpath): + raise ValueError("dataPath: %s does not exist." % fullpath) + return fullpath diff --git a/test/tutils.py b/test/tutils.py index 95c8b80a..94139f6f 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -3,9 +3,8 @@ import tempfile import os import shutil from contextlib import contextmanager -from libpathod import utils -from netlib import tcp +from netlib import tcp, utils def treader(bytes): -- cgit v1.2.3 From 4766bce63d787888f9d4ed07e9ce0b63764b22d1 Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Tue, 23 Jun 2015 10:46:42 -0700 Subject: Adding test data to support post OpenSSL v1.0 cert hashing --- test/data/verificationcerts/8117bdb9.0 | 1 + 1 file changed, 1 insertion(+) create mode 120000 test/data/verificationcerts/8117bdb9.0 diff --git a/test/data/verificationcerts/8117bdb9.0 b/test/data/verificationcerts/8117bdb9.0 new file mode 120000 index 00000000..2f34cfaa --- /dev/null +++ b/test/data/verificationcerts/8117bdb9.0 @@ -0,0 +1 @@ +trusted.pem \ No newline at end of file -- cgit v1.2.3 From 41925b01f71831c33424d5cd9e612d003b99a69d Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 25 Jun 2015 10:37:01 +1200 Subject: Fix printing of SSL version error Fixes #73 --- netlib/version_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version_check.py b/netlib/version_check.py index 09dc23ae..df1612a2 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -33,7 +33,7 @@ def version_check( if v < pyopenssl_min_version: print( "You are using an outdated version of pyOpenSSL:" - " mitmproxy requires pyOpenSSL %x or greater." % + " mitmproxy requires pyOpenSSL %s or greater." % pyopenssl_min_version, file=fp ) -- cgit v1.2.3 From 2723a0e5739412953f60c37d0dab81d684ba5f26 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 26 Jun 2015 13:26:35 +0200 Subject: remove certffi --- netlib/certffi.py | 41 ----------------------------------------- netlib/certutils.py | 6 ------ setup.py | 33 +++------------------------------ test/test_certutils.py | 20 +------------------- test/test_tcp.py | 38 +------------------------------------- 5 files changed, 5 insertions(+), 133 deletions(-) delete mode 100644 netlib/certffi.py diff --git a/netlib/certffi.py b/netlib/certffi.py deleted file mode 100644 index 451f4493..00000000 --- a/netlib/certffi.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -from cffi import FFI -import OpenSSL - -xffi = FFI() -xffi.cdef(""" - struct rsa_meth_st { - int flags; - ...; - }; - struct rsa_st { - int pad; - long version; - struct rsa_meth_st *meth; - ...; - }; -""") -xffi.verify( - """#include """, - extra_compile_args=['-w'] -) - - -def handle(privkey): - new = xffi.new("struct rsa_st*") - newbuf = xffi.buffer(new) - rsa = OpenSSL.SSL._lib.EVP_PKEY_get1_RSA(privkey._pkey) - oldbuf = OpenSSL.SSL._ffi.buffer(rsa) - newbuf[:] = oldbuf[:] - return new - - -def set_flags(privkey, val): - hdl = handle(privkey) - hdl.meth.flags = val - return privkey - - -def get_flags(privkey): - hdl = handle(privkey) - return hdl.meth.flags diff --git a/netlib/certutils.py b/netlib/certutils.py index c6f0e628..c699af00 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -333,12 +333,6 @@ class CertStore(object): return entry.cert, entry.privatekey, entry.chain_file - def gen_pkey(self, cert_): - # FIXME: We should do something with cert here? - from . import certffi - certffi.set_flags(self.default_privatekey, 1) - return self.default_privatekey - class _GeneralName(univ.Choice): # We are only interested in dNSNames. We use a default handler to ignore diff --git a/setup.py b/setup.py index 3a1d7811..d51977ee 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,3 @@ -from distutils.command.build import build -from setuptools.command.install import install from setuptools import setup, find_packages from codecs import open import os @@ -15,25 +13,6 @@ here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, 'README.mkd'), encoding='utf-8') as f: long_description = f.read() - -def get_ext_modules(): - from netlib import certffi - return [certffi.xffi.verifier.get_extension()] - - -class CFFIBuild(build): - - def finalize_options(self): - self.distribution.ext_modules = get_ext_modules() - build.finalize_options(self) - - -class CFFIInstall(install): - - def finalize_options(self): - self.distribution.ext_modules = get_ext_modules() - install.finalize_options(self) - setup( name="netlib", version=version.VERSION, @@ -62,16 +41,12 @@ setup( include_package_data=True, zip_safe=False, install_requires=[ - "cffi", "pyasn1>=0.1.7", "pyOpenSSL>=0.15.1", "cryptography>=0.9", "passlib>=1.6.2", "hpack>=1.0.1", - "certifi"], - setup_requires=[ - "cffi", - "pyOpenSSL>=0.15.1", + "certifi" ], extras_require={ 'dev': [ @@ -84,9 +59,7 @@ setup( "wheel>=0.24.0", "pathod>=%s, <%s" % (version.MINORVERSION, - version.NEXT_MINORVERSION)]}, - cmdclass={ - "build": CFFIBuild, - "install": CFFIInstall, + version.NEXT_MINORVERSION) + ] }, ) diff --git a/test/test_certutils.py b/test/test_certutils.py index e079ec40..50df36ae 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -1,5 +1,5 @@ import os -from netlib import certutils, certffi +from netlib import certutils import tutils # class TestDNTree: @@ -92,24 +92,6 @@ class TestCertStore: ret = ca1.get_cert("foo.com", []) assert ret[0].serial == dc[0].serial - def test_gen_pkey(self): - try: - with tutils.tmpdir() as d: - ca1 = certutils.CertStore.from_store( - os.path.join( - d, - "ca1"), - "test") - ca2 = certutils.CertStore.from_store( - os.path.join( - d, - "ca2"), - "test") - cert = ca1.get_cert("foo.com", []) - assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1 - finally: - certffi.set_flags(ca2.default_privatekey, 0) - class TestDummyCert: diff --git a/test/test_tcp.py b/test/test_tcp.py index 52398ef3..8a3299b6 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -10,7 +10,7 @@ import mock from OpenSSL import SSL import OpenSSL -from netlib import tcp, certutils, certffi +from netlib import tcp, certutils from . import tutils, tservers @@ -566,42 +566,6 @@ class TestDHParams(tservers.ServerTestBase): assert os.path.exists(filename) -class TestPrivkeyGen(tservers.ServerTestBase): - - class handler(tcp.BaseHandler): - - def handle(self): - with tutils.tmpdir() as d: - ca1 = certutils.CertStore.from_store(d, "test2") - ca2 = certutils.CertStore.from_store(d, "test3") - cert, _, _ = ca1.get_cert("foo.com", []) - key = ca2.gen_pkey(cert) - self.convert_to_ssl(cert, key) - - def test_privkey(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - tutils.raises("bad record mac", c.convert_to_ssl) - - -class TestPrivkeyGenNoFlags(tservers.ServerTestBase): - - class handler(tcp.BaseHandler): - - def handle(self): - with tutils.tmpdir() as d: - ca1 = certutils.CertStore.from_store(d, "test2") - ca2 = certutils.CertStore.from_store(d, "test3") - cert, _, _ = ca1.get_cert("foo.com", []) - certffi.set_flags(ca2.default_privatekey, 0) - self.convert_to_ssl(cert, ca2.default_privatekey) - - def test_privkey(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - tutils.raises("sslv3 alert handshake failure", c.convert_to_ssl) - - class TestTCPClient: def test_conerr(self): -- cgit v1.2.3 From 2fb3d6caed229ac1900561e2f1a5289694fd2dd3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 26 Jun 2015 18:10:04 +0200 Subject: add appveyor --- .appveyor.yml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .appveyor.yml diff --git a/.appveyor.yml b/.appveyor.yml new file mode 100644 index 00000000..785e7dcc --- /dev/null +++ b/.appveyor.yml @@ -0,0 +1,7 @@ +shallow_clone: true +install: + - ps: "pip install --src . -r requirements.txt" + - ps: "python -c 'from OpenSSL import SSL; print SSL.SSLeay_version(SSL.SSLEAY_VERSION)'" +build: off # Not a C# project +test_script: + - ps: "nosetests --with-cov --cov-report term-missing" \ No newline at end of file -- cgit v1.2.3 From 93e515c02ff811a1e721c4a4e89396382576ce86 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 26 Jun 2015 18:24:33 +0200 Subject: appveyor: use explicit python version --- .appveyor.yml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/.appveyor.yml b/.appveyor.yml index 785e7dcc..ef0e59d7 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,7 +1,10 @@ shallow_clone: true +environment: + matrix: + - PYTHON: "C:\\Python27" install: - - ps: "pip install --src . -r requirements.txt" - - ps: "python -c 'from OpenSSL import SSL; print SSL.SSLeay_version(SSL.SSLEAY_VERSION)'" + - "%PYTHON%\\Scripts\\pip install --src . -r requirements.txt" + - "%PYTHON%\\python -c \"from OpenSSL import SSL; print SSL.SSLeay_version(SSL.SSLEAY_VERSION)\"" build: off # Not a C# project test_script: - - ps: "nosetests --with-cov --cov-report term-missing" \ No newline at end of file + - "%PYTHON%\\Scripts\nosetests --with-cov --cov-report term-missing" \ No newline at end of file -- cgit v1.2.3 From 5b02d5417ab48ec65ee51b2b04bfc15643a5bad0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 26 Jun 2015 18:36:28 +0200 Subject: appveyor: minor fixes --- .appveyor.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.appveyor.yml b/.appveyor.yml index ef0e59d7..4e690c06 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -1,10 +1,11 @@ +version: '{build}' shallow_clone: true environment: matrix: - PYTHON: "C:\\Python27" install: - "%PYTHON%\\Scripts\\pip install --src . -r requirements.txt" - - "%PYTHON%\\python -c \"from OpenSSL import SSL; print SSL.SSLeay_version(SSL.SSLEAY_VERSION)\"" + - "%PYTHON%\\python -c \"from OpenSSL import SSL; print(SSL.SSLeay_version(SSL.SSLEAY_VERSION))\"" build: off # Not a C# project test_script: - - "%PYTHON%\\Scripts\nosetests --with-cov --cov-report term-missing" \ No newline at end of file + - "%PYTHON%\\Scripts\\nosetests" \ No newline at end of file -- cgit v1.2.3 From 74c50d24eb8fe6630486a8ef839fe9319cf53b2a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 26 Jun 2015 19:21:35 +0200 Subject: fix tests on windows --- test/data/verificationcerts/8117bdb9.0 | 16 +++++++++++++++- test/data/verificationcerts/9d45e6a9.0 | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/test/data/verificationcerts/8117bdb9.0 b/test/data/verificationcerts/8117bdb9.0 index 2f34cfaa..8ebc0e5c 120000 --- a/test/data/verificationcerts/8117bdb9.0 +++ b/test/data/verificationcerts/8117bdb9.0 @@ -1 +1,15 @@ -trusted.pem \ No newline at end of file +# Self signed +-----BEGIN CERTIFICATE----- +MIICJzCCAZACCQCo1BdopddN/TANBgkqhkiG9w0BAQUFADBXMQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEMCAXDTE1MDYxOTE4MDEzMVoYDzIx +MTUwNTI2MTgwMTMxWjBXMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0 +ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRAwDgYDVQQDEwdU +UlVTVEVEMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC00Jf3KrBAmLQWl+Dz +8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IR +fwrYCtBE77UbxklSlrwn06j6YSotz0/dwLEQEFDXWITJq7AyntaiafDHazbbXESN +m/+I/YEl2wKemEHE//qWbeM9kwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAF0NREP3 +X+fTebzJGttzrFkDhGVFKRNyLXblXRVanlGOYF+q8grgZY2ufC/55gqf+ub6FRT5 +gKPhL4V2rqL8UAvCE7jq8ujpVfTB8kRAKC675W2DBZk2EJX9mjlr89t7qXGsI5nF +onpfJ1UtiJshNoV7h/NFHeoag91kx628807n +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/9d45e6a9.0 b/test/data/verificationcerts/9d45e6a9.0 index 2f34cfaa..8ebc0e5c 120000 --- a/test/data/verificationcerts/9d45e6a9.0 +++ b/test/data/verificationcerts/9d45e6a9.0 @@ -1 +1,15 @@ -trusted.pem \ No newline at end of file +# Self signed +-----BEGIN CERTIFICATE----- +MIICJzCCAZACCQCo1BdopddN/TANBgkqhkiG9w0BAQUFADBXMQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEMCAXDTE1MDYxOTE4MDEzMVoYDzIx +MTUwNTI2MTgwMTMxWjBXMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0 +ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRAwDgYDVQQDEwdU +UlVTVEVEMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC00Jf3KrBAmLQWl+Dz +8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IR +fwrYCtBE77UbxklSlrwn06j6YSotz0/dwLEQEFDXWITJq7AyntaiafDHazbbXESN +m/+I/YEl2wKemEHE//qWbeM9kwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAF0NREP3 +X+fTebzJGttzrFkDhGVFKRNyLXblXRVanlGOYF+q8grgZY2ufC/55gqf+ub6FRT5 +gKPhL4V2rqL8UAvCE7jq8ujpVfTB8kRAKC675W2DBZk2EJX9mjlr89t7qXGsI5nF +onpfJ1UtiJshNoV7h/NFHeoag91kx628807n +-----END CERTIFICATE----- -- cgit v1.2.3 From 26ea1a065e1648029a1eed578c18113e47f093f9 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 26 Jun 2015 19:23:12 +0200 Subject: fix file type --- test/data/verificationcerts/8117bdb9.0 | 0 test/data/verificationcerts/9d45e6a9.0 | 0 2 files changed, 0 insertions(+), 0 deletions(-) mode change 120000 => 100644 test/data/verificationcerts/8117bdb9.0 mode change 120000 => 100644 test/data/verificationcerts/9d45e6a9.0 diff --git a/test/data/verificationcerts/8117bdb9.0 b/test/data/verificationcerts/8117bdb9.0 deleted file mode 120000 index 8ebc0e5c..00000000 --- a/test/data/verificationcerts/8117bdb9.0 +++ /dev/null @@ -1,15 +0,0 @@ -# Self signed ------BEGIN CERTIFICATE----- -MIICJzCCAZACCQCo1BdopddN/TANBgkqhkiG9w0BAQUFADBXMQswCQYDVQQGEwJB -VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 -cyBQdHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEMCAXDTE1MDYxOTE4MDEzMVoYDzIx -MTUwNTI2MTgwMTMxWjBXMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0 -ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRAwDgYDVQQDEwdU -UlVTVEVEMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC00Jf3KrBAmLQWl+Dz -8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IR -fwrYCtBE77UbxklSlrwn06j6YSotz0/dwLEQEFDXWITJq7AyntaiafDHazbbXESN -m/+I/YEl2wKemEHE//qWbeM9kwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAF0NREP3 -X+fTebzJGttzrFkDhGVFKRNyLXblXRVanlGOYF+q8grgZY2ufC/55gqf+ub6FRT5 -gKPhL4V2rqL8UAvCE7jq8ujpVfTB8kRAKC675W2DBZk2EJX9mjlr89t7qXGsI5nF -onpfJ1UtiJshNoV7h/NFHeoag91kx628807n ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/8117bdb9.0 b/test/data/verificationcerts/8117bdb9.0 new file mode 100644 index 00000000..8ebc0e5c --- /dev/null +++ b/test/data/verificationcerts/8117bdb9.0 @@ -0,0 +1,15 @@ +# Self signed +-----BEGIN CERTIFICATE----- +MIICJzCCAZACCQCo1BdopddN/TANBgkqhkiG9w0BAQUFADBXMQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEMCAXDTE1MDYxOTE4MDEzMVoYDzIx +MTUwNTI2MTgwMTMxWjBXMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0 +ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRAwDgYDVQQDEwdU +UlVTVEVEMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC00Jf3KrBAmLQWl+Dz +8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IR +fwrYCtBE77UbxklSlrwn06j6YSotz0/dwLEQEFDXWITJq7AyntaiafDHazbbXESN +m/+I/YEl2wKemEHE//qWbeM9kwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAF0NREP3 +X+fTebzJGttzrFkDhGVFKRNyLXblXRVanlGOYF+q8grgZY2ufC/55gqf+ub6FRT5 +gKPhL4V2rqL8UAvCE7jq8ujpVfTB8kRAKC675W2DBZk2EJX9mjlr89t7qXGsI5nF +onpfJ1UtiJshNoV7h/NFHeoag91kx628807n +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/9d45e6a9.0 b/test/data/verificationcerts/9d45e6a9.0 deleted file mode 120000 index 8ebc0e5c..00000000 --- a/test/data/verificationcerts/9d45e6a9.0 +++ /dev/null @@ -1,15 +0,0 @@ -# Self signed ------BEGIN CERTIFICATE----- -MIICJzCCAZACCQCo1BdopddN/TANBgkqhkiG9w0BAQUFADBXMQswCQYDVQQGEwJB -VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 -cyBQdHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEMCAXDTE1MDYxOTE4MDEzMVoYDzIx -MTUwNTI2MTgwMTMxWjBXMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0 -ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRAwDgYDVQQDEwdU -UlVTVEVEMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC00Jf3KrBAmLQWl+Dz -8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IR -fwrYCtBE77UbxklSlrwn06j6YSotz0/dwLEQEFDXWITJq7AyntaiafDHazbbXESN -m/+I/YEl2wKemEHE//qWbeM9kwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAF0NREP3 -X+fTebzJGttzrFkDhGVFKRNyLXblXRVanlGOYF+q8grgZY2ufC/55gqf+ub6FRT5 -gKPhL4V2rqL8UAvCE7jq8ujpVfTB8kRAKC675W2DBZk2EJX9mjlr89t7qXGsI5nF -onpfJ1UtiJshNoV7h/NFHeoag91kx628807n ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/9d45e6a9.0 b/test/data/verificationcerts/9d45e6a9.0 new file mode 100644 index 00000000..8ebc0e5c --- /dev/null +++ b/test/data/verificationcerts/9d45e6a9.0 @@ -0,0 +1,15 @@ +# Self signed +-----BEGIN CERTIFICATE----- +MIICJzCCAZACCQCo1BdopddN/TANBgkqhkiG9w0BAQUFADBXMQswCQYDVQQGEwJB +VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEMCAXDTE1MDYxOTE4MDEzMVoYDzIx +MTUwNTI2MTgwMTMxWjBXMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0 +ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRAwDgYDVQQDEwdU +UlVTVEVEMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC00Jf3KrBAmLQWl+Dz +8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IR +fwrYCtBE77UbxklSlrwn06j6YSotz0/dwLEQEFDXWITJq7AyntaiafDHazbbXESN +m/+I/YEl2wKemEHE//qWbeM9kwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAF0NREP3 +X+fTebzJGttzrFkDhGVFKRNyLXblXRVanlGOYF+q8grgZY2ufC/55gqf+ub6FRT5 +gKPhL4V2rqL8UAvCE7jq8ujpVfTB8kRAKC675W2DBZk2EJX9mjlr89t7qXGsI5nF +onpfJ1UtiJshNoV7h/NFHeoag91kx628807n +-----END CERTIFICATE----- -- cgit v1.2.3 From 8ca103cba5aa0e64ca81477dee6a74a183548336 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 26 Jun 2015 23:43:08 +0200 Subject: synchronize metadata files across projects --- .landscape.yml | 2 +- .travis.yml | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.landscape.yml b/.landscape.yml index ccaa5fc3..9a3b615f 100644 --- a/.landscape.yml +++ b/.landscape.yml @@ -13,4 +13,4 @@ pylint: - too-many-public-methods - too-many-return-statements - too-many-statements - - unpacking-non-sequence + - unpacking-non-sequence \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 4991e241..9fd4fbd9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -37,7 +37,6 @@ before_script: script: - "nosetests --with-cov --cov-report term-missing" - - "./check_coding_style.sh" after_success: - coveralls @@ -48,7 +47,7 @@ notifications: - "irc.oftc.net#mitmproxy" on_success: change on_failure: always - slack: + slack: rooms: - mitmproxy:YaDGC9Gt9TEM7o8zkC2OLNsu on_success: :change -- cgit v1.2.3 From 0a2b25187faea1fa29a3b21935cd55294b173bf8 Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Fri, 26 Jun 2015 14:57:00 -0700 Subject: Fixing how certifi is made the default ca_path to simplify calling logic. --- netlib/tcp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 74a275c9..38b77c9e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -390,7 +390,7 @@ class _Connection(object): method=SSL_DEFAULT_METHOD, options=SSL_DEFAULT_OPTIONS, verify_options=SSL.VERIFY_NONE, - ca_path=certifi.where(), + ca_path=None, ca_pemfile=None, cipher_list=None, alpn_protos=None, @@ -421,6 +421,8 @@ class _Connection(object): return is_cert_verified context.set_verify(verify_options, verify_cert) + if ca_path is None and ca_pemfile is None: + ca_path = certifi.where() context.load_verify_locations(ca_pemfile, ca_path) # Workaround for -- cgit v1.2.3 From 9aaf10120d08e12e7aa82fc2184ca7faa35349c3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 3 Jul 2015 02:01:30 +0200 Subject: socks: add assert_socks5 method --- netlib/socks.py | 41 ++++++++++++++++++++++++++++++++------- test/test_socks.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/netlib/socks.py b/netlib/socks.py index 5a73c61a..eef98f5c 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -6,7 +6,6 @@ from . import tcp, utils class SocksError(Exception): - def __init__(self, code, message): super(SocksError, self).__init__(message) self.code = code @@ -17,21 +16,18 @@ VERSION = utils.BiDi( SOCKS5=0x05 ) - CMD = utils.BiDi( CONNECT=0x01, BIND=0x02, UDP_ASSOCIATE=0x03 ) - ATYP = utils.BiDi( IPV4_ADDRESS=0x01, DOMAINNAME=0x03, IPV6_ADDRESS=0x04 ) - REP = utils.BiDi( SUCCEEDED=0x00, GENERAL_SOCKS_SERVER_FAILURE=0x01, @@ -44,7 +40,6 @@ REP = utils.BiDi( ADDRESS_TYPE_NOT_SUPPORTED=0x08, ) - METHOD = utils.BiDi( NO_AUTHENTICATION_REQUIRED=0x00, GSSAPI=0x01, @@ -58,14 +53,27 @@ class ClientGreeting(object): def __init__(self, ver, methods): self.ver = ver - self.methods = methods + self.methods = array.array("B") + self.methods.extend(methods) + + def assert_socks5(self): + if self.ver != VERSION.SOCKS5: + if self.ver == ord("G") and len(self.methods) == ord("E"): + guess = "Probably not a SOCKS request but a regular HTTP request. " + else: + guess = "" + + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver + ) @classmethod def from_file(cls, f): ver, nmethods = struct.unpack("!BB", f.safe_read(2)) methods = array.array("B") methods.fromstring(f.safe_read(nmethods)) - return cls(ver, methods) + return cls(ver, methods.tolist()) def to_file(self, f): f.write(struct.pack("!BB", self.ver, len(self.methods))) @@ -79,6 +87,18 @@ class ServerGreeting(object): self.ver = ver self.method = method + def assert_socks5(self): + if self.ver != VERSION.SOCKS5: + if self.ver == ord("H") and self.method == ord("T"): + guess = "Probably not a SOCKS request but a regular HTTP response. " + else: + guess = "" + + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver + ) + @classmethod def from_file(cls, f): ver, method = struct.unpack("!BB", f.safe_read(2)) @@ -97,6 +117,13 @@ class Message(object): self.atyp = atyp self.addr = addr + def assert_socks5(self): + if self.ver != VERSION.SOCKS5: + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver + ) + @classmethod def from_file(cls, f): ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) diff --git a/test/test_socks.py b/test/test_socks.py index a9db4706..eb5d55f9 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -9,6 +9,7 @@ def test_client_greeting(): raw = tutils.treader("\x05\x02\x00\xBE\xEF") out = StringIO() msg = socks.ClientGreeting.from_file(raw) + msg.assert_socks5() msg.to_file(out) assert out.getvalue() == raw.getvalue()[:-1] @@ -18,10 +19,37 @@ def test_client_greeting(): assert 0xEF not in msg.methods +def test_client_greeting_assert_socks5(): + raw = tutils.treader("\x00\x00") + msg = socks.ClientGreeting.from_file(raw) + tutils.raises(socks.SocksError, msg.assert_socks5) + + raw = tutils.treader("HTTP/1.1 200 OK" + " " * 100) + msg = socks.ClientGreeting.from_file(raw) + try: + msg.assert_socks5() + except socks.SocksError as e: + assert "Invalid SOCKS version" in str(e) + assert "HTTP" not in str(e) + else: + assert False + + raw = tutils.treader("GET / HTTP/1.1" + " " * 100) + msg = socks.ClientGreeting.from_file(raw) + try: + msg.assert_socks5() + except socks.SocksError as e: + assert "Invalid SOCKS version" in str(e) + assert "HTTP" in str(e) + else: + assert False + + def test_server_greeting(): raw = tutils.treader("\x05\x02") out = StringIO() msg = socks.ServerGreeting.from_file(raw) + msg.assert_socks5() msg.to_file(out) assert out.getvalue() == raw.getvalue() @@ -29,10 +57,33 @@ def test_server_greeting(): assert msg.method == 0x02 +def test_server_greeting_assert_socks5(): + raw = tutils.treader("HTTP/1.1 200 OK" + " " * 100) + msg = socks.ServerGreeting.from_file(raw) + try: + msg.assert_socks5() + except socks.SocksError as e: + assert "Invalid SOCKS version" in str(e) + assert "HTTP" in str(e) + else: + assert False + + raw = tutils.treader("GET / HTTP/1.1" + " " * 100) + msg = socks.ServerGreeting.from_file(raw) + try: + msg.assert_socks5() + except socks.SocksError as e: + assert "Invalid SOCKS version" in str(e) + assert "HTTP" not in str(e) + else: + assert False + + def test_message(): raw = tutils.treader("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") out = StringIO() msg = socks.Message.from_file(raw) + msg.assert_socks5() assert raw.read(2) == "\xBE\xEF" msg.to_file(out) @@ -43,6 +94,12 @@ def test_message(): assert msg.addr == ("example.com", 0xDEAD) +def test_message_assert_socks5(): + raw = tutils.treader("\xEE\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") + msg = socks.Message.from_file(raw) + tutils.raises(socks.SocksError, msg.assert_socks5) + + def test_message_ipv4(): # Test ATYP=0x01 (IPV4) raw = tutils.treader("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") -- cgit v1.2.3 From 880c66fe48c5a6bb4779a8149a3551f007ff5b09 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 3 Jul 2015 02:45:12 +0200 Subject: socks: optionally fail early --- netlib/socks.py | 15 ++++++++++----- test/test_socks.py | 3 +++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/netlib/socks.py b/netlib/socks.py index eef98f5c..d38b88c8 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -69,11 +69,16 @@ class ClientGreeting(object): ) @classmethod - def from_file(cls, f): + def from_file(cls, f, fail_early=False): + """ + :param fail_early: If true, a SocksError will be raised if the first byte does not indicate socks5. + """ ver, nmethods = struct.unpack("!BB", f.safe_read(2)) - methods = array.array("B") - methods.fromstring(f.safe_read(nmethods)) - return cls(ver, methods.tolist()) + client_greeting = cls(ver, []) + if fail_early: + client_greeting.assert_socks5() + client_greeting.methods.fromstring(f.safe_read(nmethods)) + return client_greeting def to_file(self, f): f.write(struct.pack("!BB", self.ver, len(self.methods))) @@ -115,7 +120,7 @@ class Message(object): self.ver = ver self.msg = msg self.atyp = atyp - self.addr = addr + self.addr = tcp.Address.wrap(addr) def assert_socks5(self): if self.ver != VERSION.SOCKS5: diff --git a/test/test_socks.py b/test/test_socks.py index eb5d55f9..1b6c2a32 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -44,6 +44,9 @@ def test_client_greeting_assert_socks5(): else: assert False + raw = tutils.treader("XX") + tutils.raises(socks.SocksError, socks.ClientGreeting.from_file, raw, fail_early=True) + def test_server_greeting(): raw = tutils.treader("\x05\x02") -- cgit v1.2.3 From 397b3bba5e718da8fca7131d5e1823c4ce5363ca Mon Sep 17 00:00:00 2001 From: "M. Utku Altinkaya" Date: Tue, 21 Jul 2015 13:17:46 +0300 Subject: Fixed version error formatting issue --- netlib/version_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version_check.py b/netlib/version_check.py index df1612a2..2081c410 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -34,7 +34,7 @@ def version_check( print( "You are using an outdated version of pyOpenSSL:" " mitmproxy requires pyOpenSSL %s or greater." % - pyopenssl_min_version, + str(pyopenssl_min_version), file=fp ) # Some users apparently have multiple versions of pyOpenSSL installed. -- cgit v1.2.3 From 9fdc412fa043072f44eddec0b07659c161e4ca90 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 22 Jul 2015 00:17:05 +0200 Subject: bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index bc9a1a57..ba426d74 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 12, 2) +IVERSION = (0, 13) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 155bdeb12352065bc36256ba8014003480361a0c Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Tue, 21 Jul 2015 18:01:51 -0700 Subject: Fixing default CA which ought to be read as a pemfile and not a directory --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 38b77c9e..47ce8c0e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -422,7 +422,7 @@ class _Connection(object): context.set_verify(verify_options, verify_cert) if ca_path is None and ca_pemfile is None: - ca_path = certifi.where() + ca_pemfile = certifi.where() context.load_verify_locations(ca_pemfile, ca_path) # Workaround for -- cgit v1.2.3 From c17af4162b5a2946c4bf53bf1d17fca41dc68da7 Mon Sep 17 00:00:00 2001 From: Kyle Morton Date: Tue, 21 Jul 2015 19:06:20 -0700 Subject: Added a fix for pre-1.0 OpenSSL which wasn't correctly erring on failed certificate validation --- netlib/tcp.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 47ce8c0e..5c4094d7 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -518,6 +518,13 @@ class TCPClient(_Connection): self.connection.do_handshake() except SSL.Error as v: raise NetLibError("SSL handshake error: %s" % repr(v)) + + # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on + # certificate validation failure + verification_mode = sslctx_kwargs.get('verify_options', None) + if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER: + raise NetLibError("SSL handshake error: certificate verify failed") + self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) -- cgit v1.2.3 From e316a9cdb44444667e26938f8c1c3969e56c2f0e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 22 Jul 2015 13:39:48 +0200 Subject: bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index ba426d74..de42ace1 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 13) +IVERSION = (0, 13, 1) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 6dcfc35011208f4bfde7f37a63d7b980f6c41ce0 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 8 Jul 2015 09:20:25 +0200 Subject: introduce http_semantics module used for generic HTTP representation everything should apply for HTTP/1 and HTTP/2 --- netlib/http.py | 16 ++-------------- netlib/http_semantics.py | 23 +++++++++++++++++++++++ test/test_http.py | 14 +++++++------- 3 files changed, 32 insertions(+), 21 deletions(-) create mode 100644 netlib/http_semantics.py diff --git a/netlib/http.py b/netlib/http.py index a2af9e49..073e9a3f 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -4,7 +4,7 @@ import string import urlparse import binascii import sys -from . import odict, utils, tcp, http_status +from . import odict, utils, tcp, http_semantics, http_status class HttpError(Exception): @@ -527,18 +527,6 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): ) -Response = collections.namedtuple( - "Response", - [ - "httpversion", - "code", - "msg", - "headers", - "content" - ] -) - - def read_response(rfile, request_method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. @@ -580,7 +568,7 @@ def read_response(rfile, request_method, body_size_limit, include_body=True): # if include_body==False then a None content means the body should be # read separately content = None - return Response(httpversion, code, msg, headers, content) + return http_semantics.Response(httpversion, code, msg, headers, content) def request_preamble(method, resource, http_major="1", http_minor="1"): diff --git a/netlib/http_semantics.py b/netlib/http_semantics.py new file mode 100644 index 00000000..e8313e3c --- /dev/null +++ b/netlib/http_semantics.py @@ -0,0 +1,23 @@ +class Response(object): + + def __init__( + self, + httpversion, + status_code, + msg, + headers, + content, + sslinfo=None, + ): + self.httpversion = httpversion + self.status_code = status_code + self.msg = msg + self.headers = headers + self.content = content + self.sslinfo = sslinfo + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Response(%s - %s)" % (self.status_code, self.msg) diff --git a/test/test_http.py b/test/test_http.py index 2ad81d24..bbc78847 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -1,7 +1,7 @@ import cStringIO import textwrap import binascii -from netlib import http, odict, tcp +from netlib import http, http_semantics, odict, tcp from . import tutils, tservers @@ -307,13 +307,13 @@ def test_read_response(): data = """ HTTP/1.1 200 OK """ - assert tst(data, "GET", None) == ( + assert tst(data, "GET", None) == http_semantics.Response( (1, 1), 200, 'OK', odict.ODictCaseless(), '' ) data = """ HTTP/1.1 200 """ - assert tst(data, "GET", None) == ( + assert tst(data, "GET", None) == http_semantics.Response( (1, 1), 200, '', odict.ODictCaseless(), '' ) data = """ @@ -330,7 +330,7 @@ def test_read_response(): HTTP/1.1 200 OK """ - assert tst(data, "GET", None) == ( + assert tst(data, "GET", None) == http_semantics.Response( (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' ) @@ -340,8 +340,8 @@ def test_read_response(): foo """ - assert tst(data, "GET", None)[4] == 'foo' - assert tst(data, "HEAD", None)[4] == '' + assert tst(data, "GET", None).content == 'foo' + assert tst(data, "HEAD", None).content == '' data = """ HTTP/1.1 200 OK @@ -357,7 +357,7 @@ def test_read_response(): foo """ - assert tst(data, "GET", None, include_body=False)[4] is None + assert tst(data, "GET", None, include_body=False).content is None def test_parse_url(): -- cgit v1.2.3 From bd5ee212840e3be731ea93e14ef1375745383d88 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 8 Jul 2015 09:34:10 +0200 Subject: refactor websockets into protocol --- netlib/websockets.py | 381 ------------------------------------------ netlib/websockets/__init__.py | 2 + netlib/websockets/frame.py | 288 +++++++++++++++++++++++++++++++ netlib/websockets/protocol.py | 111 ++++++++++++ test/test_websockets.py | 31 ++-- 5 files changed, 419 insertions(+), 394 deletions(-) delete mode 100644 netlib/websockets.py create mode 100644 netlib/websockets/__init__.py create mode 100644 netlib/websockets/frame.py create mode 100644 netlib/websockets/protocol.py diff --git a/netlib/websockets.py b/netlib/websockets.py deleted file mode 100644 index c45db4df..00000000 --- a/netlib/websockets.py +++ /dev/null @@ -1,381 +0,0 @@ -from __future__ import absolute_import -import base64 -import hashlib -import os -import struct -import io - -from . import utils, odict, tcp - -# Colleciton of utility functions that implement small portions of the RFC6455 -# WebSockets Protocol Useful for building WebSocket clients and servers. -# -# Emphassis is on readabilty, simplicity and modularity, not performance or -# completeness -# -# This is a work in progress and does not yet contain all the utilites need to -# create fully complient client/servers # -# Spec: https://tools.ietf.org/html/rfc6455 - -# The magic sha that websocket servers must know to prove they understand -# RFC6455 -websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -VERSION = "13" -MAX_16_BIT_INT = (1 << 16) -MAX_64_BIT_INT = (1 << 64) - - -OPCODE = utils.BiDi( - CONTINUE=0x00, - TEXT=0x01, - BINARY=0x02, - CLOSE=0x08, - PING=0x09, - PONG=0x0a -) - - -class Masker(object): - - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - - def __init__(self, key): - self.key = key - self.masks = [utils.bytes_to_int(byte) for byte in key] - self.offset = 0 - - def mask(self, offset, data): - result = "" - for c in data: - result += chr(ord(c) ^ self.masks[offset % 4]) - offset += 1 - return result - - def __call__(self, data): - ret = self.mask(self.offset, data) - self.offset += len(ret) - return ret - - -def client_handshake_headers(key=None, version=VERSION): - """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. - - Returns an instance of ODictCaseless - """ - if not key: - key = base64.b64encode(os.urandom(16)).decode('utf-8') - return odict.ODictCaseless([ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Key', key), - ('Sec-WebSocket-Version', version) - ]) - - -def server_handshake_headers(key): - """ - The server response is a valid HTTP 101 response. - """ - return odict.ODictCaseless( - [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nonce(key)) - ] - ) - - -def make_length_code(length): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - if length <= 125: - return length - elif length >= 126 and length <= 65535: - return 126 - else: - return 127 - - -def check_client_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-key') - - -def check_server_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-accept') - - -def create_server_nonce(client_nonce): - return base64.b64encode( - hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') - ) - - -DEFAULT = object() - - -class FrameHeader(object): - - def __init__( - self, - opcode=OPCODE.TEXT, - payload_length=0, - fin=False, - rsv1=False, - rsv2=False, - rsv3=False, - masking_key=DEFAULT, - mask=DEFAULT, - length_code=DEFAULT - ): - if not 0 <= opcode < 2 ** 4: - raise ValueError("opcode must be 0-16") - self.opcode = opcode - self.payload_length = payload_length - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - - if length_code is DEFAULT: - self.length_code = make_length_code(self.payload_length) - else: - self.length_code = length_code - - if mask is DEFAULT and masking_key is DEFAULT: - self.mask = False - self.masking_key = "" - elif mask is DEFAULT: - self.mask = 1 - self.masking_key = masking_key - elif masking_key is DEFAULT: - self.mask = mask - self.masking_key = os.urandom(4) - else: - self.mask = mask - self.masking_key = masking_key - - if self.masking_key and len(self.masking_key) != 4: - raise ValueError("Masking key must be 4 bytes.") - - def human_readable(self): - vals = [ - "ws frame:", - OPCODE.get_name(self.opcode, hex(self.opcode)).lower() - ] - flags = [] - for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: - if getattr(self, i): - flags.append(i) - if flags: - vals.extend([":", "|".join(flags)]) - if self.masking_key: - vals.append(":key=%s" % repr(self.masking_key)) - if self.payload_length: - vals.append(" %s" % utils.pretty_size(self.payload_length)) - return "".join(vals) - - def to_bytes(self): - first_byte = utils.setbit(0, 7, self.fin) - first_byte = utils.setbit(first_byte, 6, self.rsv1) - first_byte = utils.setbit(first_byte, 5, self.rsv2) - first_byte = utils.setbit(first_byte, 4, self.rsv3) - first_byte = first_byte | self.opcode - - second_byte = utils.setbit(self.length_code, 7, self.mask) - - b = chr(first_byte) + chr(second_byte) - - if self.payload_length < 126: - pass - elif self.payload_length < MAX_16_BIT_INT: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - b += struct.pack('!H', self.payload_length) - elif self.payload_length < MAX_64_BIT_INT: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - b += struct.pack('!Q', self.payload_length) - if self.masking_key is not None: - b += self.masking_key - return b - - @classmethod - def from_file(cls, fp): - """ - read a websockets frame header - """ - first_byte = utils.bytes_to_int(fp.safe_read(1)) - second_byte = utils.bytes_to_int(fp.safe_read(1)) - - fin = utils.getbit(first_byte, 7) - rsv1 = utils.getbit(first_byte, 6) - rsv2 = utils.getbit(first_byte, 5) - rsv3 = utils.getbit(first_byte, 4) - # grab right-most 4 bits - opcode = first_byte & 15 - mask_bit = utils.getbit(second_byte, 7) - # grab the next 7 bits - length_code = second_byte & 127 - - # payload_lengthy > 125 indicates you need to read more bytes - # to get the actual payload length - if length_code <= 125: - payload_length = length_code - elif length_code == 126: - payload_length = utils.bytes_to_int(fp.safe_read(2)) - elif length_code == 127: - payload_length = utils.bytes_to_int(fp.safe_read(8)) - - # masking key only present if mask bit set - if mask_bit == 1: - masking_key = fp.safe_read(4) - else: - masking_key = None - - return cls( - fin=fin, - rsv1=rsv1, - rsv2=rsv2, - rsv3=rsv3, - opcode=opcode, - mask=mask_bit, - length_code=length_code, - payload_length=payload_length, - masking_key=masking_key, - ) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class Frame(object): - - """ - Represents one websockets frame. - Constructor takes human readable forms of the frame components - from_bytes() is also avaliable. - - WebSockets Frame as defined in RFC6455 - - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-------+-+-------------+-------------------------------+ - |F|R|R|R| opcode|M| Payload len | Extended payload length | - |I|S|S|S| (4) |A| (7) | (16/64) | - |N|V|V|V| |S| | (if payload len==126/127) | - | |1|2|3| |K| | | - +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - | Extended payload length continued, if payload len == 127 | - + - - - - - - - - - - - - - - - +-------------------------------+ - | |Masking-key, if MASK set to 1 | - +-------------------------------+-------------------------------+ - | Masking-key (continued) | Payload Data | - +-------------------------------- - - - - - - - - - - - - - - - + - : Payload Data continued ... : - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - | Payload Data continued ... | - +---------------------------------------------------------------+ - """ - - def __init__(self, payload="", **kwargs): - self.payload = payload - kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) - self.header = FrameHeader(**kwargs) - - @classmethod - def default(cls, message, from_client=False): - """ - Construct a basic websocket frame from some default values. - Creates a non-fragmented text frame. - """ - if from_client: - mask_bit = 1 - masking_key = os.urandom(4) - else: - mask_bit = 0 - masking_key = None - - return cls( - message, - fin=1, # final frame - opcode=OPCODE.TEXT, # text - mask=mask_bit, - masking_key=masking_key, - ) - - @classmethod - def from_bytes(cls, bytestring): - """ - Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use from_file() directly - """ - return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) - - def human_readable(self): - ret = self.header.human_readable() - if self.payload: - ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) - return ret - - def __repr__(self): - return self.header.human_readable() - - def to_bytes(self): - """ - Serialize the frame to wire format. Returns a string. - """ - b = self.header.to_bytes() - if self.header.masking_key: - b += Masker(self.header.masking_key)(self.payload) - else: - b += self.payload - return b - - def to_file(self, writer): - writer.write(self.to_bytes()) - writer.flush() - - @classmethod - def from_file(cls, fp): - """ - read a websockets frame sent by a server or client - - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - header = FrameHeader.from_file(fp) - payload = fp.safe_read(header.payload_length) - - if header.mask == 1 and header.masking_key: - payload = Masker(header.masking_key)(payload) - - return cls( - payload, - fin=header.fin, - opcode=header.opcode, - mask=header.mask, - payload_length=header.payload_length, - masking_key=header.masking_key, - rsv1=header.rsv1, - rsv2=header.rsv2, - rsv3=header.rsv3, - length_code=header.length_code - ) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py new file mode 100644 index 00000000..5acf7696 --- /dev/null +++ b/netlib/websockets/__init__.py @@ -0,0 +1,2 @@ +from frame import * +from protocol import * diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py new file mode 100644 index 00000000..d41059fa --- /dev/null +++ b/netlib/websockets/frame.py @@ -0,0 +1,288 @@ +from __future__ import absolute_import +import base64 +import hashlib +import os +import struct +import io + +from .protocol import Masker +from .. import utils, odict, tcp + +DEFAULT = object() + +MAX_16_BIT_INT = (1 << 16) +MAX_64_BIT_INT = (1 << 64) + +OPCODE = utils.BiDi( + CONTINUE=0x00, + TEXT=0x01, + BINARY=0x02, + CLOSE=0x08, + PING=0x09, + PONG=0x0a +) + +class FrameHeader(object): + + def __init__( + self, + opcode=OPCODE.TEXT, + payload_length=0, + fin=False, + rsv1=False, + rsv2=False, + rsv3=False, + masking_key=DEFAULT, + mask=DEFAULT, + length_code=DEFAULT + ): + if not 0 <= opcode < 2 ** 4: + raise ValueError("opcode must be 0-16") + self.opcode = opcode + self.payload_length = payload_length + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + + if length_code is DEFAULT: + self.length_code = self._make_length_code(self.payload_length) + else: + self.length_code = length_code + + if mask is DEFAULT and masking_key is DEFAULT: + self.mask = False + self.masking_key = "" + elif mask is DEFAULT: + self.mask = 1 + self.masking_key = masking_key + elif masking_key is DEFAULT: + self.mask = mask + self.masking_key = os.urandom(4) + else: + self.mask = mask + self.masking_key = masking_key + + if self.masking_key and len(self.masking_key) != 4: + raise ValueError("Masking key must be 4 bytes.") + + @classmethod + def _make_length_code(self, length): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + if length <= 125: + return length + elif length >= 126 and length <= 65535: + return 126 + else: + return 127 + + def human_readable(self): + vals = [ + "ws frame:", + OPCODE.get_name(self.opcode, hex(self.opcode)).lower() + ] + flags = [] + for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: + if getattr(self, i): + flags.append(i) + if flags: + vals.extend([":", "|".join(flags)]) + if self.masking_key: + vals.append(":key=%s" % repr(self.masking_key)) + if self.payload_length: + vals.append(" %s" % utils.pretty_size(self.payload_length)) + return "".join(vals) + + def to_bytes(self): + first_byte = utils.setbit(0, 7, self.fin) + first_byte = utils.setbit(first_byte, 6, self.rsv1) + first_byte = utils.setbit(first_byte, 5, self.rsv2) + first_byte = utils.setbit(first_byte, 4, self.rsv3) + first_byte = first_byte | self.opcode + + second_byte = utils.setbit(self.length_code, 7, self.mask) + + b = chr(first_byte) + chr(second_byte) + + if self.payload_length < 126: + pass + elif self.payload_length < MAX_16_BIT_INT: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + b += struct.pack('!H', self.payload_length) + elif self.payload_length < MAX_64_BIT_INT: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + b += struct.pack('!Q', self.payload_length) + if self.masking_key is not None: + b += self.masking_key + return b + + @classmethod + def from_file(cls, fp): + """ + read a websockets frame header + """ + first_byte = utils.bytes_to_int(fp.safe_read(1)) + second_byte = utils.bytes_to_int(fp.safe_read(1)) + + fin = utils.getbit(first_byte, 7) + rsv1 = utils.getbit(first_byte, 6) + rsv2 = utils.getbit(first_byte, 5) + rsv3 = utils.getbit(first_byte, 4) + # grab right-most 4 bits + opcode = first_byte & 15 + mask_bit = utils.getbit(second_byte, 7) + # grab the next 7 bits + length_code = second_byte & 127 + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if length_code <= 125: + payload_length = length_code + elif length_code == 126: + payload_length = utils.bytes_to_int(fp.safe_read(2)) + elif length_code == 127: + payload_length = utils.bytes_to_int(fp.safe_read(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = fp.safe_read(4) + else: + masking_key = None + + return cls( + fin=fin, + rsv1=rsv1, + rsv2=rsv2, + rsv3=rsv3, + opcode=opcode, + mask=mask_bit, + length_code=length_code, + payload_length=payload_length, + masking_key=masking_key, + ) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class Frame(object): + + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + + def __init__(self, payload="", **kwargs): + self.payload = payload + kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) + self.header = FrameHeader(**kwargs) + + @classmethod + def default(cls, message, from_client=False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + if from_client: + mask_bit = 1 + masking_key = os.urandom(4) + else: + mask_bit = 0 + masking_key = None + + return cls( + message, + fin=1, # final frame + opcode=OPCODE.TEXT, # text + mask=mask_bit, + masking_key=masking_key, + ) + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use from_file() directly + """ + return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) + + def human_readable(self): + ret = self.header.human_readable() + if self.payload: + ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) + return ret + + def __repr__(self): + return self.header.human_readable() + + def to_bytes(self): + """ + Serialize the frame to wire format. Returns a string. + """ + b = self.header.to_bytes() + if self.header.masking_key: + b += Masker(self.header.masking_key)(self.payload) + else: + b += self.payload + return b + + def to_file(self, writer): + writer.write(self.to_bytes()) + writer.flush() + + @classmethod + def from_file(cls, fp): + """ + read a websockets frame sent by a server or client + + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + header = FrameHeader.from_file(fp) + payload = fp.safe_read(header.payload_length) + + if header.mask == 1 and header.masking_key: + payload = Masker(header.masking_key)(payload) + + return cls( + payload, + fin=header.fin, + opcode=header.opcode, + mask=header.mask, + payload_length=header.payload_length, + masking_key=header.masking_key, + rsv1=header.rsv1, + rsv2=header.rsv2, + rsv3=header.rsv3, + length_code=header.length_code + ) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py new file mode 100644 index 00000000..dcab53fb --- /dev/null +++ b/netlib/websockets/protocol.py @@ -0,0 +1,111 @@ +from __future__ import absolute_import +import base64 +import hashlib +import os +import struct +import io + +from .. import utils, odict, tcp + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness +# +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand +# RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +VERSION = "13" + +HEADER_WEBSOCKET_KEY = 'sec-websocket-key' +HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept' +HEADER_WEBSOCKET_VERSION = 'sec-websocket-version' + +class Masker(object): + + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + + def __init__(self, key): + self.key = key + self.masks = [utils.bytes_to_int(byte) for byte in key] + self.offset = 0 + + def mask(self, offset, data): + result = "" + for c in data: + result += chr(ord(c) ^ self.masks[offset % 4]) + offset += 1 + return result + + def __call__(self, data): + ret = self.mask(self.offset, data) + self.offset += len(ret) + return ret + +class WebsocketsProtocol(object): + + def __init__(self): + pass + + @classmethod + def client_handshake_headers(self, key=None, version=VERSION): + """ + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of ODictCaseless + """ + if not key: + key = base64.b64encode(os.urandom(16)).decode('utf-8') + return odict.ODictCaseless([ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + (HEADER_WEBSOCKET_KEY, key), + (HEADER_WEBSOCKET_VERSION, version) + ]) + + @classmethod + def server_handshake_headers(self, key): + """ + The server response is a valid HTTP 101 response. + """ + return odict.ODictCaseless( + [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + (HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key)) + ] + ) + + + @classmethod + def check_client_handshake(self, headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first(HEADER_WEBSOCKET_KEY) + + + @classmethod + def check_server_handshake(self, headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first(HEADER_WEBSOCKET_ACCEPT) + + + @classmethod + def create_server_nonce(self, client_nonce): + return base64.b64encode( + hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') + ) diff --git a/test/test_websockets.py b/test/test_websockets.py index 9956543b..ae0a5e33 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -12,6 +12,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): super(WebSocketsEchoHandler, self).__init__( connection, address, server ) + self.protocol = websockets.WebsocketsProtocol() self.handshake_done = False def handle(self): @@ -31,10 +32,10 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def handshake(self): req = http.read_request(self.rfile) - key = websockets.check_client_handshake(req.headers) + key = self.protocol.check_client_handshake(req.headers) self.wfile.write(http.response_preamble(101) + "\r\n") - headers = websockets.server_handshake_headers(key) + headers = self.protocol.server_handshake_headers(key) self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True @@ -48,6 +49,7 @@ class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) + self.protocol = websockets.WebsocketsProtocol() self.client_nonce = None def connect(self): @@ -55,15 +57,15 @@ class WebSocketsClient(tcp.TCPClient): preamble = http.request_preamble("GET", "/") self.wfile.write(preamble + "\r\n") - headers = websockets.client_handshake_headers() + headers = self.protocol.client_handshake_headers() self.client_nonce = headers.get_first("sec-websocket-key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() resp = http.read_response(self.rfile, "get", None) - server_nonce = websockets.check_server_handshake(resp.headers) + server_nonce = self.protocol.check_server_handshake(resp.headers) - if not server_nonce == websockets.create_server_nonce( + if not server_nonce == self.protocol.create_server_nonce( self.client_nonce): self.close() @@ -78,6 +80,9 @@ class WebSocketsClient(tcp.TCPClient): class TestWebSockets(tservers.ServerTestBase): handler = WebSocketsEchoHandler + def __init__(self): + self.protocol = websockets.WebsocketsProtocol() + def random_bytes(self, n=100): return os.urandom(n) @@ -130,26 +135,26 @@ class TestWebSockets(tservers.ServerTestBase): assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes def test_check_server_handshake(self): - headers = websockets.server_handshake_headers("key") - assert websockets.check_server_handshake(headers) + headers = self.protocol.server_handshake_headers("key") + assert self.protocol.check_server_handshake(headers) headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_server_handshake(headers) + assert not self.protocol.check_server_handshake(headers) def test_check_client_handshake(self): - headers = websockets.client_handshake_headers("key") - assert websockets.check_client_handshake(headers) == "key" + headers = self.protocol.client_handshake_headers("key") + assert self.protocol.check_client_handshake(headers) == "key" headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_client_handshake(headers) + assert not self.protocol.check_client_handshake(headers) class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): client_hs = http.read_request(self.rfile) - websockets.check_client_handshake(client_hs.headers) + self.protocol.check_client_handshake(client_hs.headers) self.wfile.write(http.response_preamble(101) + "\r\n") - headers = websockets.server_handshake_headers("malformed key") + headers = self.protocol.server_handshake_headers("malformed key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True -- cgit v1.2.3 From f50deb7b763d093a22a4d331e16465a2fb0329cf Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 14 Jul 2015 23:02:14 +0200 Subject: move bits around --- netlib/http.py | 583 ------------------------------ netlib/http/__init__.py | 2 + netlib/http/authentication.py | 149 ++++++++ netlib/http/cookies.py | 193 ++++++++++ netlib/http/exceptions.py | 9 + netlib/http/http1/__init__.py | 1 + netlib/http/http1/protocol.py | 518 +++++++++++++++++++++++++++ netlib/http/http2/__init__.py | 2 + netlib/http/http2/frame.py | 636 +++++++++++++++++++++++++++++++++ netlib/http/http2/protocol.py | 240 +++++++++++++ netlib/http/semantics.py | 94 +++++ netlib/http/status_codes.py | 104 ++++++ netlib/http/user_agents.py | 52 +++ netlib/http2/__init__.py | 2 - netlib/http2/frame.py | 636 --------------------------------- netlib/http2/protocol.py | 240 ------------- netlib/http_auth.py | 148 -------- netlib/http_cookies.py | 193 ---------- netlib/http_semantics.py | 23 -- netlib/http_status.py | 104 ------ netlib/http_uastrings.py | 52 --- netlib/websockets/frame.py | 2 +- netlib/websockets/protocol.py | 2 +- test/http/__init__.py | 0 test/http/http1/__init__.py | 0 test/http/http1/test_protocol.py | 445 +++++++++++++++++++++++ test/http/http2/__init__.py | 0 test/http/http2/test_frames.py | 704 +++++++++++++++++++++++++++++++++++++ test/http/http2/test_protocol.py | 325 +++++++++++++++++ test/http/test_authentication.py | 110 ++++++ test/http/test_cookies.py | 219 ++++++++++++ test/http/test_semantics.py | 54 +++ test/http/test_user_agents.py | 6 + test/http2/__init__.py | 0 test/http2/test_frames.py | 704 ------------------------------------- test/http2/test_protocol.py | 326 ----------------- test/test_http.py | 491 -------------------------- test/test_http_auth.py | 109 ------ test/test_http_cookies.py | 219 ------------ test/test_http_uastrings.py | 6 - test/test_websockets.py | 261 -------------- test/websockets/__init__.py | 0 test/websockets/test_websockets.py | 262 ++++++++++++++ 43 files changed, 4127 insertions(+), 4099 deletions(-) delete mode 100644 netlib/http.py create mode 100644 netlib/http/__init__.py create mode 100644 netlib/http/authentication.py create mode 100644 netlib/http/cookies.py create mode 100644 netlib/http/exceptions.py create mode 100644 netlib/http/http1/__init__.py create mode 100644 netlib/http/http1/protocol.py create mode 100644 netlib/http/http2/__init__.py create mode 100644 netlib/http/http2/frame.py create mode 100644 netlib/http/http2/protocol.py create mode 100644 netlib/http/semantics.py create mode 100644 netlib/http/status_codes.py create mode 100644 netlib/http/user_agents.py delete mode 100644 netlib/http2/__init__.py delete mode 100644 netlib/http2/frame.py delete mode 100644 netlib/http2/protocol.py delete mode 100644 netlib/http_auth.py delete mode 100644 netlib/http_cookies.py delete mode 100644 netlib/http_semantics.py delete mode 100644 netlib/http_status.py delete mode 100644 netlib/http_uastrings.py create mode 100644 test/http/__init__.py create mode 100644 test/http/http1/__init__.py create mode 100644 test/http/http1/test_protocol.py create mode 100644 test/http/http2/__init__.py create mode 100644 test/http/http2/test_frames.py create mode 100644 test/http/http2/test_protocol.py create mode 100644 test/http/test_authentication.py create mode 100644 test/http/test_cookies.py create mode 100644 test/http/test_semantics.py create mode 100644 test/http/test_user_agents.py delete mode 100644 test/http2/__init__.py delete mode 100644 test/http2/test_frames.py delete mode 100644 test/http2/test_protocol.py delete mode 100644 test/test_http.py delete mode 100644 test/test_http_auth.py delete mode 100644 test/test_http_cookies.py delete mode 100644 test/test_http_uastrings.py delete mode 100644 test/test_websockets.py create mode 100644 test/websockets/__init__.py create mode 100644 test/websockets/test_websockets.py diff --git a/netlib/http.py b/netlib/http.py deleted file mode 100644 index 073e9a3f..00000000 --- a/netlib/http.py +++ /dev/null @@ -1,583 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import collections -import string -import urlparse -import binascii -import sys -from . import odict, utils, tcp, http_semantics, http_status - - -class HttpError(Exception): - - def __init__(self, code, message): - super(HttpError, self).__init__(message) - self.code = code - - -class HttpErrorConnClosed(HttpError): - pass - - -def _is_valid_port(port): - if not 0 <= port <= 65535: - return False - return True - - -def _is_valid_host(host): - try: - host.decode("idna") - except ValueError: - return False - if "\0" in host: - return None - return True - - -def get_request_line(fp): - """ - Get a line, possibly preceded by a blank. - """ - line = fp.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = fp.readline() - return line - - -def parse_url(url): - """ - Returns a (scheme, host, port, path) tuple, or None on error. - - Checks that: - port is an integer 0-65535 - host is a valid IDNA-encoded hostname with no null-bytes - path is valid ASCII - """ - try: - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - except ValueError: - return None - if not scheme: - return None - if '@' in netloc: - # FIXME: Consider what to do with the discarded credentials here Most - # probably we should extend the signature to return these as a separate - # value. - _, netloc = string.rsplit(netloc, '@', maxsplit=1) - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None - else: - host = netloc - if scheme == "https": - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path - if not _is_valid_host(host): - return None - if not utils.isascii(path): - return None - if not _is_valid_port(port): - return None - return scheme, host, port, path - - -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line is - reached. Return a ODictCaseless object, or None if headers are invalid. - """ - ret = [] - name = '' - while True: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) - else: - return None - return odict.ODictCaseless(ret) - - -def read_chunked(fp, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = fp.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = fp.read(length) - suffix = fp.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' - if length == 0: - return - - -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks - - -def has_chunked_encoding(headers): - return "chunked" in [ - i.lower() for i in get_header_tokens(headers, "transfer-encoding") - ] - - -def parse_http_protocol(s): - """ - Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or - None. - """ - if not s.startswith("HTTP/"): - return None - _, version = s.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - -def parse_http_basic_auth(s): - words = s.split() - if len(words) != 2: - return None - scheme = words[0] - try: - user = binascii.a2b_base64(words[1]) - except binascii.Error: - return None - parts = user.split(':') - if len(parts) != 2: - return None - return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v - - -def parse_init(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - -def parse_init_connect(line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not _is_valid_port(port): - return None - if not _is_valid_host(host): - return None - return host, port, httpversion - - -def parse_init_proxy(line): - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - -def connection_close(httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - if httpversion == (1, 1): - return False - return True - - -def parse_response_line(line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - -def read_http_body(*args, **kwargs): - return "".join( - content for _, content, _ in read_http_body_chunked(*args, **kwargs) - ) - - -def read_http_body_chunked( - rfile, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None -): - """ - Read an HTTP message body: - - rfile: A file descriptor to read from - headers: An ODictCaseless object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if has_chunked_encoding(headers): - # Python 3: yield from - for x in read_chunked(rfile, limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - yield "", rfile.read(chunk_size), "" - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if not content: - return - yield "", content, "" - bytes_left -= chunk_size - not_done = rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - -def expected_http_body_size(headers, is_request, request_method, response_code): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"][0]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - -Request = collections.namedtuple( - "Request", - [ - "form_in", - "method", - "scheme", - "host", - "port", - "path", - "httpversion", - "headers", - "content" - ] -) - - -def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): - """ - Parse an HTTP request from a file stream - - Args: - rfile (file): Input file to read from - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - httpversion, host, port, scheme, method, path, headers, content = ( - None, None, None, None, None, None, None, None) - - request_line = get_request_line(rfile) - if not request_line: - raise tcp.NetLibDisconnect() - - request_line_parts = parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, _ = r - path = None - else: - form_in = "absolute" - r = parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = read_headers(rfile) - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): - wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - wfile.flush() - del headers['expect'] - - if include_body: - content = read_http_body( - rfile, headers, body_size_limit, method, None, True - ) - - return Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content - ) - - -def read_response(rfile, request_method, body_size_limit, include_body=True): - """ - Return an (httpversion, code, msg, headers, content) tuple. - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - line = rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = read_headers(rfile) - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = read_http_body( - rfile, - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return http_semantics.Response(httpversion, code, msg, headers, content) - - -def request_preamble(method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - -def response_preamble(code, message=None, http_major="1", http_minor="1"): - if message is None: - message = http_status.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py new file mode 100644 index 00000000..9b4b0e6b --- /dev/null +++ b/netlib/http/__init__.py @@ -0,0 +1,2 @@ +from exceptions import * +from semantics import * diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py new file mode 100644 index 00000000..26e3c2c4 --- /dev/null +++ b/netlib/http/authentication.py @@ -0,0 +1,149 @@ +from __future__ import (absolute_import, print_function, division) +from argparse import Action, ArgumentTypeError + +from .. import http + + +class NullProxyAuth(object): + + """ + No proxy auth at all (returns empty challange headers) + """ + + def __init__(self, password_manager): + self.password_manager = password_manager + + def clean(self, headers_): + """ + Clean up authentication headers, so they're not passed upstream. + """ + pass + + def authenticate(self, headers_): + """ + Tests that the user is allowed to use the proxy + """ + return True + + def auth_challenge_headers(self): + """ + Returns a dictionary containing the headers require to challenge the user + """ + return {} + + +class BasicProxyAuth(NullProxyAuth): + CHALLENGE_HEADER = 'Proxy-Authenticate' + AUTH_HEADER = 'Proxy-Authorization' + + def __init__(self, password_manager, realm): + NullProxyAuth.__init__(self, password_manager) + self.realm = realm + + def clean(self, headers): + del headers[self.AUTH_HEADER] + + def authenticate(self, headers): + auth_value = headers.get(self.AUTH_HEADER, []) + if not auth_value: + return False + parts = http.http1.parse_http_basic_auth(auth_value[0]) + if not parts: + return False + scheme, username, password = parts + if scheme.lower() != 'basic': + return False + if not self.password_manager.test(username, password): + return False + self.username = username + return True + + def auth_challenge_headers(self): + return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} + + +class PassMan(object): + + def test(self, username_, password_token_): + return False + + +class PassManNonAnon(PassMan): + + """ + Ensure the user specifies a username, accept any password. + """ + + def test(self, username, password_token_): + if username: + return True + return False + + +class PassManHtpasswd(PassMan): + + """ + Read usernames and passwords from an htpasswd file + """ + + def __init__(self, path): + """ + Raises ValueError if htpasswd file is invalid. + """ + import passlib.apache + self.htpasswd = passlib.apache.HtpasswdFile(path) + + def test(self, username, password_token): + return bool(self.htpasswd.check_password(username, password_token)) + + +class PassManSingleUser(PassMan): + + def __init__(self, username, password): + self.username, self.password = username, password + + def test(self, username, password_token): + return self.username == username and self.password == password_token + + +class AuthAction(Action): + + """ + Helper class to allow seamless integration int argparse. Example usage: + parser.add_argument( + "--nonanonymous", + action=NonanonymousAuthAction, nargs=0, + help="Allow access to any user long as a credentials are specified." + ) + """ + + def __call__(self, parser, namespace, values, option_string=None): + passman = self.getPasswordManager(values) + authenticator = BasicProxyAuth(passman, "mitmproxy") + setattr(namespace, self.dest, authenticator) + + def getPasswordManager(self, s): # pragma: nocover + raise NotImplementedError() + + +class SingleuserAuthAction(AuthAction): + + def getPasswordManager(self, s): + if len(s.split(':')) != 2: + raise ArgumentTypeError( + "Invalid single-user specification. Please use the format username:password" + ) + username, password = s.split(':') + return PassManSingleUser(username, password) + + +class NonanonymousAuthAction(AuthAction): + + def getPasswordManager(self, s): + return PassManNonAnon() + + +class HtpasswdAuthAction(AuthAction): + + def getPasswordManager(self, s): + return PassManHtpasswd(s) diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py new file mode 100644 index 00000000..b77e3503 --- /dev/null +++ b/netlib/http/cookies.py @@ -0,0 +1,193 @@ +import re + +from .. import odict + +""" +A flexible module for cookie parsing and manipulation. + +This module differs from usual standards-compliant cookie modules in a number +of ways. We try to be as permissive as possible, and to retain even mal-formed +information. Duplicate cookies are preserved in parsing, and can be set in +formatting. We do attempt to escape and quote values where needed, but will not +reject data that violate the specs. + +Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do +not parse the comma-separated variant of Set-Cookie that allows multiple +cookies to be set in a single header. Technically this should be feasible, but +it turns out that violations of RFC6265 that makes the parsing problem +indeterminate are much more common than genuine occurences of the multi-cookie +variants. Serialization follows RFC6265. + + http://tools.ietf.org/html/rfc6265 + http://tools.ietf.org/html/rfc2109 + http://tools.ietf.org/html/rfc2965 +""" + +# TODO +# - Disallow LHS-only Cookie values + + +def _read_until(s, start, term): + """ + Read until one of the characters in term is reached. + """ + if start == len(s): + return "", start + 1 + for i in range(start, len(s)): + if s[i] in term: + return s[start:i], i + return s[start:i + 1], i + 1 + + +def _read_token(s, start): + """ + Read a token - the LHS of a token/value pair in a cookie. + """ + return _read_until(s, start, ";=") + + +def _read_quoted_string(s, start): + """ + start: offset to the first quote of the string to be read + + A sort of loose super-set of the various quoted string specifications. + + RFC6265 disallows backslashes or double quotes within quoted strings. + Prior RFCs use backslashes to escape. This leaves us free to apply + backslash escaping by default and be compatible with everything. + """ + escaping = False + ret = [] + # Skip the first quote + for i in range(start + 1, len(s)): + if escaping: + ret.append(s[i]) + escaping = False + elif s[i] == '"': + break + elif s[i] == "\\": + escaping = True + else: + ret.append(s[i]) + return "".join(ret), i + 1 + + +def _read_value(s, start, delims): + """ + Reads a value - the RHS of a token/value pair in a cookie. + + special: If the value is special, commas are premitted. Else comma + terminates. This helps us support old and new style values. + """ + if start >= len(s): + return "", start + elif s[start] == '"': + return _read_quoted_string(s, start) + else: + return _read_until(s, start, delims) + + +def _read_pairs(s, off=0): + """ + Read pairs of lhs=rhs values. + + off: start offset + specials: a lower-cased list of keys that may contain commas + """ + vals = [] + while True: + lhs, off = _read_token(s, off) + lhs = lhs.lstrip() + if lhs: + rhs = None + if off < len(s): + if s[off] == "=": + rhs, off = _read_value(s, off + 1, ";") + vals.append([lhs, rhs]) + off += 1 + if not off < len(s): + break + return vals, off + + +def _has_special(s): + for i in s: + if i in '",;\\': + return True + o = ord(i) + if o < 0x21 or o > 0x7e: + return True + return False + + +ESCAPE = re.compile(r"([\"\\])") + + +def _format_pairs(lst, specials=(), sep="; "): + """ + specials: A lower-cased list of keys that will not be quoted. + """ + vals = [] + for k, v in lst: + if v is None: + vals.append(k) + else: + if k.lower() not in specials and _has_special(v): + v = ESCAPE.sub(r"\\\1", v) + v = '"%s"' % v + vals.append("%s=%s" % (k, v)) + return sep.join(vals) + + +def _format_set_cookie_pairs(lst): + return _format_pairs( + lst, + specials=("expires", "path") + ) + + +def _parse_set_cookie_pairs(s): + """ + For Set-Cookie, we support multiple cookies as described in RFC2109. + This function therefore returns a list of lists. + """ + pairs, off_ = _read_pairs(s) + return pairs + + +def parse_set_cookie_header(line): + """ + Parse a Set-Cookie header value + + Returns a (name, value, attrs) tuple, or None, where attrs is an + ODictCaseless set of attributes. No attempt is made to parse attribute + values - they are treated purely as strings. + """ + pairs = _parse_set_cookie_pairs(line) + if pairs: + return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + + +def format_set_cookie_header(name, value, attrs): + """ + Formats a Set-Cookie header value. + """ + pairs = [[name, value]] + pairs.extend(attrs.lst) + return _format_set_cookie_pairs(pairs) + + +def parse_cookie_header(line): + """ + Parse a Cookie header value. + Returns a (possibly empty) ODict object. + """ + pairs, off_ = _read_pairs(line) + return odict.ODict(pairs) + + +def format_cookie_header(od): + """ + Formats a Cookie header value. + """ + return _format_pairs(od.lst) diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py new file mode 100644 index 00000000..8a2bbebc --- /dev/null +++ b/netlib/http/exceptions.py @@ -0,0 +1,9 @@ +class HttpError(Exception): + + def __init__(self, code, message): + super(HttpError, self).__init__(message) + self.code = code + + +class HttpErrorConnClosed(HttpError): + pass diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py new file mode 100644 index 00000000..6b5043af --- /dev/null +++ b/netlib/http/http1/__init__.py @@ -0,0 +1 @@ +from protocol import * diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py new file mode 100644 index 00000000..0f7a0bd3 --- /dev/null +++ b/netlib/http/http1/protocol.py @@ -0,0 +1,518 @@ +from __future__ import (absolute_import, print_function, division) +import binascii +import collections +import string +import sys +import urlparse + +from netlib import odict, utils, tcp, http +from .. import status_codes +from ..exceptions import * + + +def get_request_line(fp): + """ + Get a line, possibly preceded by a blank. + """ + line = fp.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = fp.readline() + return line + +def read_headers(fp): + """ + Read a set of headers from a file pointer. Stop once a blank line is + reached. Return a ODictCaseless object, or None if headers are invalid. + """ + ret = [] + name = '' + while True: + line = fp.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + if not ret: + return None + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i + 1:].strip() + ret.append([name, value]) + else: + return None + return odict.ODictCaseless(ret) + + +def read_chunked(fp, limit, is_request): + """ + Read a chunked HTTP body. + + May raise HttpError. + """ + # FIXME: Should check if chunked is the final encoding in the headers + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + # 3.3 2. + total = 0 + code = 400 if is_request else 502 + while True: + line = fp.readline(128) + if line == "": + raise HttpErrorConnClosed(code, "Connection closed prematurely") + if line != '\r\n' and line != '\n': + try: + length = int(line, 16) + except ValueError: + raise HttpError( + code, + "Invalid chunked encoding length: %s" % line + ) + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large. Limit is %s," \ + " chunked content longer than %s" % (limit, total) + raise HttpError(code, msg) + chunk = fp.read(length) + suffix = fp.readline(5) + if suffix != '\r\n': + raise HttpError(code, "Malformed chunked body") + yield line, chunk, '\r\n' + if length == 0: + return + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: + for j in i.split(","): + toks.append(j.strip()) + return toks + + +def has_chunked_encoding(headers): + return "chunked" in [ + i.lower() for i in get_header_tokens(headers, "transfer-encoding") + ] + + +def parse_http_protocol(s): + """ + Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or + None. + """ + if not s.startswith("HTTP/"): + return None + _, version = s.split('/', 1) + if "." not in version: + return None + major, minor = version.split('.', 1) + try: + major = int(major) + minor = int(minor) + except ValueError: + return None + return major, minor + + +def parse_http_basic_auth(s): + # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v + + +def parse_init(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + if not utils.isascii(method): + return None + return method, url, httpversion + + +def parse_init_connect(line): + """ + Returns (host, port, httpversion) if line is a valid CONNECT line. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + """ + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + + if method.upper() != 'CONNECT': + return None + try: + host, port = url.split(":") + except ValueError: + return None + try: + port = int(port) + except ValueError: + return None + if not http.is_valid_port(port): + return None + if not http.is_valid_host(host): + return None + return host, port, httpversion + + +def parse_init_proxy(line): + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + + parts = http.parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + return method, scheme, host, port, path, httpversion + + +def parse_init_http(line): + """ + Returns (method, url, httpversion) + """ + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + if not utils.isascii(url): + return None + if not (url.startswith("/") or url == "*"): + return None + return method, url, httpversion + + +def connection_close(httpversion, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1 Note that a connection should be + closed as well if the response has been read until end of the stream. + """ + # At first, check if we have an explicit Connection header. + if "connection" in headers: + toks = get_header_tokens(headers, "connection") + if "close" in toks: + return True + elif "keep-alive" in toks: + return False + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + if httpversion == (1, 1): + return False + return True + + +def parse_response_line(line): + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if len(parts) != 3: + return None + proto, code, msg = parts + try: + code = int(code) + except ValueError: + return None + return (proto, code, msg) + + +def read_http_body(*args, **kwargs): + return "".join( + content for _, content, _ in read_http_body_chunked(*args, **kwargs) + ) + + +def read_http_body_chunked( + rfile, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None +): + """ + Read an HTTP message body: + + rfile: A file descriptor to read from + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False + otherwise + """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxsize + + expected_size = expected_http_body_size( + headers, is_request, request_method, response_code + ) + + if expected_size is None: + if has_chunked_encoding(headers): + # Python 3: yield from + for x in read_chunked(rfile, limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + +def expected_http_body_size(headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + return None + if is_request: + return 0 + return -1 + + +# TODO: make this a regular class - just like Response +Request = collections.namedtuple( + "Request", + [ + "form_in", + "method", + "scheme", + "host", + "port", + "path", + "httpversion", + "headers", + "content" + ] +) + + +def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): + """ + Parse an HTTP request from a file stream + + Args: + rfile (file): Input file to read from + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. + """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = get_request_line(rfile) + if not request_line: + raise tcp.NetLibDisconnect() + + request_line_parts = parse_init(request_line) + if not request_line_parts: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + path = None + else: + form_in = "absolute" + r = parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = read_headers(rfile) + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): + wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' + ) + wfile.flush() + del headers['expect'] + + if include_body: + content = read_http_body( + rfile, headers, body_size_limit, method, None, True + ) + + return Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content + ) + + +def read_response(rfile, request_method, body_size_limit, include_body=True): + """ + Returns an http.Response + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ + + line = rfile.readline() + # Possible leftover from previous message + if line == "\r\n" or line == "\n": + line = rfile.readline() + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = read_headers(rfile) + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = read_http_body( + rfile, + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return http.Response(httpversion, code, msg, headers, content) + + +def request_preamble(method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) + + +def response_preamble(code, message=None, http_major="1", http_minor="1"): + if message is None: + message = status_codes.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py new file mode 100644 index 00000000..5acf7696 --- /dev/null +++ b/netlib/http/http2/__init__.py @@ -0,0 +1,2 @@ +from frame import * +from protocol import * diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py new file mode 100644 index 00000000..f7e60471 --- /dev/null +++ b/netlib/http/http2/frame.py @@ -0,0 +1,636 @@ +import sys +import struct +from hpack.hpack import Encoder, Decoder + +from .. import utils + + +class FrameSizeError(Exception): + pass + + +class Frame(object): + + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): + valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + if state is None: + class State(object): + pass + + state = State() + state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + + self.length = length + self.type = self.TYPE + self.flags = flags + self.stream_id = stream_id + + @classmethod + def _check_frame_size(cls, length, state): + if state: + settings = state.http2_settings + else: + settings = HTTP2_DEFAULT_SETTINGS.copy() + + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: + raise FrameSizeError( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_frame_size)) + + @classmethod + def from_file(cls, fp, state=None): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + if raw_header[:4] == b'HTTP': # pragma no cover + print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" + + cls._check_frame_size(length, state) + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + self._check_frame_size(self.length, self.state) + + b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self, direction="-"): + self.length = len(self.payload_bytes()) + + return "\n".join([ + "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + self.payload_human_readable(), + "===============================================================", + ]) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): + super(DataFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + return "payload: %s" % str(self.payload) + + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b'', + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(HeadersFrame, self).__init__(state, length, flags, stream_id) + + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & Frame.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack( + '!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) + + b += self.header_block_fragment + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(PriorityFrame, self).__init__(state, length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') + + if self.stream_dependency == 0x0: + raise ValueError('stream dependency is invalid.') + + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) + + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) + + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + + def payload_human_readable(self): + return "error code: %#x" % self.error_code + + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = utils.BiDi( + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): + super(SettingsFrame, self).__init__(state, length, flags, stream_id) + + if settings is None: + settings = {} + + self.settings = settings + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + for i in xrange(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i + 6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) + + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): + super(PingFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) + + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) + + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: + raise ValueError( + 'Window Szie Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment + + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + + def payload_human_readable(self): + s = [] + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} + + +HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, +} diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py new file mode 100644 index 00000000..8e5f5429 --- /dev/null +++ b/netlib/http/http2/protocol.py @@ -0,0 +1,240 @@ +from __future__ import (absolute_import, print_function, division) +import itertools + +from hpack.hpack import Encoder, Decoder +from .. import utils +from . import frame + + +class HTTP2Protocol(object): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE =\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + + ALPN_PROTO_H2 = 'h2' + + def __init__(self, tcp_handler, is_server=False, dump_frames=False): + self.tcp_handler = tcp_handler + self.is_server = is_server + + self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.encoder = Encoder() + self.decoder = Decoder() + self.connection_preface_performed = False + self.dump_frames = dump_frames + + def check_alpn(self): + alp = self.tcp_handler.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + return True + + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + + def _read_settings_ack(self, hide=False): # pragma no cover + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + assert frm.flags & frame.Frame.FLAG_ACK + assert len(frm.settings) == 0 + break + + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) + + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) + + def next_stream_id(self): + if self.current_stream_id is None: + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def send_frame(self, frm, hide=False): + raw_bytes = frm.to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + def read_frame(self, hide=False): + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable("<<")) + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) + + return frm + + def _apply_settings(self, settings, hide=False): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + self.http2_settings[setting] = value + + frm = frame.SettingsFrame( + state=self, + flags=frame.Frame.FLAG_ACK) + self.send_frame(frm, hide) + + # be liberal in what we expect from the other end + # to be more strict use: self._read_settings_ack(hide) + + def _create_headers(self, headers, stream_id, end_stream=True): + # TODO: implement max frame size checks and sending in chunks + + flags = frame.Frame.FLAG_END_HEADERS + if end_stream: + flags |= frame.Frame.FLAG_END_STREAM + + header_block_fragment = self.encoder.encode(headers) + + frm = frame.HeadersFrame( + state=self, + flags=flags, + stream_id=stream_id, + header_block_fragment=header_block_fragment) + + if self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + return [frm.to_bytes()] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + # TODO: implement max frame size checks and sending in chunks + # TODO: implement flow-control window + + frm = frame.DataFrame( + state=self, + flags=frame.Frame.FLAG_END_STREAM, + stream_id=stream_id, + payload=body) + + if self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + return [frm.to_bytes()] + + + def create_request(self, method, path, headers=None, body=None): + if headers is None: + headers = [] + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = [ + (b':method', bytes(method)), + (b':path', bytes(path)), + (b':scheme', b'https'), + (b':authority', authority), + ] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) + + def read_response(self): + stream_id_, headers, body = self._receive_transmission() + return headers[':status'], headers, body + + def read_request(self): + return self._receive_transmission() + + def _receive_transmission(self): + body_expected = True + + stream_id = 0 + header_block_fragment = b'' + body = b'' + + while True: + frm = self.read_frame() + if isinstance(frm, frame.HeadersFrame)\ + or isinstance(frm, frame.ContinuationFrame): + stream_id = frm.stream_id + header_block_fragment += frm.header_block_fragment + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False + if frm.flags & frame.Frame.FLAG_END_HEADERS: + break + + while body_expected: + frm = self.read_frame() + if isinstance(frm, frame.DataFrame): + body += frm.payload + if frm.flags & frame.Frame.FLAG_END_STREAM: + break + # TODO: implement window update & flow + + headers = {} + for header, value in self.decoder.decode(header_block_fragment): + headers[header] = value + + return stream_id, headers, body + + def create_response(self, code, stream_id=None, headers=None, body=None): + if headers is None: + headers = [] + + headers = [(b':status', bytes(str(code)))] + headers + + if not stream_id: + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id), + )) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py new file mode 100644 index 00000000..e7e84fe3 --- /dev/null +++ b/netlib/http/semantics.py @@ -0,0 +1,94 @@ +from __future__ import (absolute_import, print_function, division) +import binascii +import collections +import string +import sys +import urlparse + +from .. import utils + +class Response(object): + + def __init__( + self, + httpversion, + status_code, + msg, + headers, + content, + sslinfo=None, + ): + self.httpversion = httpversion + self.status_code = status_code + self.msg = msg + self.headers = headers + self.content = content + self.sslinfo = sslinfo + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Response(%s - %s)" % (self.status_code, self.msg) + + + +def is_valid_port(port): + if not 0 <= port <= 65535: + return False + return True + + +def is_valid_host(host): + try: + host.decode("idna") + except ValueError: + return False + if "\0" in host: + return None + return True + + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + + Checks that: + port is an integer 0-65535 + host is a valid IDNA-encoded hostname with no null-bytes + path is valid ASCII + """ + try: + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + except ValueError: + return None + if not scheme: + return None + if '@' in netloc: + # FIXME: Consider what to do with the discarded credentials here Most + # probably we should extend the signature to return these as a separate + # value. + _, netloc = string.rsplit(netloc, '@', maxsplit=1) + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + if not is_valid_host(host): + return None + if not utils.isascii(path): + return None + if not is_valid_port(port): + return None + return scheme, host, port, path diff --git a/netlib/http/status_codes.py b/netlib/http/status_codes.py new file mode 100644 index 00000000..dc09f465 --- /dev/null +++ b/netlib/http/status_codes.py @@ -0,0 +1,104 @@ +from __future__ import (absolute_import, print_function, division) + +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 + +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 + +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 +REQUESTED_RANGE_NOT_SATISFIABLE = 416 +EXPECTATION_FAILED = 417 + +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510 + +RESPONSES = { + # 100 + CONTINUE: "Continue", + SWITCHING: "Switching Protocols", + + # 200 + OK: "OK", + CREATED: "Created", + ACCEPTED: "Accepted", + NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", + NO_CONTENT: "No Content", + RESET_CONTENT: "Reset Content.", + PARTIAL_CONTENT: "Partial Content", + MULTI_STATUS: "Multi-Status", + + # 300 + MULTIPLE_CHOICE: "Multiple Choices", + MOVED_PERMANENTLY: "Moved Permanently", + FOUND: "Found", + SEE_OTHER: "See Other", + NOT_MODIFIED: "Not Modified", + USE_PROXY: "Use Proxy", + # 306 not defined?? + TEMPORARY_REDIRECT: "Temporary Redirect", + + # 400 + BAD_REQUEST: "Bad Request", + UNAUTHORIZED: "Unauthorized", + PAYMENT_REQUIRED: "Payment Required", + FORBIDDEN: "Forbidden", + NOT_FOUND: "Not Found", + NOT_ALLOWED: "Method Not Allowed", + NOT_ACCEPTABLE: "Not Acceptable", + PROXY_AUTH_REQUIRED: "Proxy Authentication Required", + REQUEST_TIMEOUT: "Request Time-out", + CONFLICT: "Conflict", + GONE: "Gone", + LENGTH_REQUIRED: "Length Required", + PRECONDITION_FAILED: "Precondition Failed", + REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", + REQUEST_URI_TOO_LONG: "Request-URI Too Long", + UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", + REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", + EXPECTATION_FAILED: "Expectation Failed", + + # 500 + INTERNAL_SERVER_ERROR: "Internal Server Error", + NOT_IMPLEMENTED: "Not Implemented", + BAD_GATEWAY: "Bad Gateway", + SERVICE_UNAVAILABLE: "Service Unavailable", + GATEWAY_TIMEOUT: "Gateway Time-out", + HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", + INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", + NOT_EXTENDED: "Not Extended" +} diff --git a/netlib/http/user_agents.py b/netlib/http/user_agents.py new file mode 100644 index 00000000..e8681908 --- /dev/null +++ b/netlib/http/user_agents.py @@ -0,0 +1,52 @@ +from __future__ import (absolute_import, print_function, division) + +""" + A small collection of useful user-agent header strings. These should be + kept reasonably current to reflect common usage. +""" + +# pylint: line-too-long + +# A collection of (name, shortcut, string) tuples. + +UASTRINGS = [ + ("android", + "a", + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa + ("blackberry", + "l", + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa + ("bingbot", + "b", + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa + ("chrome", + "c", + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa + ("firefox", + "f", + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa + ("googlebot", + "g", + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa + ("ie9", + "i", + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa + ("ipad", + "p", + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa + ("iphone", + "h", + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa + ("safari", + "s", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa +] + + +def get_by_shortcut(s): + """ + Retrieve a user agent entry by shortcut. + """ + for i in UASTRINGS: + if s == i[1]: + return i diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py deleted file mode 100644 index 5acf7696..00000000 --- a/netlib/http2/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from frame import * -from protocol import * diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py deleted file mode 100644 index f7e60471..00000000 --- a/netlib/http2/frame.py +++ /dev/null @@ -1,636 +0,0 @@ -import sys -import struct -from hpack.hpack import Encoder, Decoder - -from .. import utils - - -class FrameSizeError(Exception): - pass - - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - class State(object): - pass - - state = State() - state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(cls, length, state): - if state: - settings = state.http2_settings - else: - settings = HTTP2_DEFAULT_SETTINGS.copy() - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise FrameSizeError( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(cls, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - if raw_header[:4] == b'HTTP': # pragma no cover - print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" - - cls._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self, direction="-"): - self.length = len(self.payload_bytes()) - - return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), - self.payload_human_readable(), - "===============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & Frame.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - if self.stream_dependency == 0x0: - raise ValueError('stream dependency is invalid.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Szie Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - s = [] - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - return "\n".join(s) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} - - -HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, -} diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py deleted file mode 100644 index 8e5f5429..00000000 --- a/netlib/http2/protocol.py +++ /dev/null @@ -1,240 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import itertools - -from hpack.hpack import Encoder, Decoder -from .. import utils -from . import frame - - -class HTTP2Protocol(object): - - ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd - ) - - # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE =\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') - - ALPN_PROTO_H2 = 'h2' - - def __init__(self, tcp_handler, is_server=False, dump_frames=False): - self.tcp_handler = tcp_handler - self.is_server = is_server - - self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() - self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() - self.connection_preface_performed = False - self.dump_frames = dump_frames - - def check_alpn(self): - alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - return True - - def _receive_settings(self, hide=False): - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - break - - def _read_settings_ack(self, hide=False): # pragma no cover - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - assert frm.flags & frame.Frame.FLAG_ACK - assert len(frm.settings) == 0 - break - - def perform_server_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - magic_length = len(self.CLIENT_CONNECTION_PREFACE) - magic = self.tcp_handler.rfile.safe_read(magic_length) - assert magic == self.CLIENT_CONNECTION_PREFACE - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) - - def perform_client_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) - - def next_stream_id(self): - if self.current_stream_id is None: - if self.is_server: - # servers must use even stream ids - self.current_stream_id = 2 - else: - # clients must use odd stream ids - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - - def send_frame(self, frm, hide=False): - raw_bytes = frm.to_bytes() - self.tcp_handler.wfile.write(raw_bytes) - self.tcp_handler.wfile.flush() - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - def read_frame(self, hide=False): - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable("<<")) - if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: - self._apply_settings(frm.settings, hide) - - return frm - - def _apply_settings(self, settings, hide=False): - for setting, value in settings.items(): - old_value = self.http2_settings[setting] - if not old_value: - old_value = '-' - self.http2_settings[setting] = value - - frm = frame.SettingsFrame( - state=self, - flags=frame.Frame.FLAG_ACK) - self.send_frame(frm, hide) - - # be liberal in what we expect from the other end - # to be more strict use: self._read_settings_ack(hide) - - def _create_headers(self, headers, stream_id, end_stream=True): - # TODO: implement max frame size checks and sending in chunks - - flags = frame.Frame.FLAG_END_HEADERS - if end_stream: - flags |= frame.Frame.FLAG_END_STREAM - - header_block_fragment = self.encoder.encode(headers) - - frm = frame.HeadersFrame( - state=self, - flags=flags, - stream_id=stream_id, - header_block_fragment=header_block_fragment) - - if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - return [frm.to_bytes()] - - def _create_body(self, body, stream_id): - if body is None or len(body) == 0: - return b'' - - # TODO: implement max frame size checks and sending in chunks - # TODO: implement flow-control window - - frm = frame.DataFrame( - state=self, - flags=frame.Frame.FLAG_END_STREAM, - stream_id=stream_id, - payload=body) - - if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - return [frm.to_bytes()] - - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self): - stream_id_, headers, body = self._receive_transmission() - return headers[':status'], headers, body - - def read_request(self): - return self._receive_transmission() - - def _receive_transmission(self): - body_expected = True - - stream_id = 0 - header_block_fragment = b'' - body = b'' - - while True: - frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame)\ - or isinstance(frm, frame.ContinuationFrame): - stream_id = frm.stream_id - header_block_fragment += frm.header_block_fragment - if frm.flags & frame.Frame.FLAG_END_STREAM: - body_expected = False - if frm.flags & frame.Frame.FLAG_END_HEADERS: - break - - while body_expected: - frm = self.read_frame() - if isinstance(frm, frame.DataFrame): - body += frm.payload - if frm.flags & frame.Frame.FLAG_END_STREAM: - break - # TODO: implement window update & flow - - headers = {} - for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value - - return stream_id, headers, body - - def create_response(self, code, stream_id=None, headers=None, body=None): - if headers is None: - headers = [] - - headers = [(b':status', bytes(str(code)))] + headers - - if not stream_id: - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id), - )) diff --git a/netlib/http_auth.py b/netlib/http_auth.py deleted file mode 100644 index adab4aed..00000000 --- a/netlib/http_auth.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -from argparse import Action, ArgumentTypeError -from . import http - - -class NullProxyAuth(object): - - """ - No proxy auth at all (returns empty challange headers) - """ - - def __init__(self, password_manager): - self.password_manager = password_manager - - def clean(self, headers_): - """ - Clean up authentication headers, so they're not passed upstream. - """ - pass - - def authenticate(self, headers_): - """ - Tests that the user is allowed to use the proxy - """ - return True - - def auth_challenge_headers(self): - """ - Returns a dictionary containing the headers require to challenge the user - """ - return {} - - -class BasicProxyAuth(NullProxyAuth): - CHALLENGE_HEADER = 'Proxy-Authenticate' - AUTH_HEADER = 'Proxy-Authorization' - - def __init__(self, password_manager, realm): - NullProxyAuth.__init__(self, password_manager) - self.realm = realm - - def clean(self, headers): - del headers[self.AUTH_HEADER] - - def authenticate(self, headers): - auth_value = headers.get(self.AUTH_HEADER, []) - if not auth_value: - return False - parts = http.parse_http_basic_auth(auth_value[0]) - if not parts: - return False - scheme, username, password = parts - if scheme.lower() != 'basic': - return False - if not self.password_manager.test(username, password): - return False - self.username = username - return True - - def auth_challenge_headers(self): - return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} - - -class PassMan(object): - - def test(self, username_, password_token_): - return False - - -class PassManNonAnon(PassMan): - - """ - Ensure the user specifies a username, accept any password. - """ - - def test(self, username, password_token_): - if username: - return True - return False - - -class PassManHtpasswd(PassMan): - - """ - Read usernames and passwords from an htpasswd file - """ - - def __init__(self, path): - """ - Raises ValueError if htpasswd file is invalid. - """ - import passlib.apache - self.htpasswd = passlib.apache.HtpasswdFile(path) - - def test(self, username, password_token): - return bool(self.htpasswd.check_password(username, password_token)) - - -class PassManSingleUser(PassMan): - - def __init__(self, username, password): - self.username, self.password = username, password - - def test(self, username, password_token): - return self.username == username and self.password == password_token - - -class AuthAction(Action): - - """ - Helper class to allow seamless integration int argparse. Example usage: - parser.add_argument( - "--nonanonymous", - action=NonanonymousAuthAction, nargs=0, - help="Allow access to any user long as a credentials are specified." - ) - """ - - def __call__(self, parser, namespace, values, option_string=None): - passman = self.getPasswordManager(values) - authenticator = BasicProxyAuth(passman, "mitmproxy") - setattr(namespace, self.dest, authenticator) - - def getPasswordManager(self, s): # pragma: nocover - raise NotImplementedError() - - -class SingleuserAuthAction(AuthAction): - - def getPasswordManager(self, s): - if len(s.split(':')) != 2: - raise ArgumentTypeError( - "Invalid single-user specification. Please use the format username:password" - ) - username, password = s.split(':') - return PassManSingleUser(username, password) - - -class NonanonymousAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManNonAnon() - - -class HtpasswdAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManHtpasswd(s) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py deleted file mode 100644 index e91ee5c0..00000000 --- a/netlib/http_cookies.py +++ /dev/null @@ -1,193 +0,0 @@ -""" -A flexible module for cookie parsing and manipulation. - -This module differs from usual standards-compliant cookie modules in a number -of ways. We try to be as permissive as possible, and to retain even mal-formed -information. Duplicate cookies are preserved in parsing, and can be set in -formatting. We do attempt to escape and quote values where needed, but will not -reject data that violate the specs. - -Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do -not parse the comma-separated variant of Set-Cookie that allows multiple -cookies to be set in a single header. Technically this should be feasible, but -it turns out that violations of RFC6265 that makes the parsing problem -indeterminate are much more common than genuine occurences of the multi-cookie -variants. Serialization follows RFC6265. - - http://tools.ietf.org/html/rfc6265 - http://tools.ietf.org/html/rfc2109 - http://tools.ietf.org/html/rfc2965 -""" - -# TODO -# - Disallow LHS-only Cookie values - -import re - -import odict - - -def _read_until(s, start, term): - """ - Read until one of the characters in term is reached. - """ - if start == len(s): - return "", start + 1 - for i in range(start, len(s)): - if s[i] in term: - return s[start:i], i - return s[start:i + 1], i + 1 - - -def _read_token(s, start): - """ - Read a token - the LHS of a token/value pair in a cookie. - """ - return _read_until(s, start, ";=") - - -def _read_quoted_string(s, start): - """ - start: offset to the first quote of the string to be read - - A sort of loose super-set of the various quoted string specifications. - - RFC6265 disallows backslashes or double quotes within quoted strings. - Prior RFCs use backslashes to escape. This leaves us free to apply - backslash escaping by default and be compatible with everything. - """ - escaping = False - ret = [] - # Skip the first quote - for i in range(start + 1, len(s)): - if escaping: - ret.append(s[i]) - escaping = False - elif s[i] == '"': - break - elif s[i] == "\\": - escaping = True - else: - ret.append(s[i]) - return "".join(ret), i + 1 - - -def _read_value(s, start, delims): - """ - Reads a value - the RHS of a token/value pair in a cookie. - - special: If the value is special, commas are premitted. Else comma - terminates. This helps us support old and new style values. - """ - if start >= len(s): - return "", start - elif s[start] == '"': - return _read_quoted_string(s, start) - else: - return _read_until(s, start, delims) - - -def _read_pairs(s, off=0): - """ - Read pairs of lhs=rhs values. - - off: start offset - specials: a lower-cased list of keys that may contain commas - """ - vals = [] - while True: - lhs, off = _read_token(s, off) - lhs = lhs.lstrip() - if lhs: - rhs = None - if off < len(s): - if s[off] == "=": - rhs, off = _read_value(s, off + 1, ";") - vals.append([lhs, rhs]) - off += 1 - if not off < len(s): - break - return vals, off - - -def _has_special(s): - for i in s: - if i in '",;\\': - return True - o = ord(i) - if o < 0x21 or o > 0x7e: - return True - return False - - -ESCAPE = re.compile(r"([\"\\])") - - -def _format_pairs(lst, specials=(), sep="; "): - """ - specials: A lower-cased list of keys that will not be quoted. - """ - vals = [] - for k, v in lst: - if v is None: - vals.append(k) - else: - if k.lower() not in specials and _has_special(v): - v = ESCAPE.sub(r"\\\1", v) - v = '"%s"' % v - vals.append("%s=%s" % (k, v)) - return sep.join(vals) - - -def _format_set_cookie_pairs(lst): - return _format_pairs( - lst, - specials=("expires", "path") - ) - - -def _parse_set_cookie_pairs(s): - """ - For Set-Cookie, we support multiple cookies as described in RFC2109. - This function therefore returns a list of lists. - """ - pairs, off_ = _read_pairs(s) - return pairs - - -def parse_set_cookie_header(line): - """ - Parse a Set-Cookie header value - - Returns a (name, value, attrs) tuple, or None, where attrs is an - ODictCaseless set of attributes. No attempt is made to parse attribute - values - they are treated purely as strings. - """ - pairs = _parse_set_cookie_pairs(line) - if pairs: - return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) - - -def format_set_cookie_header(name, value, attrs): - """ - Formats a Set-Cookie header value. - """ - pairs = [[name, value]] - pairs.extend(attrs.lst) - return _format_set_cookie_pairs(pairs) - - -def parse_cookie_header(line): - """ - Parse a Cookie header value. - Returns a (possibly empty) ODict object. - """ - pairs, off_ = _read_pairs(line) - return odict.ODict(pairs) - - -def format_cookie_header(od): - """ - Formats a Cookie header value. - """ - return _format_pairs(od.lst) diff --git a/netlib/http_semantics.py b/netlib/http_semantics.py deleted file mode 100644 index e8313e3c..00000000 --- a/netlib/http_semantics.py +++ /dev/null @@ -1,23 +0,0 @@ -class Response(object): - - def __init__( - self, - httpversion, - status_code, - msg, - headers, - content, - sslinfo=None, - ): - self.httpversion = httpversion - self.status_code = status_code - self.msg = msg - self.headers = headers - self.content = content - self.sslinfo = sslinfo - - def __eq__(self, other): - return self.__dict__ == other.__dict__ - - def __repr__(self): - return "Response(%s - %s)" % (self.status_code, self.msg) diff --git a/netlib/http_status.py b/netlib/http_status.py deleted file mode 100644 index dc09f465..00000000 --- a/netlib/http_status.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import (absolute_import, print_function, division) - -CONTINUE = 100 -SWITCHING = 101 -OK = 200 -CREATED = 201 -ACCEPTED = 202 -NON_AUTHORITATIVE_INFORMATION = 203 -NO_CONTENT = 204 -RESET_CONTENT = 205 -PARTIAL_CONTENT = 206 -MULTI_STATUS = 207 - -MULTIPLE_CHOICE = 300 -MOVED_PERMANENTLY = 301 -FOUND = 302 -SEE_OTHER = 303 -NOT_MODIFIED = 304 -USE_PROXY = 305 -TEMPORARY_REDIRECT = 307 - -BAD_REQUEST = 400 -UNAUTHORIZED = 401 -PAYMENT_REQUIRED = 402 -FORBIDDEN = 403 -NOT_FOUND = 404 -NOT_ALLOWED = 405 -NOT_ACCEPTABLE = 406 -PROXY_AUTH_REQUIRED = 407 -REQUEST_TIMEOUT = 408 -CONFLICT = 409 -GONE = 410 -LENGTH_REQUIRED = 411 -PRECONDITION_FAILED = 412 -REQUEST_ENTITY_TOO_LARGE = 413 -REQUEST_URI_TOO_LONG = 414 -UNSUPPORTED_MEDIA_TYPE = 415 -REQUESTED_RANGE_NOT_SATISFIABLE = 416 -EXPECTATION_FAILED = 417 - -INTERNAL_SERVER_ERROR = 500 -NOT_IMPLEMENTED = 501 -BAD_GATEWAY = 502 -SERVICE_UNAVAILABLE = 503 -GATEWAY_TIMEOUT = 504 -HTTP_VERSION_NOT_SUPPORTED = 505 -INSUFFICIENT_STORAGE_SPACE = 507 -NOT_EXTENDED = 510 - -RESPONSES = { - # 100 - CONTINUE: "Continue", - SWITCHING: "Switching Protocols", - - # 200 - OK: "OK", - CREATED: "Created", - ACCEPTED: "Accepted", - NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", - NO_CONTENT: "No Content", - RESET_CONTENT: "Reset Content.", - PARTIAL_CONTENT: "Partial Content", - MULTI_STATUS: "Multi-Status", - - # 300 - MULTIPLE_CHOICE: "Multiple Choices", - MOVED_PERMANENTLY: "Moved Permanently", - FOUND: "Found", - SEE_OTHER: "See Other", - NOT_MODIFIED: "Not Modified", - USE_PROXY: "Use Proxy", - # 306 not defined?? - TEMPORARY_REDIRECT: "Temporary Redirect", - - # 400 - BAD_REQUEST: "Bad Request", - UNAUTHORIZED: "Unauthorized", - PAYMENT_REQUIRED: "Payment Required", - FORBIDDEN: "Forbidden", - NOT_FOUND: "Not Found", - NOT_ALLOWED: "Method Not Allowed", - NOT_ACCEPTABLE: "Not Acceptable", - PROXY_AUTH_REQUIRED: "Proxy Authentication Required", - REQUEST_TIMEOUT: "Request Time-out", - CONFLICT: "Conflict", - GONE: "Gone", - LENGTH_REQUIRED: "Length Required", - PRECONDITION_FAILED: "Precondition Failed", - REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", - REQUEST_URI_TOO_LONG: "Request-URI Too Long", - UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", - REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", - EXPECTATION_FAILED: "Expectation Failed", - - # 500 - INTERNAL_SERVER_ERROR: "Internal Server Error", - NOT_IMPLEMENTED: "Not Implemented", - BAD_GATEWAY: "Bad Gateway", - SERVICE_UNAVAILABLE: "Service Unavailable", - GATEWAY_TIMEOUT: "Gateway Time-out", - HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", - INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", - NOT_EXTENDED: "Not Extended" -} diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py deleted file mode 100644 index e8681908..00000000 --- a/netlib/http_uastrings.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import (absolute_import, print_function, division) - -""" - A small collection of useful user-agent header strings. These should be - kept reasonably current to reflect common usage. -""" - -# pylint: line-too-long - -# A collection of (name, shortcut, string) tuples. - -UASTRINGS = [ - ("android", - "a", - "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa - ("blackberry", - "l", - "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa - ("bingbot", - "b", - "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa - ("chrome", - "c", - "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa - ("firefox", - "f", - "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa - ("googlebot", - "g", - "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa - ("ie9", - "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa - ("ipad", - "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa - ("iphone", - "h", - "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa - ("safari", - "s", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa -] - - -def get_by_shortcut(s): - """ - Retrieve a user agent entry by shortcut. - """ - for i in UASTRINGS: - if s == i[1]: - return i diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index d41059fa..49d8ee10 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -6,7 +6,7 @@ import struct import io from .protocol import Masker -from .. import utils, odict, tcp +from netlib import utils, odict, tcp DEFAULT = object() diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index dcab53fb..29b4db3d 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -5,7 +5,7 @@ import os import struct import io -from .. import utils, odict, tcp +from netlib import utils, odict, tcp # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. diff --git a/test/http/__init__.py b/test/http/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/http/http1/__init__.py b/test/http/http1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py new file mode 100644 index 00000000..05e82831 --- /dev/null +++ b/test/http/http1/test_protocol.py @@ -0,0 +1,445 @@ +import cStringIO +import textwrap +import binascii + +from netlib import http, odict, tcp +from netlib.http.http1 import protocol +from ... import tutils, tservers + + +def test_has_chunked_encoding(): + h = odict.ODictCaseless() + assert not protocol.has_chunked_encoding(h) + h["transfer-encoding"] = ["chunked"] + assert protocol.has_chunked_encoding(h) + + +def test_read_chunked(): + + h = odict.ODictCaseless() + h["transfer-encoding"] = ["chunked"] + s = cStringIO.StringIO("1\r\na\r\n0\r\n") + + tutils.raises( + "malformed chunked body", + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") + assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + + s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") + assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + + s = cStringIO.StringIO("\r\n") + tutils.raises( + "closed prematurely", + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("1\r\nfoo") + tutils.raises( + "malformed chunked body", + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("foo\r\nfoo") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") + tutils.raises("too large", protocol.read_http_body, s, h, 2, "GET", None, True) + + +def test_connection_close(): + h = odict.ODictCaseless() + assert protocol.connection_close((1, 0), h) + assert not protocol.connection_close((1, 1), h) + + h["connection"] = ["keep-alive"] + assert not protocol.connection_close((1, 1), h) + + h["connection"] = ["close"] + assert protocol.connection_close((1, 1), h) + + +def test_get_header_tokens(): + h = odict.ODictCaseless() + assert protocol.get_header_tokens(h, "foo") == [] + h["foo"] = ["bar"] + assert protocol.get_header_tokens(h, "foo") == ["bar"] + h["foo"] = ["bar, voing"] + assert protocol.get_header_tokens(h, "foo") == ["bar", "voing"] + h["foo"] = ["bar, voing", "oink"] + assert protocol.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] + + +def test_read_http_body_request(): + h = odict.ODictCaseless() + r = cStringIO.StringIO("testing") + assert protocol.read_http_body(r, h, None, "GET", None, True) == "" + + +def test_read_http_body_response(): + h = odict.ODictCaseless() + s = tcp.Reader(cStringIO.StringIO("testing")) + assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + + +def test_read_http_body(): + # test default case + h = odict.ODictCaseless() + h["content-length"] = [7] + s = cStringIO.StringIO("testing") + assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + + # test content length: invalid header + h["content-length"] = ["foo"] + s = cStringIO.StringIO("testing") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, None, "GET", 200, False + ) + + # test content length: invalid header #2 + h["content-length"] = [-1] + s = cStringIO.StringIO("testing") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, None, "GET", 200, False + ) + + # test content length: content length > actual content + h["content-length"] = [5] + s = cStringIO.StringIO("testing") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, 4, "GET", 200, False + ) + + # test content length: content length < actual content + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, None, "GET", 200, False)) == 5 + + # test no content length: limit > actual content + h = odict.ODictCaseless() + s = tcp.Reader(cStringIO.StringIO("testing")) + assert len(protocol.read_http_body(s, h, 100, "GET", 200, False)) == 7 + + # test no content length: limit < actual content + s = tcp.Reader(cStringIO.StringIO("testing")) + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, 4, "GET", 200, False + ) + + # test chunked + h = odict.ODictCaseless() + h["transfer-encoding"] = ["chunked"] + s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) + assert protocol.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" + + +def test_expected_http_body_size(): + # gibber in the content-length field + h = odict.ODictCaseless() + h["content-length"] = ["foo"] + assert protocol.expected_http_body_size(h, False, "GET", 200) is None + # negative number in the content-length field + h = odict.ODictCaseless() + h["content-length"] = ["-7"] + assert protocol.expected_http_body_size(h, False, "GET", 200) is None + # explicit length + h = odict.ODictCaseless() + h["content-length"] = ["5"] + assert protocol.expected_http_body_size(h, False, "GET", 200) == 5 + # no length + h = odict.ODictCaseless() + assert protocol.expected_http_body_size(h, False, "GET", 200) == -1 + # no length request + h = odict.ODictCaseless() + assert protocol.expected_http_body_size(h, True, "GET", None) == 0 + + +def test_parse_http_protocol(): + assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) + assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not protocol.parse_http_protocol("HTTP/a.1") + assert not protocol.parse_http_protocol("HTTP/1.a") + assert not protocol.parse_http_protocol("foo/0.0") + assert not protocol.parse_http_protocol("HTTP/x") + + +def test_parse_init_connect(): + assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") + assert not protocol.parse_init_connect("bogus") + assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0") + + +def test_parse_init_proxy(): + u = "GET http://foo.com:8888/test HTTP/1.1" + m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) + assert m == "GET" + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + assert httpversion == (1, 1) + + u = "G\xfeET http://foo.com:8888/test HTTP/1.1" + assert not protocol.parse_init_proxy(u) + + assert not protocol.parse_init_proxy("invalid") + assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") + assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + + +def test_parse_init_http(): + u = "GET /test HTTP/1.1" + m, u, httpversion = protocol.parse_init_http(u) + assert m == "GET" + assert u == "/test" + assert httpversion == (1, 1) + + u = "G\xfeET /test HTTP/1.1" + assert not protocol.parse_init_http(u) + + assert not protocol.parse_init_http("invalid") + assert not protocol.parse_init_http("GET invalid HTTP/1.1") + assert not protocol.parse_init_http("GET /test foo/1.1") + assert not protocol.parse_init_http("GET /test\xc0 HTTP/1.1") + + +class TestReadHeaders: + + def _read(self, data, verbatim=False): + if not verbatim: + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + return protocol.read_headers(s) + + def test_read_simple(self): + data = """ + Header: one + Header2: two + \r\n + """ + h = self._read(data) + assert h.lst == [["Header", "one"], ["Header2", "two"]] + + def test_read_multi(self): + data = """ + Header: one + Header: two + \r\n + """ + h = self._read(data) + assert h.lst == [["Header", "one"], ["Header", "two"]] + + def test_read_continued(self): + data = """ + Header: one + \ttwo + Header2: three + \r\n + """ + h = self._read(data) + assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] + + def test_read_continued_err(self): + data = "\tfoo: bar\r\n" + assert self._read(data, True) is None + + def test_read_err(self): + data = """ + foo + """ + assert self._read(data) is None + + +class NoContentLengthHTTPHandler(tcp.BaseHandler): + + def handle(self): + self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") + self.wfile.flush() + + +class TestReadResponseNoContentLength(tservers.ServerTestBase): + handler = NoContentLengthHTTPHandler + + def test_no_content_length(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + resp = protocol.read_response(c.rfile, "GET", None) + assert resp.content == "bar\r\n\r\n" + + +def test_read_response(): + def tst(data, method, limit, include_body=True): + data = textwrap.dedent(data) + r = cStringIO.StringIO(data) + return protocol.read_response( + r, method, limit, include_body=include_body + ) + + tutils.raises("server disconnect", tst, "", "GET", None) + tutils.raises("invalid server response", tst, "foo", "GET", None) + data = """ + HTTP/1.1 200 OK + """ + assert tst(data, "GET", None) == http.Response( + (1, 1), 200, 'OK', odict.ODictCaseless(), '' + ) + data = """ + HTTP/1.1 200 + """ + assert tst(data, "GET", None) == http.Response( + (1, 1), 200, '', odict.ODictCaseless(), '' + ) + data = """ + HTTP/x 200 OK + """ + tutils.raises("invalid http version", tst, data, "GET", None) + data = """ + HTTP/1.1 xx OK + """ + tutils.raises("invalid server response", tst, data, "GET", None) + + data = """ + HTTP/1.1 100 CONTINUE + + HTTP/1.1 200 OK + """ + assert tst(data, "GET", None) == http.Response( + (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' + ) + + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert tst(data, "GET", None).content == 'foo' + assert tst(data, "HEAD", None).content == '' + + data = """ + HTTP/1.1 200 OK + \tContent-Length: 3 + + foo + """ + tutils.raises("invalid headers", tst, data, "GET", None) + + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert tst(data, "GET", None, include_body=False).content is None + + +def test_parse_http_basic_auth(): + vals = ("basic", "foo", "bar") + assert protocol.parse_http_basic_auth( + protocol.assemble_http_basic_auth(*vals) + ) == vals + assert not protocol.parse_http_basic_auth("") + assert not protocol.parse_http_basic_auth("foo bar") + v = "basic " + binascii.b2a_base64("foo") + assert not protocol.parse_http_basic_auth(v) + + +def test_get_request_line(): + r = cStringIO.StringIO("\nfoo") + assert protocol.get_request_line(r) == "foo" + assert not protocol.get_request_line(r) + + +class TestReadRequest(): + + def tst(self, data, **kwargs): + r = cStringIO.StringIO(data) + return protocol.read_request(r, **kwargs) + + def test_invalid(self): + tutils.raises( + "bad http request", + self.tst, + "xxx" + ) + tutils.raises( + "bad http request line", + self.tst, + "get /\xff HTTP/1.1" + ) + tutils.raises( + "invalid headers", + self.tst, + "get / HTTP/1.1\r\nfoo" + ) + tutils.raises( + tcp.NetLibDisconnect, + self.tst, + "\r\n" + ) + + def test_asterisk_form_in(self): + v = self.tst("OPTIONS * HTTP/1.1") + assert v.form_in == "relative" + assert v.method == "OPTIONS" + + def test_absolute_form_in(self): + tutils.raises( + "Bad HTTP request line", + self.tst, + "GET oops-no-protocol.com HTTP/1.1" + ) + v = self.tst("GET http://address:22/ HTTP/1.1") + assert v.form_in == "absolute" + assert v.port == 22 + assert v.host == "address" + assert v.scheme == "http" + + def test_connect(self): + tutils.raises( + "Bad HTTP request line", + self.tst, + "CONNECT oops-no-port.com HTTP/1.1" + ) + v = self.tst("CONNECT foo.com:443 HTTP/1.1") + assert v.form_in == "authority" + assert v.method == "CONNECT" + assert v.port == 443 + assert v.host == "foo.com" + + def test_expect(self): + w = cStringIO.StringIO() + r = cStringIO.StringIO( + "GET / HTTP/1.1\r\n" + "Content-Length: 3\r\n" + "Expect: 100-continue\r\n\r\n" + "foobar", + ) + v = protocol.read_request(r, wfile=w) + assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" + assert v.content == "foo" + assert r.read(3) == "bar" diff --git a/test/http/http2/__init__.py b/test/http/http2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py new file mode 100644 index 00000000..ee2edc39 --- /dev/null +++ b/test/http/http2/test_frames.py @@ -0,0 +1,704 @@ +import cStringIO +from test import tutils +from nose.tools import assert_equal +from netlib import tcp +from netlib.http.http2.frame import * + + +def hex_to_file(data): + data = data.decode('hex') + return tcp.Reader(cStringIO.StringIO(data)) + + +def test_invalid_flags(): + tutils.raises( + ValueError, + DataFrame, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + payload='foobar') + + +def test_frame_equality(): + a = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + b = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + assert_equal(a, b) + + +def test_too_large_frames(): + f = DataFrame( + length=9000, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar' * 3000) + tutils.raises(FrameSizeError, f.to_bytes) + + +def test_data_frame_to_bytes(): + f = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') + + f = DataFrame( + length=11, + flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), + stream_id=0x1234567, + payload='foobar', + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000a00090123456703666f6f626172000000') + + f = DataFrame( + length=6, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload='foobar') + tutils.raises(ValueError, f.to_bytes) + + +def test_data_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000006000101234567666f6f626172')) + assert isinstance(f, DataFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, DataFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_END_STREAM) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.payload, 'foobar') + + f = Frame.from_file(hex_to_file('00000a00090123456703666f6f626172000000')) + assert isinstance(f, DataFrame) + assert_equal(f.length, 10) + assert_equal(f.TYPE, DataFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.payload, 'foobar') + + +def test_data_frame_human_readable(): + f = DataFrame( + length=11, + flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), + stream_id=0x1234567, + payload='foobar', + pad_length=3) + assert f.human_readable() + + +def test_headers_frame_to_bytes(): + f = HeadersFrame( + length=6, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex')) + assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') + + f = HeadersFrame( + length=10, + flags=(HeadersFrame.FLAG_PADDED), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000b01080123456703668594e75e31d9000000') + + f = HeadersFrame( + length=10, + flags=(HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00000c012001234567876543212a668594e75e31d9') + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00001001280123456703876543212a668594e75e31d9000000') + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00001001280123456703076543212a668594e75e31d9000000') + + f = HeadersFrame( + length=6, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment='668594e75e31d9'.decode('hex')) + tutils.raises(ValueError, f.to_bytes) + + +def test_headers_frame_from_bytes(): + f = Frame.from_file(hex_to_file( + '000007010001234567668594e75e31d9')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 7) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + + f = Frame.from_file(hex_to_file( + '00000b01080123456703668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 11) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + + f = Frame.from_file(hex_to_file( + '00000c012001234567876543212a668594e75e31d9')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 12) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(hex_to_file( + '00001001280123456703876543212a668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 16) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(hex_to_file( + '00001001280123456703076543212a668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 16) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, False) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + +def test_headers_frame_human_readable(): + f = HeadersFrame( + length=7, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment=b'', + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert f.human_readable() + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert f.human_readable() + + +def test_priority_frame_to_bytes(): + f = PriorityFrame( + length=5, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') + + f = PriorityFrame( + length=5, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + exclusive=False, + stream_dependency=0x7654321, + weight=21) + assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') + + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + stream_dependency=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + stream_dependency=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_priority_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000005020001234567876543212a')) + assert isinstance(f, PriorityFrame) + assert_equal(f.length, 5) + assert_equal(f.TYPE, PriorityFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(hex_to_file('0000050200012345670765432115')) + assert isinstance(f, PriorityFrame) + assert_equal(f.length, 5) + assert_equal(f.TYPE, PriorityFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.exclusive, False) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 21) + + +def test_priority_frame_human_readable(): + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + exclusive=False, + stream_dependency=0x7654321, + weight=21) + assert f.human_readable() + + +def test_rst_stream_frame_to_bytes(): + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + error_code=0x7654321) + assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') + + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_rst_stream_frame_from_bytes(): + f = Frame.from_file(hex_to_file('00000403000123456707654321')) + assert isinstance(f, RstStreamFrame) + assert_equal(f.length, 4) + assert_equal(f.TYPE, RstStreamFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.error_code, 0x07654321) + + +def test_rst_stream_frame_human_readable(): + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + error_code=0x7654321) + assert f.human_readable() + + +def test_settings_frame_to_bytes(): + f = SettingsFrame( + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0) + assert_equal(f.to_bytes().encode('hex'), '000000040000000000') + + f = SettingsFrame( + length=0, + flags=SettingsFrame.FLAG_ACK, + stream_id=0x0) + assert_equal(f.to_bytes().encode('hex'), '000000040100000000') + + f = SettingsFrame( + length=6, + flags=SettingsFrame.FLAG_ACK, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) + assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') + + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + assert_equal( + f.to_bytes().encode('hex'), + '00000c040000000000000200000001000312345678') + + f = SettingsFrame( + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + +def test_settings_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000000040000000000')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 0) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + + f = Frame.from_file(hex_to_file('000000040100000000')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 0) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, SettingsFrame.FLAG_ACK) + assert_equal(f.stream_id, 0x0) + + f = Frame.from_file(hex_to_file('000006040100000000000200000001')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, SettingsFrame.FLAG_ACK, 0x0) + assert_equal(f.stream_id, 0x0) + assert_equal(len(f.settings), 1) + assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) + + f = Frame.from_file(hex_to_file( + '00000c040000000000000200000001000312345678')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 12) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(len(f.settings), 2) + assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) + assert_equal( + f.settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], + 0x12345678) + + +def test_settings_frame_human_readable(): + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={}) + assert f.human_readable() + + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + assert f.human_readable() + + +def test_push_promise_frame_to_bytes(): + f = PushPromiseFrame( + length=10, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar') + assert_equal( + f.to_bytes().encode('hex'), + '00000a05000123456707654321666f6f626172') + + f = PushPromiseFrame( + length=14, + flags=HeadersFrame.FLAG_PADDED, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar', + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000e0508012345670307654321666f6f626172000000') + + f = PushPromiseFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + f = PushPromiseFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + promised_stream=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_push_promise_frame_from_bytes(): + f = Frame.from_file(hex_to_file('00000a05000123456707654321666f6f626172')) + assert isinstance(f, PushPromiseFrame) + assert_equal(f.length, 10) + assert_equal(f.TYPE, PushPromiseFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + f = Frame.from_file(hex_to_file( + '00000e0508012345670307654321666f6f626172000000')) + assert isinstance(f, PushPromiseFrame) + assert_equal(f.length, 14) + assert_equal(f.TYPE, PushPromiseFrame.TYPE) + assert_equal(f.flags, PushPromiseFrame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + +def test_push_promise_frame_human_readable(): + f = PushPromiseFrame( + length=14, + flags=HeadersFrame.FLAG_PADDED, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar', + pad_length=3) + assert f.human_readable() + + +def test_ping_frame_to_bytes(): + f = PingFrame( + length=8, + flags=PingFrame.FLAG_ACK, + stream_id=0x0, + payload=b'foobar') + assert_equal( + f.to_bytes().encode('hex'), + '000008060100000000666f6f6261720000') + + f = PingFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'foobardeadbeef') + assert_equal( + f.to_bytes().encode('hex'), + '000008060000000000666f6f6261726465') + + f = PingFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + +def test_ping_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000008060100000000666f6f6261720000')) + assert isinstance(f, PingFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, PingFrame.TYPE) + assert_equal(f.flags, PingFrame.FLAG_ACK) + assert_equal(f.stream_id, 0x0) + assert_equal(f.payload, b'foobar\0\0') + + f = Frame.from_file(hex_to_file('000008060000000000666f6f6261726465')) + assert isinstance(f, PingFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, PingFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.payload, b'foobarde') + + +def test_ping_frame_human_readable(): + f = PingFrame( + length=8, + flags=PingFrame.FLAG_ACK, + stream_id=0x0, + payload=b'foobar') + assert f.human_readable() + + +def test_goaway_frame_to_bytes(): + f = GoAwayFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'') + assert_equal( + f.to_bytes().encode('hex'), + '0000080700000000000123456787654321') + + f = GoAwayFrame( + length=14, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'foobar') + assert_equal( + f.to_bytes().encode('hex'), + '00000e0700000000000123456787654321666f6f626172') + + f = GoAwayFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + last_stream=0x1234567, + error_code=0x87654321) + tutils.raises(ValueError, f.to_bytes) + + +def test_goaway_frame_from_bytes(): + f = Frame.from_file(hex_to_file( + '0000080700000000000123456787654321')) + assert isinstance(f, GoAwayFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, GoAwayFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.last_stream, 0x1234567) + assert_equal(f.error_code, 0x87654321) + assert_equal(f.data, b'') + + f = Frame.from_file(hex_to_file( + '00000e0700000000000123456787654321666f6f626172')) + assert isinstance(f, GoAwayFrame) + assert_equal(f.length, 14) + assert_equal(f.TYPE, GoAwayFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.last_stream, 0x1234567) + assert_equal(f.error_code, 0x87654321) + assert_equal(f.data, b'foobar') + + +def test_go_away_frame_human_readable(): + f = GoAwayFrame( + length=14, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'foobar') + assert f.human_readable() + + +def test_window_update_frame_to_bytes(): + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x1234567) + assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') + + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + window_size_increment=0x7654321) + assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') + + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0xdeadbeef) + tutils.raises(ValueError, f.to_bytes) + + f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) + tutils.raises(ValueError, f.to_bytes) + + +def test_window_update_frame_from_bytes(): + f = Frame.from_file(hex_to_file('00000408000000000001234567')) + assert isinstance(f, WindowUpdateFrame) + assert_equal(f.length, 4) + assert_equal(f.TYPE, WindowUpdateFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.window_size_increment, 0x1234567) + + +def test_window_update_frame_human_readable(): + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + window_size_increment=0x7654321) + assert f.human_readable() + + +def test_continuation_frame_to_bytes(): + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + header_block_fragment='foobar') + assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') + + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x0, + header_block_fragment='foobar') + tutils.raises(ValueError, f.to_bytes) + + +def test_continuation_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000006090401234567666f6f626172')) + assert isinstance(f, ContinuationFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, ContinuationFrame.TYPE) + assert_equal(f.flags, ContinuationFrame.FLAG_END_HEADERS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + +def test_continuation_frame_human_readable(): + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + header_block_fragment='foobar') + assert f.human_readable() diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py new file mode 100644 index 00000000..f607860e --- /dev/null +++ b/test/http/http2/test_protocol.py @@ -0,0 +1,325 @@ +import OpenSSL + +from netlib import tcp +from netlib.http import http2 +from netlib.http.http2.frame import * +from ... import tutils, tservers + + +class EchoHandler(tcp.BaseHandler): + sni = None + + def handle(self): + while True: + v = self.rfile.safe_read(1) + self.wfile.write(v) + self.wfile.flush() + + +class TestCheckALPNMatch(tservers.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) + protocol = http2.HTTP2Protocol(c) + assert protocol.check_alpn() + + +class TestCheckALPNMismatch(tservers.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=None, + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) + protocol = http2.HTTP2Protocol(c) + tutils.raises(NotImplementedError, protocol.check_alpn) + + +class TestPerformServerConnectionPreface(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # send magic + self.wfile.write( + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) + self.wfile.flush() + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + def test_perform_server_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = http2.HTTP2Protocol(c) + protocol.perform_server_connection_preface() + + +class TestPerformClientConnectionPreface(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check magic + assert self.rfile.read(24) ==\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + def test_perform_client_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = http2.HTTP2Protocol(c) + protocol.perform_client_connection_preface() + + +class TestClientStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c) + + def test_client_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol.next_stream_id() == 1 + assert self.protocol.current_stream_id == 1 + assert self.protocol.next_stream_id() == 3 + assert self.protocol.current_stream_id == 3 + assert self.protocol.next_stream_id() == 5 + assert self.protocol.current_stream_id == 5 + + +class TestServerStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c, is_server=True) + + def test_server_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol.next_stream_id() == 2 + assert self.protocol.current_stream_id == 2 + assert self.protocol.next_stream_id() == 4 + assert self.protocol.current_stream_id == 4 + assert self.protocol.next_stream_id() == 6 + assert self.protocol.current_stream_id == 6 + + +class TestApplySettings(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check settings acknowledgement + assert self.rfile.read(9) == '000000040100000000'.decode('hex') + self.wfile.write("OK") + self.wfile.flush() + + ssl = True + + def test_apply_settings(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + protocol._apply_settings({ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 'bar', + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 'deadbeef', + }) + + assert c.rfile.safe_read(2) == "OK" + + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 'foo' + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 'bar' + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE] == 'deadbeef' + + +class TestCreateHeaders(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_headers(self): + headers = [ + (b':method', b'GET'), + (b':path', b'index.html'), + (b':scheme', b'https'), + (b'foo', b'bar')] + + bytes = http2.HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=True) + assert b''.join(bytes) ==\ + '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + bytes = http2.HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=False) + assert b''.join(bytes) ==\ + '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + # TODO: add test for too large header_block_fragments + + +class TestCreateBody(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c) + + def test_create_body_empty(self): + bytes = self.protocol._create_body(b'', 1) + assert b''.join(bytes) == ''.decode('hex') + + def test_create_body_single_frame(self): + bytes = self.protocol._create_body('foobar', 1) + assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') + + def test_create_body_multiple_frames(self): + pass + # bytes = self.protocol._create_body('foobar' * 3000, 1) + # TODO: add test for too large frames + + +class TestCreateRequest(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_request_simple(self): + bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') + assert len(bytes) == 1 + assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') + + def test_create_request_with_body(self): + bytes = http2.HTTP2Protocol(self.c).create_request( + 'GET', '/', [(b'foo', b'bar')], 'foobar') + assert len(bytes) == 2 + assert bytes[0] ==\ + '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') + + +class TestReadResponse(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801040000000188628594e78c767f'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + status, headers, body = protocol.read_response() + + assert headers == {':status': '200', 'etag': 'foobar'} + assert status == "200" + assert body == b'foobar' + + +class TestReadEmptyResponse(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801050000000188628594e78c767f'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_empty_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + status, headers, body = protocol.read_response() + + assert headers == {':status': '200', 'etag': 'foobar'} + assert status == "200" + assert body == b'' + + +class TestReadRequest(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'000003010400000001828487'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_request(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c, is_server=True) + + stream_id, headers, body = protocol.read_request() + + assert stream_id + assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} + assert body == b'foobar' + + +class TestCreateResponse(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_response_simple(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000000288'.decode('hex') + + def test_create_response_with_body(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( + 200, 1, [(b'foo', b'bar')], 'foobar') + assert len(bytes) == 2 + assert bytes[0] ==\ + '00000901040000000188408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py new file mode 100644 index 00000000..c0dae1a2 --- /dev/null +++ b/test/http/test_authentication.py @@ -0,0 +1,110 @@ +from netlib import odict, http +from netlib.http import authentication +from .. import tutils + + +class TestPassManNonAnon: + + def test_simple(self): + p = authentication.PassManNonAnon() + assert not p.test("", "") + assert p.test("user", "") + + +class TestPassManHtpasswd: + + def test_file_errors(self): + tutils.raises( + "malformed htpasswd file", + authentication.PassManHtpasswd, + tutils.test_data.path("data/server.crt")) + + def test_simple(self): + pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) + + vals = ("basic", "test", "test") + http.http1.assemble_http_basic_auth(*vals) + assert pm.test("test", "test") + assert not pm.test("test", "foo") + assert not pm.test("foo", "test") + assert not pm.test("test", "") + assert not pm.test("", "") + + +class TestPassManSingleUser: + + def test_simple(self): + pm = authentication.PassManSingleUser("test", "test") + assert pm.test("test", "test") + assert not pm.test("test", "foo") + assert not pm.test("foo", "test") + + +class TestNullProxyAuth: + + def test_simple(self): + na = authentication.NullProxyAuth(authentication.PassManNonAnon()) + assert not na.auth_challenge_headers() + assert na.authenticate("foo") + na.clean({}) + + +class TestBasicProxyAuth: + + def test_simple(self): + ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") + h = odict.ODictCaseless() + assert ba.auth_challenge_headers() + assert not ba.authenticate(h) + + def test_authenticate_clean(self): + ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") + + hdrs = odict.ODictCaseless() + vals = ("basic", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + assert ba.authenticate(hdrs) + + ba.clean(hdrs) + assert not ba.AUTH_HEADER in hdrs + + hdrs[ba.AUTH_HEADER] = [""] + assert not ba.authenticate(hdrs) + + hdrs[ba.AUTH_HEADER] = ["foo"] + assert not ba.authenticate(hdrs) + + vals = ("foo", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + assert not ba.authenticate(hdrs) + + ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") + vals = ("basic", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + assert not ba.authenticate(hdrs) + + +class Bunch: + pass + + +class TestAuthAction: + + def test_nonanonymous(self): + m = Bunch() + aa = authentication.NonanonymousAuthAction(None, "authenticator") + aa(None, m, None, None) + assert m.authenticator + + def test_singleuser(self): + m = Bunch() + aa = authentication.SingleuserAuthAction(None, "authenticator") + aa(None, m, "foo:bar", None) + assert m.authenticator + tutils.raises("invalid", aa, None, m, "foo", None) + + def test_httppasswd(self): + m = Bunch() + aa = authentication.HtpasswdAuthAction(None, "authenticator") + aa(None, m, tutils.test_data.path("data/htpasswd"), None) + assert m.authenticator diff --git a/test/http/test_cookies.py b/test/http/test_cookies.py new file mode 100644 index 00000000..4f99593a --- /dev/null +++ b/test/http/test_cookies.py @@ -0,0 +1,219 @@ +import nose.tools + +from netlib.http import cookies + + +def test_read_token(): + tokens = [ + [("foo", 0), ("foo", 3)], + [("foo", 1), ("oo", 3)], + [(" foo", 1), ("foo", 4)], + [(" foo;", 1), ("foo", 4)], + [(" foo=", 1), ("foo", 4)], + [(" foo=bar", 1), ("foo", 4)], + ] + for q, a in tokens: + nose.tools.eq_(cookies._read_token(*q), a) + + +def test_read_quoted_string(): + tokens = [ + [('"foo" x', 0), ("foo", 5)], + [('"f\oo" x', 0), ("foo", 6)], + [(r'"f\\o" x', 0), (r"f\o", 6)], + [(r'"f\\" x', 0), (r"f" + '\\', 5)], + [('"fo\\\"" x', 0), ("fo\"", 6)], + ] + for q, a in tokens: + nose.tools.eq_(cookies._read_quoted_string(*q), a) + + +def test_read_pairs(): + vals = [ + [ + "one", + [["one", None]] + ], + [ + "one=two", + [["one", "two"]] + ], + [ + "one=", + [["one", ""]] + ], + [ + 'one="two"', + [["one", "two"]] + ], + [ + 'one="two"; three=four', + [["one", "two"], ["three", "four"]] + ], + [ + 'one="two"; three=four; five', + [["one", "two"], ["three", "four"], ["five", None]] + ], + [ + 'one="\\"two"; three=four', + [["one", '"two'], ["three", "four"]] + ], + ] + for s, lst in vals: + ret, off = cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + + +def test_pairs_roundtrips(): + pairs = [ + [ + "", + [] + ], + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one", + [["one", None]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="uno"; two="\due"', + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="un\\"o"', + [["one", 'un"o']] + ], + [ + 'one="uno,due"', + [["one", 'uno,due']] + ], + [ + "one=uno; two; three=tre", + [["one", "uno"], ["two", None], ["three", "tre"]] + ], + [ + "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " + "_rcc2=53VdltWl+Ov6ordflA==;", + [ + ["_lvs2", "zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g="], + ["_rcc2", "53VdltWl+Ov6ordflA=="] + ] + ] + ] + for s, lst in pairs: + ret, off = cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + s2 = cookies._format_pairs(lst) + ret, off = cookies._read_pairs(s2) + nose.tools.eq_(ret, lst) + + +def test_cookie_roundtrips(): + pairs = [ + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + ] + for s, lst in pairs: + ret = cookies.parse_cookie_header(s) + nose.tools.eq_(ret.lst, lst) + s2 = cookies.format_cookie_header(ret) + ret = cookies.parse_cookie_header(s2) + nose.tools.eq_(ret.lst, lst) + + +def test_parse_set_cookie_pairs(): + pairs = [ + [ + "one=uno", + [ + ["one", "uno"] + ] + ], + [ + "one=un\x20", + [ + ["one", "un\x20"] + ] + ], + [ + "one=uno; foo", + [ + ["one", "uno"], + ["foo", None] + ] + ], + [ + "mun=1.390.f60; " + "expires=sun, 11-oct-2015 12:38:31 gmt; path=/; " + "domain=b.aol.com", + [ + ["mun", "1.390.f60"], + ["expires", "sun, 11-oct-2015 12:38:31 gmt"], + ["path", "/"], + ["domain", "b.aol.com"] + ] + ], + [ + r'rpb=190%3d1%2616726%3d1%2634832%3d1%2634874%3d1; ' + 'domain=.rubiconproject.com; ' + 'expires=mon, 11-may-2015 21:54:57 gmt; ' + 'path=/', + [ + ['rpb', r'190%3d1%2616726%3d1%2634832%3d1%2634874%3d1'], + ['domain', '.rubiconproject.com'], + ['expires', 'mon, 11-may-2015 21:54:57 gmt'], + ['path', '/'] + ] + ], + ] + for s, lst in pairs: + ret = cookies._parse_set_cookie_pairs(s) + nose.tools.eq_(ret, lst) + s2 = cookies._format_set_cookie_pairs(ret) + ret2 = cookies._parse_set_cookie_pairs(s2) + nose.tools.eq_(ret2, lst) + + +def test_parse_set_cookie_header(): + vals = [ + [ + "", None + ], + [ + ";", None + ], + [ + "one=uno", + ("one", "uno", []) + ], + [ + "one=uno; foo=bar", + ("one", "uno", [["foo", "bar"]]) + ] + ] + for s, expected in vals: + ret = cookies.parse_set_cookie_header(s) + if expected: + assert ret[0] == expected[0] + assert ret[1] == expected[1] + nose.tools.eq_(ret[2].lst, expected[2]) + s2 = cookies.format_set_cookie_header(*ret) + ret2 = cookies.parse_set_cookie_header(s2) + assert ret2[0] == expected[0] + assert ret2[1] == expected[1] + nose.tools.eq_(ret2[2].lst, expected[2]) + else: + assert ret is None diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py new file mode 100644 index 00000000..c4605302 --- /dev/null +++ b/test/http/test_semantics.py @@ -0,0 +1,54 @@ +import cStringIO +import textwrap +import binascii + +from netlib import http, odict, tcp +from netlib.http import http1 +from .. import tutils, tservers + +def test_httperror(): + e = http.exceptions.HttpError(404, "Not found") + assert str(e) + + +def test_parse_url(): + assert not http.parse_url("") + + u = "http://foo.com:8888/test" + s, h, po, pa = http.parse_url(u) + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + + s, h, po, pa = http.parse_url("http://foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = http.parse_url("http://user:pass@foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = http.parse_url("http://foo") + assert pa == "/" + + s, h, po, pa = http.parse_url("https://foo") + assert po == 443 + + assert not http.parse_url("https://foo:bar") + assert not http.parse_url("https://foo:") + + # Invalid IDNA + assert not http.parse_url("http://\xfafoo") + # Invalid PATH + assert not http.parse_url("http:/\xc6/localhost:56121") + # Null byte in host + assert not http.parse_url("http://foo\0") + # Port out of range + assert not http.parse_url("http://foo:999999") + # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt + assert not http.parse_url('http://lo[calhost') diff --git a/test/http/test_user_agents.py b/test/http/test_user_agents.py new file mode 100644 index 00000000..0bf1bba7 --- /dev/null +++ b/test/http/test_user_agents.py @@ -0,0 +1,6 @@ +from netlib.http import user_agents + + +def test_get_shortcut(): + assert user_agents.get_by_shortcut("c")[0] == "chrome" + assert not user_agents.get_by_shortcut("_") diff --git a/test/http2/__init__.py b/test/http2/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/http2/test_frames.py b/test/http2/test_frames.py deleted file mode 100644 index 76a4b712..00000000 --- a/test/http2/test_frames.py +++ /dev/null @@ -1,704 +0,0 @@ -import cStringIO -from test import tutils -from nose.tools import assert_equal -from netlib import tcp -from netlib.http2.frame import * - - -def hex_to_file(data): - data = data.decode('hex') - return tcp.Reader(cStringIO.StringIO(data)) - - -def test_invalid_flags(): - tutils.raises( - ValueError, - DataFrame, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - payload='foobar') - - -def test_frame_equality(): - a = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - b = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - assert_equal(a, b) - - -def test_too_large_frames(): - f = DataFrame( - length=9000, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar' * 3000) - tutils.raises(FrameSizeError, f.to_bytes) - - -def test_data_frame_to_bytes(): - f = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') - - f = DataFrame( - length=11, - flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), - stream_id=0x1234567, - payload='foobar', - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000a00090123456703666f6f626172000000') - - f = DataFrame( - length=6, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload='foobar') - tutils.raises(ValueError, f.to_bytes) - - -def test_data_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000006000101234567666f6f626172')) - assert isinstance(f, DataFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, DataFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_END_STREAM) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.payload, 'foobar') - - f = Frame.from_file(hex_to_file('00000a00090123456703666f6f626172000000')) - assert isinstance(f, DataFrame) - assert_equal(f.length, 10) - assert_equal(f.TYPE, DataFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.payload, 'foobar') - - -def test_data_frame_human_readable(): - f = DataFrame( - length=11, - flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), - stream_id=0x1234567, - payload='foobar', - pad_length=3) - assert f.human_readable() - - -def test_headers_frame_to_bytes(): - f = HeadersFrame( - length=6, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex')) - assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') - - f = HeadersFrame( - length=10, - flags=(HeadersFrame.FLAG_PADDED), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000b01080123456703668594e75e31d9000000') - - f = HeadersFrame( - length=10, - flags=(HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00000c012001234567876543212a668594e75e31d9') - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00001001280123456703876543212a668594e75e31d9000000') - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00001001280123456703076543212a668594e75e31d9000000') - - f = HeadersFrame( - length=6, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment='668594e75e31d9'.decode('hex')) - tutils.raises(ValueError, f.to_bytes) - - -def test_headers_frame_from_bytes(): - f = Frame.from_file(hex_to_file( - '000007010001234567668594e75e31d9')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 7) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - - f = Frame.from_file(hex_to_file( - '00000b01080123456703668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 11) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - - f = Frame.from_file(hex_to_file( - '00000c012001234567876543212a668594e75e31d9')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 12) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(hex_to_file( - '00001001280123456703876543212a668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 16) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(hex_to_file( - '00001001280123456703076543212a668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 16) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, False) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - -def test_headers_frame_human_readable(): - f = HeadersFrame( - length=7, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment=b'', - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert f.human_readable() - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert f.human_readable() - - -def test_priority_frame_to_bytes(): - f = PriorityFrame( - length=5, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') - - f = PriorityFrame( - length=5, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - exclusive=False, - stream_dependency=0x7654321, - weight=21) - assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') - - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - stream_dependency=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - stream_dependency=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_priority_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000005020001234567876543212a')) - assert isinstance(f, PriorityFrame) - assert_equal(f.length, 5) - assert_equal(f.TYPE, PriorityFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(hex_to_file('0000050200012345670765432115')) - assert isinstance(f, PriorityFrame) - assert_equal(f.length, 5) - assert_equal(f.TYPE, PriorityFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.exclusive, False) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 21) - - -def test_priority_frame_human_readable(): - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - exclusive=False, - stream_dependency=0x7654321, - weight=21) - assert f.human_readable() - - -def test_rst_stream_frame_to_bytes(): - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - error_code=0x7654321) - assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') - - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_rst_stream_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000403000123456707654321')) - assert isinstance(f, RstStreamFrame) - assert_equal(f.length, 4) - assert_equal(f.TYPE, RstStreamFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.error_code, 0x07654321) - - -def test_rst_stream_frame_human_readable(): - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - error_code=0x7654321) - assert f.human_readable() - - -def test_settings_frame_to_bytes(): - f = SettingsFrame( - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0) - assert_equal(f.to_bytes().encode('hex'), '000000040000000000') - - f = SettingsFrame( - length=0, - flags=SettingsFrame.FLAG_ACK, - stream_id=0x0) - assert_equal(f.to_bytes().encode('hex'), '000000040100000000') - - f = SettingsFrame( - length=6, - flags=SettingsFrame.FLAG_ACK, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) - assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') - - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert_equal( - f.to_bytes().encode('hex'), - '00000c040000000000000200000001000312345678') - - f = SettingsFrame( - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - -def test_settings_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000000040000000000')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 0) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - - f = Frame.from_file(hex_to_file('000000040100000000')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 0) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, SettingsFrame.FLAG_ACK) - assert_equal(f.stream_id, 0x0) - - f = Frame.from_file(hex_to_file('000006040100000000000200000001')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, SettingsFrame.FLAG_ACK, 0x0) - assert_equal(f.stream_id, 0x0) - assert_equal(len(f.settings), 1) - assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - - f = Frame.from_file(hex_to_file( - '00000c040000000000000200000001000312345678')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 12) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(len(f.settings), 2) - assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - assert_equal( - f.settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], - 0x12345678) - - -def test_settings_frame_human_readable(): - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={}) - assert f.human_readable() - - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert f.human_readable() - - -def test_push_promise_frame_to_bytes(): - f = PushPromiseFrame( - length=10, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar') - assert_equal( - f.to_bytes().encode('hex'), - '00000a05000123456707654321666f6f626172') - - f = PushPromiseFrame( - length=14, - flags=HeadersFrame.FLAG_PADDED, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar', - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000e0508012345670307654321666f6f626172000000') - - f = PushPromiseFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - f = PushPromiseFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - promised_stream=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_push_promise_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000a05000123456707654321666f6f626172')) - assert isinstance(f, PushPromiseFrame) - assert_equal(f.length, 10) - assert_equal(f.TYPE, PushPromiseFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - f = Frame.from_file(hex_to_file( - '00000e0508012345670307654321666f6f626172000000')) - assert isinstance(f, PushPromiseFrame) - assert_equal(f.length, 14) - assert_equal(f.TYPE, PushPromiseFrame.TYPE) - assert_equal(f.flags, PushPromiseFrame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - -def test_push_promise_frame_human_readable(): - f = PushPromiseFrame( - length=14, - flags=HeadersFrame.FLAG_PADDED, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar', - pad_length=3) - assert f.human_readable() - - -def test_ping_frame_to_bytes(): - f = PingFrame( - length=8, - flags=PingFrame.FLAG_ACK, - stream_id=0x0, - payload=b'foobar') - assert_equal( - f.to_bytes().encode('hex'), - '000008060100000000666f6f6261720000') - - f = PingFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'foobardeadbeef') - assert_equal( - f.to_bytes().encode('hex'), - '000008060000000000666f6f6261726465') - - f = PingFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - -def test_ping_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000008060100000000666f6f6261720000')) - assert isinstance(f, PingFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, PingFrame.TYPE) - assert_equal(f.flags, PingFrame.FLAG_ACK) - assert_equal(f.stream_id, 0x0) - assert_equal(f.payload, b'foobar\0\0') - - f = Frame.from_file(hex_to_file('000008060000000000666f6f6261726465')) - assert isinstance(f, PingFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, PingFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.payload, b'foobarde') - - -def test_ping_frame_human_readable(): - f = PingFrame( - length=8, - flags=PingFrame.FLAG_ACK, - stream_id=0x0, - payload=b'foobar') - assert f.human_readable() - - -def test_goaway_frame_to_bytes(): - f = GoAwayFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'') - assert_equal( - f.to_bytes().encode('hex'), - '0000080700000000000123456787654321') - - f = GoAwayFrame( - length=14, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'foobar') - assert_equal( - f.to_bytes().encode('hex'), - '00000e0700000000000123456787654321666f6f626172') - - f = GoAwayFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - last_stream=0x1234567, - error_code=0x87654321) - tutils.raises(ValueError, f.to_bytes) - - -def test_goaway_frame_from_bytes(): - f = Frame.from_file(hex_to_file( - '0000080700000000000123456787654321')) - assert isinstance(f, GoAwayFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, GoAwayFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.last_stream, 0x1234567) - assert_equal(f.error_code, 0x87654321) - assert_equal(f.data, b'') - - f = Frame.from_file(hex_to_file( - '00000e0700000000000123456787654321666f6f626172')) - assert isinstance(f, GoAwayFrame) - assert_equal(f.length, 14) - assert_equal(f.TYPE, GoAwayFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.last_stream, 0x1234567) - assert_equal(f.error_code, 0x87654321) - assert_equal(f.data, b'foobar') - - -def test_go_away_frame_human_readable(): - f = GoAwayFrame( - length=14, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'foobar') - assert f.human_readable() - - -def test_window_update_frame_to_bytes(): - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x1234567) - assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') - - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - window_size_increment=0x7654321) - assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') - - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0xdeadbeef) - tutils.raises(ValueError, f.to_bytes) - - f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) - tutils.raises(ValueError, f.to_bytes) - - -def test_window_update_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000408000000000001234567')) - assert isinstance(f, WindowUpdateFrame) - assert_equal(f.length, 4) - assert_equal(f.TYPE, WindowUpdateFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.window_size_increment, 0x1234567) - - -def test_window_update_frame_human_readable(): - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - window_size_increment=0x7654321) - assert f.human_readable() - - -def test_continuation_frame_to_bytes(): - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - header_block_fragment='foobar') - assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') - - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x0, - header_block_fragment='foobar') - tutils.raises(ValueError, f.to_bytes) - - -def test_continuation_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000006090401234567666f6f626172')) - assert isinstance(f, ContinuationFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, ContinuationFrame.TYPE) - assert_equal(f.flags, ContinuationFrame.FLAG_END_HEADERS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - -def test_continuation_frame_human_readable(): - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - header_block_fragment='foobar') - assert f.human_readable() diff --git a/test/http2/test_protocol.py b/test/http2/test_protocol.py deleted file mode 100644 index 5e2af34e..00000000 --- a/test/http2/test_protocol.py +++ /dev/null @@ -1,326 +0,0 @@ -import OpenSSL - -from netlib import http2 -from netlib import tcp -from netlib.http2.frame import * -from test import tutils -from .. import tservers - - -class EchoHandler(tcp.BaseHandler): - sni = None - - def handle(self): - while True: - v = self.rfile.safe_read(1) - self.wfile.write(v) - self.wfile.flush() - - -class TestCheckALPNMatch(tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, - ) - - if OpenSSL._util.lib.Cryptography_HAS_ALPN: - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) - assert protocol.check_alpn() - - -class TestCheckALPNMismatch(tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=None, - ) - - if OpenSSL._util.lib.Cryptography_HAS_ALPN: - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) - tutils.raises(NotImplementedError, protocol.check_alpn) - - -class TestPerformServerConnectionPreface(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # send magic - self.wfile.write( - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) - self.wfile.flush() - - # send empty settings frame - self.wfile.write('000000040000000000'.decode('hex')) - self.wfile.flush() - - # check empty settings frame - assert self.rfile.read(9) ==\ - '000000040000000000'.decode('hex') - - # check settings acknowledgement - assert self.rfile.read(9) == \ - '000000040100000000'.decode('hex') - - # send settings acknowledgement - self.wfile.write('000000040100000000'.decode('hex')) - self.wfile.flush() - - def test_perform_server_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - protocol = http2.HTTP2Protocol(c) - protocol.perform_server_connection_preface() - - -class TestPerformClientConnectionPreface(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # check magic - assert self.rfile.read(24) ==\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') - - # check empty settings frame - assert self.rfile.read(9) ==\ - '000000040000000000'.decode('hex') - - # send empty settings frame - self.wfile.write('000000040000000000'.decode('hex')) - self.wfile.flush() - - # check settings acknowledgement - assert self.rfile.read(9) == \ - '000000040100000000'.decode('hex') - - # send settings acknowledgement - self.wfile.write('000000040100000000'.decode('hex')) - self.wfile.flush() - - def test_perform_client_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - protocol = http2.HTTP2Protocol(c) - protocol.perform_client_connection_preface() - - -class TestClientStreamIds(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) - - def test_client_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 1 - assert self.protocol.current_stream_id == 1 - assert self.protocol.next_stream_id() == 3 - assert self.protocol.current_stream_id == 3 - assert self.protocol.next_stream_id() == 5 - assert self.protocol.current_stream_id == 5 - - -class TestServerStreamIds(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c, is_server=True) - - def test_server_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 2 - assert self.protocol.current_stream_id == 2 - assert self.protocol.next_stream_id() == 4 - assert self.protocol.current_stream_id == 4 - assert self.protocol.next_stream_id() == 6 - assert self.protocol.current_stream_id == 6 - - -class TestApplySettings(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # check settings acknowledgement - assert self.rfile.read(9) == '000000040100000000'.decode('hex') - self.wfile.write("OK") - self.wfile.flush() - - ssl = True - - def test_apply_settings(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - protocol._apply_settings({ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 'bar', - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 'deadbeef', - }) - - assert c.rfile.safe_read(2) == "OK" - - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 'foo' - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 'bar' - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE] == 'deadbeef' - - -class TestCreateHeaders(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_headers(self): - headers = [ - (b':method', b'GET'), - (b':path', b'index.html'), - (b':scheme', b'https'), - (b'foo', b'bar')] - - bytes = http2.HTTP2Protocol(self.c)._create_headers( - headers, 1, end_stream=True) - assert b''.join(bytes) ==\ - '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') - - bytes = http2.HTTP2Protocol(self.c)._create_headers( - headers, 1, end_stream=False) - assert b''.join(bytes) ==\ - '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') - - # TODO: add test for too large header_block_fragments - - -class TestCreateBody(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) - - def test_create_body_empty(self): - bytes = self.protocol._create_body(b'', 1) - assert b''.join(bytes) == ''.decode('hex') - - def test_create_body_single_frame(self): - bytes = self.protocol._create_body('foobar', 1) - assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') - - def test_create_body_multiple_frames(self): - pass - # bytes = self.protocol._create_body('foobar' * 3000, 1) - # TODO: add test for too large frames - - -class TestCreateRequest(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_request_simple(self): - bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') - assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') - - def test_create_request_with_body(self): - bytes = http2.HTTP2Protocol(self.c).create_request( - 'GET', '/', [(b'foo', b'bar')], 'foobar') - assert len(bytes) == 2 - assert bytes[0] ==\ - '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') - - -class TestReadResponse(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'00000801040000000188628594e78c767f'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - status, headers, body = protocol.read_response() - - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'foobar' - - -class TestReadEmptyResponse(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'00000801050000000188628594e78c767f'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_empty_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - status, headers, body = protocol.read_response() - - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'' - - -class TestReadRequest(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'000003010400000001828487'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_request(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c, is_server=True) - - stream_id, headers, body = protocol.read_request() - - assert stream_id - assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} - assert body == b'foobar' - - -class TestCreateResponse(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_response_simple(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) - assert len(bytes) == 1 - assert bytes[0] ==\ - '00000101050000000288'.decode('hex') - - def test_create_response_with_body(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( - 200, 1, [(b'foo', b'bar')], 'foobar') - assert len(bytes) == 2 - assert bytes[0] ==\ - '00000901040000000188408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') diff --git a/test/test_http.py b/test/test_http.py deleted file mode 100644 index bbc78847..00000000 --- a/test/test_http.py +++ /dev/null @@ -1,491 +0,0 @@ -import cStringIO -import textwrap -import binascii -from netlib import http, http_semantics, odict, tcp -from . import tutils, tservers - - -def test_httperror(): - e = http.HttpError(404, "Not found") - assert str(e) - - -def test_has_chunked_encoding(): - h = odict.ODictCaseless() - assert not http.has_chunked_encoding(h) - h["transfer-encoding"] = ["chunked"] - assert http.has_chunked_encoding(h) - - -def test_read_chunked(): - - h = odict.ODictCaseless() - h["transfer-encoding"] = ["chunked"] - s = cStringIO.StringIO("1\r\na\r\n0\r\n") - - tutils.raises( - "malformed chunked body", - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert http.read_http_body(s, h, None, "GET", None, True) == "a" - - s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") - assert http.read_http_body(s, h, None, "GET", None, True) == "a" - - s = cStringIO.StringIO("\r\n") - tutils.raises( - "closed prematurely", - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("1\r\nfoo") - tutils.raises( - "malformed chunked body", - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - tutils.raises("too large", http.read_http_body, s, h, 2, "GET", None, True) - - -def test_connection_close(): - h = odict.ODictCaseless() - assert http.connection_close((1, 0), h) - assert not http.connection_close((1, 1), h) - - h["connection"] = ["keep-alive"] - assert not http.connection_close((1, 1), h) - - h["connection"] = ["close"] - assert http.connection_close((1, 1), h) - - -def test_get_header_tokens(): - h = odict.ODictCaseless() - assert http.get_header_tokens(h, "foo") == [] - h["foo"] = ["bar"] - assert http.get_header_tokens(h, "foo") == ["bar"] - h["foo"] = ["bar, voing"] - assert http.get_header_tokens(h, "foo") == ["bar", "voing"] - h["foo"] = ["bar, voing", "oink"] - assert http.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] - - -def test_read_http_body_request(): - h = odict.ODictCaseless() - r = cStringIO.StringIO("testing") - assert http.read_http_body(r, h, None, "GET", None, True) == "" - - -def test_read_http_body_response(): - h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" - - -def test_read_http_body(): - # test default case - h = odict.ODictCaseless() - h["content-length"] = [7] - s = cStringIO.StringIO("testing") - assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" - - # test content length: invalid header - h["content-length"] = ["foo"] - s = cStringIO.StringIO("testing") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, None, "GET", 200, False - ) - - # test content length: invalid header #2 - h["content-length"] = [-1] - s = cStringIO.StringIO("testing") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, None, "GET", 200, False - ) - - # test content length: content length > actual content - h["content-length"] = [5] - s = cStringIO.StringIO("testing") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, 4, "GET", 200, False - ) - - # test content length: content length < actual content - s = cStringIO.StringIO("testing") - assert len(http.read_http_body(s, h, None, "GET", 200, False)) == 5 - - # test no content length: limit > actual content - h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7 - - # test no content length: limit < actual content - s = tcp.Reader(cStringIO.StringIO("testing")) - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, 4, "GET", 200, False - ) - - # test chunked - h = odict.ODictCaseless() - h["transfer-encoding"] = ["chunked"] - s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) - assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" - - -def test_expected_http_body_size(): - # gibber in the content-length field - h = odict.ODictCaseless() - h["content-length"] = ["foo"] - assert http.expected_http_body_size(h, False, "GET", 200) is None - # negative number in the content-length field - h = odict.ODictCaseless() - h["content-length"] = ["-7"] - assert http.expected_http_body_size(h, False, "GET", 200) is None - # explicit length - h = odict.ODictCaseless() - h["content-length"] = ["5"] - assert http.expected_http_body_size(h, False, "GET", 200) == 5 - # no length - h = odict.ODictCaseless() - assert http.expected_http_body_size(h, False, "GET", 200) == -1 - # no length request - h = odict.ODictCaseless() - assert http.expected_http_body_size(h, True, "GET", None) == 0 - - -def test_parse_http_protocol(): - assert http.parse_http_protocol("HTTP/1.1") == (1, 1) - assert http.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not http.parse_http_protocol("HTTP/a.1") - assert not http.parse_http_protocol("HTTP/1.a") - assert not http.parse_http_protocol("foo/0.0") - assert not http.parse_http_protocol("HTTP/x") - - -def test_parse_init_connect(): - assert http.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not http.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") - assert not http.parse_init_connect("bogus") - assert not http.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com:443 foo/1.0") - assert not http.parse_init_connect("CONNECT host.com:foo HTTP/1.0") - - -def test_parse_init_proxy(): - u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = http.parse_init_proxy(u) - assert m == "GET" - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not http.parse_init_proxy(u) - - assert not http.parse_init_proxy("invalid") - assert not http.parse_init_proxy("GET invalid HTTP/1.1") - assert not http.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") - - -def test_parse_init_http(): - u = "GET /test HTTP/1.1" - m, u, httpversion = http.parse_init_http(u) - assert m == "GET" - assert u == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET /test HTTP/1.1" - assert not http.parse_init_http(u) - - assert not http.parse_init_http("invalid") - assert not http.parse_init_http("GET invalid HTTP/1.1") - assert not http.parse_init_http("GET /test foo/1.1") - assert not http.parse_init_http("GET /test\xc0 HTTP/1.1") - - -class TestReadHeaders: - - def _read(self, data, verbatim=False): - if not verbatim: - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - return http.read_headers(s) - - def test_read_simple(self): - data = """ - Header: one - Header2: two - \r\n - """ - h = self._read(data) - assert h.lst == [["Header", "one"], ["Header2", "two"]] - - def test_read_multi(self): - data = """ - Header: one - Header: two - \r\n - """ - h = self._read(data) - assert h.lst == [["Header", "one"], ["Header", "two"]] - - def test_read_continued(self): - data = """ - Header: one - \ttwo - Header2: three - \r\n - """ - h = self._read(data) - assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] - - def test_read_continued_err(self): - data = "\tfoo: bar\r\n" - assert self._read(data, True) is None - - def test_read_err(self): - data = """ - foo - """ - assert self._read(data) is None - - -class NoContentLengthHTTPHandler(tcp.BaseHandler): - - def handle(self): - self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") - self.wfile.flush() - - -class TestReadResponseNoContentLength(tservers.ServerTestBase): - handler = NoContentLengthHTTPHandler - - def test_no_content_length(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - resp = http.read_response(c.rfile, "GET", None) - assert resp.content == "bar\r\n\r\n" - - -def test_read_response(): - def tst(data, method, limit, include_body=True): - data = textwrap.dedent(data) - r = cStringIO.StringIO(data) - return http.read_response( - r, method, limit, include_body=include_body - ) - - tutils.raises("server disconnect", tst, "", "GET", None) - tutils.raises("invalid server response", tst, "foo", "GET", None) - data = """ - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http_semantics.Response( - (1, 1), 200, 'OK', odict.ODictCaseless(), '' - ) - data = """ - HTTP/1.1 200 - """ - assert tst(data, "GET", None) == http_semantics.Response( - (1, 1), 200, '', odict.ODictCaseless(), '' - ) - data = """ - HTTP/x 200 OK - """ - tutils.raises("invalid http version", tst, data, "GET", None) - data = """ - HTTP/1.1 xx OK - """ - tutils.raises("invalid server response", tst, data, "GET", None) - - data = """ - HTTP/1.1 100 CONTINUE - - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http_semantics.Response( - (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' - ) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None).content == 'foo' - assert tst(data, "HEAD", None).content == '' - - data = """ - HTTP/1.1 200 OK - \tContent-Length: 3 - - foo - """ - tutils.raises("invalid headers", tst, data, "GET", None) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None, include_body=False).content is None - - -def test_parse_url(): - assert not http.parse_url("") - - u = "http://foo.com:8888/test" - s, h, po, pa = http.parse_url(u) - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - - s, h, po, pa = http.parse_url("http://foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = http.parse_url("http://user:pass@foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = http.parse_url("http://foo") - assert pa == "/" - - s, h, po, pa = http.parse_url("https://foo") - assert po == 443 - - assert not http.parse_url("https://foo:bar") - assert not http.parse_url("https://foo:") - - # Invalid IDNA - assert not http.parse_url("http://\xfafoo") - # Invalid PATH - assert not http.parse_url("http:/\xc6/localhost:56121") - # Null byte in host - assert not http.parse_url("http://foo\0") - # Port out of range - assert not http.parse_url("http://foo:999999") - # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt - assert not http.parse_url('http://lo[calhost') - - -def test_parse_http_basic_auth(): - vals = ("basic", "foo", "bar") - assert http.parse_http_basic_auth( - http.assemble_http_basic_auth(*vals) - ) == vals - assert not http.parse_http_basic_auth("") - assert not http.parse_http_basic_auth("foo bar") - v = "basic " + binascii.b2a_base64("foo") - assert not http.parse_http_basic_auth(v) - - -def test_get_request_line(): - r = cStringIO.StringIO("\nfoo") - assert http.get_request_line(r) == "foo" - assert not http.get_request_line(r) - - -class TestReadRequest(): - - def tst(self, data, **kwargs): - r = cStringIO.StringIO(data) - return http.read_request(r, **kwargs) - - def test_invalid(self): - tutils.raises( - "bad http request", - self.tst, - "xxx" - ) - tutils.raises( - "bad http request line", - self.tst, - "get /\xff HTTP/1.1" - ) - tutils.raises( - "invalid headers", - self.tst, - "get / HTTP/1.1\r\nfoo" - ) - tutils.raises( - tcp.NetLibDisconnect, - self.tst, - "\r\n" - ) - - def test_asterisk_form_in(self): - v = self.tst("OPTIONS * HTTP/1.1") - assert v.form_in == "relative" - assert v.method == "OPTIONS" - - def test_absolute_form_in(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "GET oops-no-protocol.com HTTP/1.1" - ) - v = self.tst("GET http://address:22/ HTTP/1.1") - assert v.form_in == "absolute" - assert v.port == 22 - assert v.host == "address" - assert v.scheme == "http" - - def test_connect(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "CONNECT oops-no-port.com HTTP/1.1" - ) - v = self.tst("CONNECT foo.com:443 HTTP/1.1") - assert v.form_in == "authority" - assert v.method == "CONNECT" - assert v.port == 443 - assert v.host == "foo.com" - - def test_expect(self): - w = cStringIO.StringIO() - r = cStringIO.StringIO( - "GET / HTTP/1.1\r\n" - "Content-Length: 3\r\n" - "Expect: 100-continue\r\n\r\n" - "foobar", - ) - v = http.read_request(r, wfile=w) - assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" - assert v.content == "foo" - assert r.read(3) == "bar" diff --git a/test/test_http_auth.py b/test/test_http_auth.py deleted file mode 100644 index c842925b..00000000 --- a/test/test_http_auth.py +++ /dev/null @@ -1,109 +0,0 @@ -from netlib import odict, http_auth, http -import tutils - - -class TestPassManNonAnon: - - def test_simple(self): - p = http_auth.PassManNonAnon() - assert not p.test("", "") - assert p.test("user", "") - - -class TestPassManHtpasswd: - - def test_file_errors(self): - tutils.raises( - "malformed htpasswd file", - http_auth.PassManHtpasswd, - tutils.test_data.path("data/server.crt")) - - def test_simple(self): - pm = http_auth.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) - - vals = ("basic", "test", "test") - http.assemble_http_basic_auth(*vals) - assert pm.test("test", "test") - assert not pm.test("test", "foo") - assert not pm.test("foo", "test") - assert not pm.test("test", "") - assert not pm.test("", "") - - -class TestPassManSingleUser: - - def test_simple(self): - pm = http_auth.PassManSingleUser("test", "test") - assert pm.test("test", "test") - assert not pm.test("test", "foo") - assert not pm.test("foo", "test") - - -class TestNullProxyAuth: - - def test_simple(self): - na = http_auth.NullProxyAuth(http_auth.PassManNonAnon()) - assert not na.auth_challenge_headers() - assert na.authenticate("foo") - na.clean({}) - - -class TestBasicProxyAuth: - - def test_simple(self): - ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test") - h = odict.ODictCaseless() - assert ba.auth_challenge_headers() - assert not ba.authenticate(h) - - def test_authenticate_clean(self): - ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test") - - hdrs = odict.ODictCaseless() - vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] - assert ba.authenticate(hdrs) - - ba.clean(hdrs) - assert not ba.AUTH_HEADER in hdrs - - hdrs[ba.AUTH_HEADER] = [""] - assert not ba.authenticate(hdrs) - - hdrs[ba.AUTH_HEADER] = ["foo"] - assert not ba.authenticate(hdrs) - - vals = ("foo", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] - assert not ba.authenticate(hdrs) - - ba = http_auth.BasicProxyAuth(http_auth.PassMan(), "test") - vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] - assert not ba.authenticate(hdrs) - - -class Bunch: - pass - - -class TestAuthAction: - - def test_nonanonymous(self): - m = Bunch() - aa = http_auth.NonanonymousAuthAction(None, "authenticator") - aa(None, m, None, None) - assert m.authenticator - - def test_singleuser(self): - m = Bunch() - aa = http_auth.SingleuserAuthAction(None, "authenticator") - aa(None, m, "foo:bar", None) - assert m.authenticator - tutils.raises("invalid", aa, None, m, "foo", None) - - def test_httppasswd(self): - m = Bunch() - aa = http_auth.HtpasswdAuthAction(None, "authenticator") - aa(None, m, tutils.test_data.path("data/htpasswd"), None) - assert m.authenticator diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py deleted file mode 100644 index 070849cf..00000000 --- a/test/test_http_cookies.py +++ /dev/null @@ -1,219 +0,0 @@ -import nose.tools - -from netlib import http_cookies - - -def test_read_token(): - tokens = [ - [("foo", 0), ("foo", 3)], - [("foo", 1), ("oo", 3)], - [(" foo", 1), ("foo", 4)], - [(" foo;", 1), ("foo", 4)], - [(" foo=", 1), ("foo", 4)], - [(" foo=bar", 1), ("foo", 4)], - ] - for q, a in tokens: - nose.tools.eq_(http_cookies._read_token(*q), a) - - -def test_read_quoted_string(): - tokens = [ - [('"foo" x', 0), ("foo", 5)], - [('"f\oo" x', 0), ("foo", 6)], - [(r'"f\\o" x', 0), (r"f\o", 6)], - [(r'"f\\" x', 0), (r"f" + '\\', 5)], - [('"fo\\\"" x', 0), ("fo\"", 6)], - ] - for q, a in tokens: - nose.tools.eq_(http_cookies._read_quoted_string(*q), a) - - -def test_read_pairs(): - vals = [ - [ - "one", - [["one", None]] - ], - [ - "one=two", - [["one", "two"]] - ], - [ - "one=", - [["one", ""]] - ], - [ - 'one="two"', - [["one", "two"]] - ], - [ - 'one="two"; three=four', - [["one", "two"], ["three", "four"]] - ], - [ - 'one="two"; three=four; five', - [["one", "two"], ["three", "four"], ["five", None]] - ], - [ - 'one="\\"two"; three=four', - [["one", '"two'], ["three", "four"]] - ], - ] - for s, lst in vals: - ret, off = http_cookies._read_pairs(s) - nose.tools.eq_(ret, lst) - - -def test_pairs_roundtrips(): - pairs = [ - [ - "", - [] - ], - [ - "one=uno", - [["one", "uno"]] - ], - [ - "one", - [["one", None]] - ], - [ - "one=uno; two=due", - [["one", "uno"], ["two", "due"]] - ], - [ - 'one="uno"; two="\due"', - [["one", "uno"], ["two", "due"]] - ], - [ - 'one="un\\"o"', - [["one", 'un"o']] - ], - [ - 'one="uno,due"', - [["one", 'uno,due']] - ], - [ - "one=uno; two; three=tre", - [["one", "uno"], ["two", None], ["three", "tre"]] - ], - [ - "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " - "_rcc2=53VdltWl+Ov6ordflA==;", - [ - ["_lvs2", "zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g="], - ["_rcc2", "53VdltWl+Ov6ordflA=="] - ] - ] - ] - for s, lst in pairs: - ret, off = http_cookies._read_pairs(s) - nose.tools.eq_(ret, lst) - s2 = http_cookies._format_pairs(lst) - ret, off = http_cookies._read_pairs(s2) - nose.tools.eq_(ret, lst) - - -def test_cookie_roundtrips(): - pairs = [ - [ - "one=uno", - [["one", "uno"]] - ], - [ - "one=uno; two=due", - [["one", "uno"], ["two", "due"]] - ], - ] - for s, lst in pairs: - ret = http_cookies.parse_cookie_header(s) - nose.tools.eq_(ret.lst, lst) - s2 = http_cookies.format_cookie_header(ret) - ret = http_cookies.parse_cookie_header(s2) - nose.tools.eq_(ret.lst, lst) - - -def test_parse_set_cookie_pairs(): - pairs = [ - [ - "one=uno", - [ - ["one", "uno"] - ] - ], - [ - "one=un\x20", - [ - ["one", "un\x20"] - ] - ], - [ - "one=uno; foo", - [ - ["one", "uno"], - ["foo", None] - ] - ], - [ - "mun=1.390.f60; " - "expires=sun, 11-oct-2015 12:38:31 gmt; path=/; " - "domain=b.aol.com", - [ - ["mun", "1.390.f60"], - ["expires", "sun, 11-oct-2015 12:38:31 gmt"], - ["path", "/"], - ["domain", "b.aol.com"] - ] - ], - [ - r'rpb=190%3d1%2616726%3d1%2634832%3d1%2634874%3d1; ' - 'domain=.rubiconproject.com; ' - 'expires=mon, 11-may-2015 21:54:57 gmt; ' - 'path=/', - [ - ['rpb', r'190%3d1%2616726%3d1%2634832%3d1%2634874%3d1'], - ['domain', '.rubiconproject.com'], - ['expires', 'mon, 11-may-2015 21:54:57 gmt'], - ['path', '/'] - ] - ], - ] - for s, lst in pairs: - ret = http_cookies._parse_set_cookie_pairs(s) - nose.tools.eq_(ret, lst) - s2 = http_cookies._format_set_cookie_pairs(ret) - ret2 = http_cookies._parse_set_cookie_pairs(s2) - nose.tools.eq_(ret2, lst) - - -def test_parse_set_cookie_header(): - vals = [ - [ - "", None - ], - [ - ";", None - ], - [ - "one=uno", - ("one", "uno", []) - ], - [ - "one=uno; foo=bar", - ("one", "uno", [["foo", "bar"]]) - ] - ] - for s, expected in vals: - ret = http_cookies.parse_set_cookie_header(s) - if expected: - assert ret[0] == expected[0] - assert ret[1] == expected[1] - nose.tools.eq_(ret[2].lst, expected[2]) - s2 = http_cookies.format_set_cookie_header(*ret) - ret2 = http_cookies.parse_set_cookie_header(s2) - assert ret2[0] == expected[0] - assert ret2[1] == expected[1] - nose.tools.eq_(ret2[2].lst, expected[2]) - else: - assert ret is None diff --git a/test/test_http_uastrings.py b/test/test_http_uastrings.py deleted file mode 100644 index 3fa4f359..00000000 --- a/test/test_http_uastrings.py +++ /dev/null @@ -1,6 +0,0 @@ -from netlib import http_uastrings - - -def test_get_shortcut(): - assert http_uastrings.get_by_shortcut("c")[0] == "chrome" - assert not http_uastrings.get_by_shortcut("_") diff --git a/test/test_websockets.py b/test/test_websockets.py deleted file mode 100644 index ae0a5e33..00000000 --- a/test/test_websockets.py +++ /dev/null @@ -1,261 +0,0 @@ -import os - -from nose.tools import raises - -from netlib import tcp, websockets, http -from . import tutils, tservers - - -class WebSocketsEchoHandler(tcp.BaseHandler): - - def __init__(self, connection, address, server): - super(WebSocketsEchoHandler, self).__init__( - connection, address, server - ) - self.protocol = websockets.WebsocketsProtocol() - self.handshake_done = False - - def handle(self): - while True: - if not self.handshake_done: - self.handshake() - else: - self.read_next_message() - - def read_next_message(self): - frame = websockets.Frame.from_file(self.rfile) - self.on_message(frame.payload) - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=False) - frame.to_file(self.wfile) - - def handshake(self): - req = http.read_request(self.rfile) - key = self.protocol.check_client_handshake(req.headers) - - self.wfile.write(http.response_preamble(101) + "\r\n") - headers = self.protocol.server_handshake_headers(key) - self.wfile.write(headers.format() + "\r\n") - self.wfile.flush() - self.handshake_done = True - - def on_message(self, message): - if message is not None: - self.send_message(message) - - -class WebSocketsClient(tcp.TCPClient): - - def __init__(self, address, source_address=None): - super(WebSocketsClient, self).__init__(address, source_address) - self.protocol = websockets.WebsocketsProtocol() - self.client_nonce = None - - def connect(self): - super(WebSocketsClient, self).connect() - - preamble = http.request_preamble("GET", "/") - self.wfile.write(preamble + "\r\n") - headers = self.protocol.client_handshake_headers() - self.client_nonce = headers.get_first("sec-websocket-key") - self.wfile.write(headers.format() + "\r\n") - self.wfile.flush() - - resp = http.read_response(self.rfile, "get", None) - server_nonce = self.protocol.check_server_handshake(resp.headers) - - if not server_nonce == self.protocol.create_server_nonce( - self.client_nonce): - self.close() - - def read_next_message(self): - return websockets.Frame.from_file(self.rfile).payload - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=True) - frame.to_file(self.wfile) - - -class TestWebSockets(tservers.ServerTestBase): - handler = WebSocketsEchoHandler - - def __init__(self): - self.protocol = websockets.WebsocketsProtocol() - - def random_bytes(self, n=100): - return os.urandom(n) - - def echo(self, msg): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(msg) - response = client.read_next_message() - assert response == msg - - def test_simple_echo(self): - self.echo("hello I'm the client") - - def test_frame_sizes(self): - # length can fit in the the 7 bit payload length - small_msg = self.random_bytes(100) - # 50kb, sligthly larger than can fit in a 7 bit int - medium_msg = self.random_bytes(50000) - # 150kb, slightly larger than can fit in a 16 bit int - large_msg = self.random_bytes(150000) - - self.echo(small_msg) - self.echo(medium_msg) - self.echo(large_msg) - - def test_default_builder(self): - """ - default builder should always generate valid frames - """ - msg = self.random_bytes() - client_frame = websockets.Frame.default(msg, from_client=True) - server_frame = websockets.Frame.default(msg, from_client=False) - - def test_serialization_bijection(self): - """ - Ensure that various frame types can be serialized/deserialized back - and forth between to_bytes() and from_bytes() - """ - for is_client in [True, False]: - for num_bytes in [100, 50000, 150000]: - frame = websockets.Frame.default( - self.random_bytes(num_bytes), is_client - ) - frame2 = websockets.Frame.from_bytes( - frame.to_bytes() - ) - assert frame == frame2 - - bytes = b'\x81\x03cba' - assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes - - def test_check_server_handshake(self): - headers = self.protocol.server_handshake_headers("key") - assert self.protocol.check_server_handshake(headers) - headers["Upgrade"] = ["not_websocket"] - assert not self.protocol.check_server_handshake(headers) - - def test_check_client_handshake(self): - headers = self.protocol.client_handshake_headers("key") - assert self.protocol.check_client_handshake(headers) == "key" - headers["Upgrade"] = ["not_websocket"] - assert not self.protocol.check_client_handshake(headers) - - -class BadHandshakeHandler(WebSocketsEchoHandler): - - def handshake(self): - client_hs = http.read_request(self.rfile) - self.protocol.check_client_handshake(client_hs.headers) - - self.wfile.write(http.response_preamble(101) + "\r\n") - headers = self.protocol.server_handshake_headers("malformed key") - self.wfile.write(headers.format() + "\r\n") - self.wfile.flush() - self.handshake_done = True - - -class TestBadHandshake(tservers.ServerTestBase): - - """ - Ensure that the client disconnects if the server handshake is malformed - """ - handler = BadHandshakeHandler - - @raises(tcp.NetLibDisconnect) - def test(self): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message("hello") - - -class TestFrameHeader: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.FrameHeader(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) - assert f == f2 - round() - round(fin=1) - round(rsv1=1) - round(rsv2=1) - round(rsv3=1) - round(payload_length=1) - round(payload_length=100) - round(payload_length=1000) - round(payload_length=10000) - round(opcode=websockets.OPCODE.PING) - round(masking_key="test") - - def test_human_readable(self): - f = websockets.FrameHeader( - masking_key="test", - fin=True, - payload_length=10 - ) - assert f.human_readable() - f = websockets.FrameHeader() - assert f.human_readable() - - def test_funky(self): - f = websockets.FrameHeader(masking_key="test", mask=False) - bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) - assert not f2.mask - - def test_violations(self): - tutils.raises("opcode", websockets.FrameHeader, opcode=17) - tutils.raises("masking key", websockets.FrameHeader, masking_key="x") - - def test_automask(self): - f = websockets.FrameHeader(mask=True) - assert f.masking_key - - f = websockets.FrameHeader(masking_key="foob") - assert f.mask - - f = websockets.FrameHeader(masking_key="foob", mask=0) - assert not f.mask - assert f.masking_key - - -class TestFrame: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.Frame(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.Frame.from_file(tutils.treader(bytes)) - assert f == f2 - round("test") - round("test", fin=1) - round("test", rsv1=1) - round("test", opcode=websockets.OPCODE.PING) - round("test", masking_key="test") - - def test_human_readable(self): - f = websockets.Frame() - assert f.human_readable() - - -def test_masker(): - tests = [ - ["a"], - ["four"], - ["fourf"], - ["fourfive"], - ["a", "aasdfasdfa", "asdf"], - ["a" * 50, "aasdfasdfa", "asdf"], - ] - for i in tests: - m = websockets.Masker("abcd") - data = "".join([m(t) for t in i]) - data2 = websockets.Masker("abcd")(data) - assert data2 == "".join(i) diff --git a/test/websockets/__init__.py b/test/websockets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py new file mode 100644 index 00000000..07ad0452 --- /dev/null +++ b/test/websockets/test_websockets.py @@ -0,0 +1,262 @@ +import os + +from nose.tools import raises + +from netlib import tcp, http, websockets +from netlib.http.exceptions import * +from .. import tutils, tservers + + +class WebSocketsEchoHandler(tcp.BaseHandler): + + def __init__(self, connection, address, server): + super(WebSocketsEchoHandler, self).__init__( + connection, address, server + ) + self.protocol = websockets.WebsocketsProtocol() + self.handshake_done = False + + def handle(self): + while True: + if not self.handshake_done: + self.handshake() + else: + self.read_next_message() + + def read_next_message(self): + frame = websockets.Frame.from_file(self.rfile) + self.on_message(frame.payload) + + def send_message(self, message): + frame = websockets.Frame.default(message, from_client=False) + frame.to_file(self.wfile) + + def handshake(self): + req = http.http1.read_request(self.rfile) + key = self.protocol.check_client_handshake(req.headers) + + self.wfile.write(http.http1.response_preamble(101) + "\r\n") + headers = self.protocol.server_handshake_headers(key) + self.wfile.write(headers.format() + "\r\n") + self.wfile.flush() + self.handshake_done = True + + def on_message(self, message): + if message is not None: + self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + + def __init__(self, address, source_address=None): + super(WebSocketsClient, self).__init__(address, source_address) + self.protocol = websockets.WebsocketsProtocol() + self.client_nonce = None + + def connect(self): + super(WebSocketsClient, self).connect() + + preamble = http.http1.protocol.request_preamble("GET", "/") + self.wfile.write(preamble + "\r\n") + headers = self.protocol.client_handshake_headers() + self.client_nonce = headers.get_first("sec-websocket-key") + self.wfile.write(headers.format() + "\r\n") + self.wfile.flush() + + resp = http.http1.protocol.read_response(self.rfile, "get", None) + server_nonce = self.protocol.check_server_handshake(resp.headers) + + if not server_nonce == self.protocol.create_server_nonce( + self.client_nonce): + self.close() + + def read_next_message(self): + return websockets.Frame.from_file(self.rfile).payload + + def send_message(self, message): + frame = websockets.Frame.default(message, from_client=True) + frame.to_file(self.wfile) + + +class TestWebSockets(tservers.ServerTestBase): + handler = WebSocketsEchoHandler + + def __init__(self): + self.protocol = websockets.WebsocketsProtocol() + + def random_bytes(self, n=100): + return os.urandom(n) + + def echo(self, msg): + client = WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message(msg) + response = client.read_next_message() + assert response == msg + + def test_simple_echo(self): + self.echo("hello I'm the client") + + def test_frame_sizes(self): + # length can fit in the the 7 bit payload length + small_msg = self.random_bytes(100) + # 50kb, sligthly larger than can fit in a 7 bit int + medium_msg = self.random_bytes(50000) + # 150kb, slightly larger than can fit in a 16 bit int + large_msg = self.random_bytes(150000) + + self.echo(small_msg) + self.echo(medium_msg) + self.echo(large_msg) + + def test_default_builder(self): + """ + default builder should always generate valid frames + """ + msg = self.random_bytes() + client_frame = websockets.Frame.default(msg, from_client=True) + server_frame = websockets.Frame.default(msg, from_client=False) + + def test_serialization_bijection(self): + """ + Ensure that various frame types can be serialized/deserialized back + and forth between to_bytes() and from_bytes() + """ + for is_client in [True, False]: + for num_bytes in [100, 50000, 150000]: + frame = websockets.Frame.default( + self.random_bytes(num_bytes), is_client + ) + frame2 = websockets.Frame.from_bytes( + frame.to_bytes() + ) + assert frame == frame2 + + bytes = b'\x81\x03cba' + assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes + + def test_check_server_handshake(self): + headers = self.protocol.server_handshake_headers("key") + assert self.protocol.check_server_handshake(headers) + headers["Upgrade"] = ["not_websocket"] + assert not self.protocol.check_server_handshake(headers) + + def test_check_client_handshake(self): + headers = self.protocol.client_handshake_headers("key") + assert self.protocol.check_client_handshake(headers) == "key" + headers["Upgrade"] = ["not_websocket"] + assert not self.protocol.check_client_handshake(headers) + + +class BadHandshakeHandler(WebSocketsEchoHandler): + + def handshake(self): + client_hs = http.http1.protocol.read_request(self.rfile) + self.protocol.check_client_handshake(client_hs.headers) + + self.wfile.write(http.http1.protocol.response_preamble(101) + "\r\n") + headers = self.protocol.server_handshake_headers("malformed key") + self.wfile.write(headers.format() + "\r\n") + self.wfile.flush() + self.handshake_done = True + + +class TestBadHandshake(tservers.ServerTestBase): + + """ + Ensure that the client disconnects if the server handshake is malformed + """ + handler = BadHandshakeHandler + + @raises(tcp.NetLibDisconnect) + def test(self): + client = WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message("hello") + + +class TestFrameHeader: + + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.FrameHeader(*args, **kwargs) + bytes = f.to_bytes() + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) + assert f == f2 + round() + round(fin=1) + round(rsv1=1) + round(rsv2=1) + round(rsv3=1) + round(payload_length=1) + round(payload_length=100) + round(payload_length=1000) + round(payload_length=10000) + round(opcode=websockets.OPCODE.PING) + round(masking_key="test") + + def test_human_readable(self): + f = websockets.FrameHeader( + masking_key="test", + fin=True, + payload_length=10 + ) + assert f.human_readable() + f = websockets.FrameHeader() + assert f.human_readable() + + def test_funky(self): + f = websockets.FrameHeader(masking_key="test", mask=False) + bytes = f.to_bytes() + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) + assert not f2.mask + + def test_violations(self): + tutils.raises("opcode", websockets.FrameHeader, opcode=17) + tutils.raises("masking key", websockets.FrameHeader, masking_key="x") + + def test_automask(self): + f = websockets.FrameHeader(mask=True) + assert f.masking_key + + f = websockets.FrameHeader(masking_key="foob") + assert f.mask + + f = websockets.FrameHeader(masking_key="foob", mask=0) + assert not f.mask + assert f.masking_key + + +class TestFrame: + + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.Frame(*args, **kwargs) + bytes = f.to_bytes() + f2 = websockets.Frame.from_file(tutils.treader(bytes)) + assert f == f2 + round("test") + round("test", fin=1) + round("test", rsv1=1) + round("test", opcode=websockets.OPCODE.PING) + round("test", masking_key="test") + + def test_human_readable(self): + f = websockets.Frame() + assert f.human_readable() + + +def test_masker(): + tests = [ + ["a"], + ["four"], + ["fourf"], + ["fourfive"], + ["a", "aasdfasdfa", "asdf"], + ["a" * 50, "aasdfasdfa", "asdf"], + ] + for i in tests: + m = websockets.Masker("abcd") + data = "".join([m(t) for t in i]) + data2 = websockets.Masker("abcd")(data) + assert data2 == "".join(i) -- cgit v1.2.3 From bab6cbff1e5444aea72a188d57812130c375e0f0 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 15 Jul 2015 22:32:14 +0200 Subject: extract authentication methods from protocol --- netlib/http/authentication.py | 22 +++++++++++++++++++++- netlib/http/http1/protocol.py | 39 ++------------------------------------- netlib/http/semantics.py | 14 +++++++++++++- test/http/http1/test_protocol.py | 19 ++++--------------- test/http/test_authentication.py | 21 +++++++++++++++++---- 5 files changed, 57 insertions(+), 58 deletions(-) diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 26e3c2c4..9a227010 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -1,8 +1,28 @@ from __future__ import (absolute_import, print_function, division) from argparse import Action, ArgumentTypeError +import binascii from .. import http +def parse_http_basic_auth(s): + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v + class NullProxyAuth(object): @@ -47,7 +67,7 @@ class BasicProxyAuth(NullProxyAuth): auth_value = headers.get(self.AUTH_HEADER, []) if not auth_value: return False - parts = http.http1.parse_http_basic_auth(auth_value[0]) + parts = parse_http_basic_auth(auth_value[0]) if not parts: return False scheme, username, password = parts diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 0f7a0bd3..97c119a9 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -85,22 +85,9 @@ def read_chunked(fp, limit, is_request): return -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks - - def has_chunked_encoding(headers): return "chunked" in [ - i.lower() for i in get_header_tokens(headers, "transfer-encoding") + i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") ] @@ -123,28 +110,6 @@ def parse_http_protocol(s): return major, minor -def parse_http_basic_auth(s): - # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics - words = s.split() - if len(words) != 2: - return None - scheme = words[0] - try: - user = binascii.a2b_base64(words[1]) - except binascii.Error: - return None - parts = user.split(':') - if len(parts) != 2: - return None - return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): - # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v - - def parse_init(line): try: method, url, protocol = string.split(line) @@ -221,7 +186,7 @@ def connection_close(httpversion, headers): """ # At first, check if we have an explicit Connection header. if "connection" in headers: - toks = get_header_tokens(headers, "connection") + toks = http.get_header_tokens(headers, "connection") if "close" in toks: return True elif "keep-alive" in toks: diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index e7e84fe3..a62c93e3 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -49,7 +49,6 @@ def is_valid_host(host): return True - def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. @@ -92,3 +91,16 @@ def parse_url(url): if not is_valid_port(port): return None return scheme, host, port, path + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: + for j in i.split(","): + toks.append(j.strip()) + return toks diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 05e82831..d0a2ee02 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -71,13 +71,13 @@ def test_connection_close(): def test_get_header_tokens(): h = odict.ODictCaseless() - assert protocol.get_header_tokens(h, "foo") == [] + assert http.get_header_tokens(h, "foo") == [] h["foo"] = ["bar"] - assert protocol.get_header_tokens(h, "foo") == ["bar"] + assert http.get_header_tokens(h, "foo") == ["bar"] h["foo"] = ["bar, voing"] - assert protocol.get_header_tokens(h, "foo") == ["bar", "voing"] + assert http.get_header_tokens(h, "foo") == ["bar", "voing"] h["foo"] = ["bar, voing", "oink"] - assert protocol.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] + assert http.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] def test_read_http_body_request(): @@ -357,17 +357,6 @@ def test_read_response(): assert tst(data, "GET", None, include_body=False).content is None -def test_parse_http_basic_auth(): - vals = ("basic", "foo", "bar") - assert protocol.parse_http_basic_auth( - protocol.assemble_http_basic_auth(*vals) - ) == vals - assert not protocol.parse_http_basic_auth("") - assert not protocol.parse_http_basic_auth("foo bar") - v = "basic " + binascii.b2a_base64("foo") - assert not protocol.parse_http_basic_auth(v) - - def test_get_request_line(): r = cStringIO.StringIO("\nfoo") assert protocol.get_request_line(r) == "foo" diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py index c0dae1a2..8f231643 100644 --- a/test/http/test_authentication.py +++ b/test/http/test_authentication.py @@ -1,8 +1,21 @@ +import binascii + from netlib import odict, http from netlib.http import authentication from .. import tutils +def test_parse_http_basic_auth(): + vals = ("basic", "foo", "bar") + assert http.authentication.parse_http_basic_auth( + http.authentication.assemble_http_basic_auth(*vals) + ) == vals + assert not http.authentication.parse_http_basic_auth("") + assert not http.authentication.parse_http_basic_auth("foo bar") + v = "basic " + binascii.b2a_base64("foo") + assert not http.authentication.parse_http_basic_auth(v) + + class TestPassManNonAnon: def test_simple(self): @@ -23,7 +36,7 @@ class TestPassManHtpasswd: pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) vals = ("basic", "test", "test") - http.http1.assemble_http_basic_auth(*vals) + authentication.assemble_http_basic_auth(*vals) assert pm.test("test", "test") assert not pm.test("test", "foo") assert not pm.test("foo", "test") @@ -62,7 +75,7 @@ class TestBasicProxyAuth: hdrs = odict.ODictCaseless() vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert ba.authenticate(hdrs) ba.clean(hdrs) @@ -75,12 +88,12 @@ class TestBasicProxyAuth: assert not ba.authenticate(hdrs) vals = ("foo", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert not ba.authenticate(hdrs) ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert not ba.authenticate(hdrs) -- cgit v1.2.3 From 230c16122b06f5c6af60e6ddc2d8e2e83cd75273 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 16 Jul 2015 22:50:24 +0200 Subject: change HTTP2 interface to match HTTP1 --- netlib/http/http2/protocol.py | 6 +++--- test/http/http2/test_protocol.py | 20 ++++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 8e5f5429..0d6eac85 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -2,7 +2,7 @@ from __future__ import (absolute_import, print_function, division) import itertools from hpack.hpack import Encoder, Decoder -from .. import utils +from netlib import http, utils from . import frame @@ -186,9 +186,9 @@ class HTTP2Protocol(object): self._create_headers(headers, stream_id, end_stream=(body is None)), self._create_body(body, stream_id))) - def read_response(self): + def read_response(self, *args): stream_id_, headers, body = self._receive_transmission() - return headers[':status'], headers, body + return http.Response("HTTP/2", headers[':status'], "", headers, body) def read_request(self): return self._receive_transmission() diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index f607860e..403a2589 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -251,11 +251,13 @@ class TestReadResponse(tservers.ServerTestBase): c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) - status, headers, body = protocol.read_response() + resp = protocol.read_response() - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'foobar' + assert resp.httpversion == "HTTP/2" + assert resp.status_code == "200" + assert resp.msg == "" + assert resp.headers == {':status': '200', 'etag': 'foobar'} + assert resp.content == b'foobar' class TestReadEmptyResponse(tservers.ServerTestBase): @@ -274,11 +276,13 @@ class TestReadEmptyResponse(tservers.ServerTestBase): c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) - status, headers, body = protocol.read_response() + resp = protocol.read_response() - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'' + assert resp.httpversion == "HTTP/2" + assert resp.status_code == "200" + assert resp.msg == "" + assert resp.headers == {':status': '200', 'etag': 'foobar'} + assert resp.content == b'' class TestReadRequest(tservers.ServerTestBase): -- cgit v1.2.3 From 808b294865257fc3f52b33ed2a796009658b126f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 16 Jul 2015 22:56:34 +0200 Subject: refactor HTTP/1 as protocol --- netlib/http/http1/protocol.py | 901 +++++++++++++++++++------------------ test/http/http1/test_protocol.py | 214 ++++----- test/websockets/test_websockets.py | 21 +- 3 files changed, 583 insertions(+), 553 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 97c119a9..401654c1 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -9,475 +9,488 @@ from netlib import odict, utils, tcp, http from .. import status_codes from ..exceptions import * +class HTTP1Protocol(object): + + # TODO: make this a regular class - just like Response + Request = collections.namedtuple( + "Request", + [ + "form_in", + "method", + "scheme", + "host", + "port", + "path", + "httpversion", + "headers", + "content" + ] + ) -def get_request_line(fp): - """ - Get a line, possibly preceded by a blank. - """ - line = fp.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = fp.readline() - return line - -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line is - reached. Return a ODictCaseless object, or None if headers are invalid. - """ - ret = [] - name = '' - while True: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) + def __init__(self, tcp_handler): + self.tcp_handler = tcp_handler + + def get_request_line(self): + """ + Get a line, possibly preceded by a blank. + """ + line = self.tcp_handler.rfile.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = self.tcp_handler.rfile.readline() + return line + + def read_headers(self): + """ + Read a set of headers. + Stop once a blank line is reached. + + Return a ODictCaseless object, or None if headers are invalid. + """ + ret = [] + name = '' + while True: + line = self.tcp_handler.rfile.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + if not ret: + return None + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() else: - return None - return odict.ODictCaseless(ret) - - -def read_chunked(fp, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = fp.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = fp.read(length) - suffix = fp.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' - if length == 0: - return - - -def has_chunked_encoding(headers): - return "chunked" in [ - i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") - ] - - -def parse_http_protocol(s): - """ - Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or - None. - """ - if not s.startswith("HTTP/"): - return None - _, version = s.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - -def parse_init(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - -def parse_init_connect(line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not http.is_valid_port(port): - return None - if not http.is_valid_host(host): - return None - return host, port, httpversion - - -def parse_init_proxy(line): - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = http.parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - -def connection_close(httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = http.get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - if httpversion == (1, 1): - return False - return True - - -def parse_response_line(line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - -def read_http_body(*args, **kwargs): - return "".join( - content for _, content, _ in read_http_body_chunked(*args, **kwargs) - ) + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i + 1:].strip() + ret.append([name, value]) + else: + return None + return odict.ODictCaseless(ret) + + + def read_chunked(self, limit, is_request): + """ + Read a chunked HTTP body. + + May raise HttpError. + """ + # FIXME: Should check if chunked is the final encoding in the headers + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + # 3.3 2. + total = 0 + code = 400 if is_request else 502 + while True: + line = self.tcp_handler.rfile.readline(128) + if line == "": + raise HttpErrorConnClosed(code, "Connection closed prematurely") + if line != '\r\n' and line != '\n': + try: + length = int(line, 16) + except ValueError: + raise HttpError( + code, + "Invalid chunked encoding length: %s" % line + ) + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large. Limit is %s," \ + " chunked content longer than %s" % (limit, total) + raise HttpError(code, msg) + chunk = self.tcp_handler.rfile.read(length) + suffix = self.tcp_handler.rfile.readline(5) + if suffix != '\r\n': + raise HttpError(code, "Malformed chunked body") + yield line, chunk, '\r\n' + if length == 0: + return + + + @classmethod + def has_chunked_encoding(self, headers): + return "chunked" in [ + i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") + ] + + + @classmethod + def parse_http_protocol(self, line): + """ + Parse an HTTP protocol declaration. + Returns a (major, minor) tuple, or None. + """ + if not line.startswith("HTTP/"): + return None + _, version = line.split('/', 1) + if "." not in version: + return None + major, minor = version.split('.', 1) + try: + major = int(major) + minor = int(minor) + except ValueError: + return None + return major, minor -def read_http_body_chunked( - rfile, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None -): - """ - Read an HTTP message body: - - rfile: A file descriptor to read from - headers: An ODictCaseless object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = expected_http_body_size( - headers, is_request, request_method, response_code - ) + @classmethod + def parse_init(self, line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + httpversion = self.parse_http_protocol(protocol) + if not httpversion: + return None + if not utils.isascii(method): + return None + return method, url, httpversion - if expected_size is None: - if has_chunked_encoding(headers): - # Python 3: yield from - for x in read_chunked(rfile, limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - yield "", rfile.read(chunk_size), "" - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if not content: - return - yield "", content, "" - bytes_left -= chunk_size - not_done = rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) + @classmethod + def parse_init_connect(self, line): + """ + Returns (host, port, httpversion) if line is a valid CONNECT line. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + """ + v = self.parse_init(line) + if not v: + return None + method, url, httpversion = v -def expected_http_body_size(headers, is_request, request_method, response_code): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if has_chunked_encoding(headers): - return None - if "content-length" in headers: + if method.upper() != 'CONNECT': + return None try: - size = int(headers["content-length"][0]) - if size < 0: - raise ValueError() - return size + host, port = url.split(":") except ValueError: return None - if is_request: - return 0 - return -1 - - -# TODO: make this a regular class - just like Response -Request = collections.namedtuple( - "Request", - [ - "form_in", - "method", - "scheme", - "host", - "port", - "path", - "httpversion", - "headers", - "content" - ] -) - - -def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): - """ - Parse an HTTP request from a file stream - - Args: - rfile (file): Input file to read from - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - httpversion, host, port, scheme, method, path, headers, content = ( - None, None, None, None, None, None, None, None) - - request_line = get_request_line(rfile) - if not request_line: - raise tcp.NetLibDisconnect() - - request_line_parts = parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) + try: + port = int(port) + except ValueError: + return None + if not http.is_valid_port(port): + return None + if not http.is_valid_host(host): + return None + return host, port, httpversion + + @classmethod + def parse_init_proxy(self, line): + v = self.parse_init(line) + if not v: + return None + method, url, httpversion = v + + parts = http.parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + return method, scheme, host, port, path, httpversion + + @classmethod + def parse_init_http(self, line): + """ + Returns (method, url, httpversion) + """ + v = self.parse_init(line) + if not v: + return None + method, url, httpversion = v + if not utils.isascii(url): + return None + if not (url.startswith("/") or url == "*"): + return None + return method, url, httpversion + + + @classmethod + def connection_close(self, httpversion, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1 Note that a connection should be + closed as well if the response has been read until end of the stream. + """ + # At first, check if we have an explicit Connection header. + if "connection" in headers: + toks = http.get_header_tokens(headers, "connection") + if "close" in toks: + return True + elif "keep-alive" in toks: + return False + + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + return httpversion != (1, 1) + + + @classmethod + def parse_response_line(self, line): + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if len(parts) != 3: + return None + proto, code, msg = parts + try: + code = int(code) + except ValueError: + return None + return (proto, code, msg) + + + def read_http_body(self, *args, **kwargs): + return "".join( + content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) + ) + + + def read_http_body_chunked( + self, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None + ): + """ + Read an HTTP message body: + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False + otherwise + """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxsize + + expected_size = self.expected_http_body_size( + headers, is_request, request_method, response_code ) - method, path, httpversion = request_line_parts - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): + if expected_size is None: + if self.has_chunked_encoding(headers): + # Python 3: yield from + for x in self.read_chunked(limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", self.tcp_handler.rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = self.tcp_handler.rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = self.tcp_handler.rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + + @classmethod + def expected_http_body_size(self, headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if self.has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + return None + if is_request: + return 0 + return -1 + + + def read_request(self, include_body=True, body_size_limit=None): + """ + Parse an HTTP request from a file stream + + Args: + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. + """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = self.get_request_line() + if not request_line: + raise tcp.NetLibDisconnect() + + request_line_parts = self.parse_init(request_line) + if not request_line_parts: raise HttpError( 400, "Bad HTTP request line: %s" % repr(request_line) ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = self.parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + path = None + else: + form_in = "absolute" + r = self.parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = self.read_headers() + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): + self.tcp_handler.wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' ) - host, port, _ = r - path = None - else: - form_in = "absolute" - r = parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) + self.tcp_handler.wfile.flush() + del headers['expect'] + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + method, + None, + True ) - _, scheme, host, port, path, _ = r - headers = read_headers(rfile) - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): - wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' + return self.Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content ) - wfile.flush() - del headers['expect'] - if include_body: - content = read_http_body( - rfile, headers, body_size_limit, method, None, True - ) - return Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content - ) + def read_response(self, request_method, body_size_limit, include_body=True): + """ + Returns an http.Response + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ -def read_response(rfile, request_method, body_size_limit, include_body=True): - """ - Returns an http.Response - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - - line = rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = read_headers(rfile) - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = read_http_body( - rfile, - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return http.Response(httpversion, code, msg, headers, content) + line = self.tcp_handler.rfile.readline() + # Possible leftover from previous message + if line == "\r\n" or line == "\n": + line = self.tcp_handler.rfile.readline() + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = self.parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = self.parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = self.read_headers() + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return http.Response(httpversion, code, msg, headers, content) -def request_preamble(method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) + @classmethod + def request_preamble(self, method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) -def response_preamble(code, message=None, http_major="1", http_minor="1"): - if message is None: - message = status_codes.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) + @classmethod + def response_preamble(self, code, message=None, http_major="1", http_minor="1"): + if message is None: + message = status_codes.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index d0a2ee02..6b8a884c 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -3,70 +3,79 @@ import textwrap import binascii from netlib import http, odict, tcp -from netlib.http.http1 import protocol +from netlib.http.http1 import HTTP1Protocol from ... import tutils, tservers +def mock_protocol(data='', chunked=False): + class TCPHandlerMock(object): + pass + tcp_handler = TCPHandlerMock() + tcp_handler.rfile = cStringIO.StringIO(data) + tcp_handler.wfile = cStringIO.StringIO() + return HTTP1Protocol(tcp_handler) + + + def test_has_chunked_encoding(): h = odict.ODictCaseless() - assert not protocol.has_chunked_encoding(h) + assert not HTTP1Protocol.has_chunked_encoding(h) h["transfer-encoding"] = ["chunked"] - assert protocol.has_chunked_encoding(h) + assert HTTP1Protocol.has_chunked_encoding(h) def test_read_chunked(): - h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] - s = cStringIO.StringIO("1\r\na\r\n0\r\n") + data = "1\r\na\r\n0\r\n" tutils.raises( "malformed chunked body", - protocol.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + data = "1\r\na\r\n0\r\n\r\n" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a" - s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") - assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + data = "\r\n\r\n1\r\na\r\n0\r\n\r\n" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a" - s = cStringIO.StringIO("\r\n") + data = "\r\n" tutils.raises( "closed prematurely", - protocol.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("1\r\nfoo") + data = "1\r\nfoo" tutils.raises( "malformed chunked body", - protocol.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("foo\r\nfoo") + data = "foo\r\nfoo" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, None, "GET", None, True + http.HttpError, + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - tutils.raises("too large", protocol.read_http_body, s, h, 2, "GET", None, True) + data = "5\r\naaaaa\r\n0\r\n\r\n" + tutils.raises("too large", mock_protocol(data).read_http_body, h, 2, "GET", None, True) def test_connection_close(): h = odict.ODictCaseless() - assert protocol.connection_close((1, 0), h) - assert not protocol.connection_close((1, 1), h) + assert HTTP1Protocol.connection_close((1, 0), h) + assert not HTTP1Protocol.connection_close((1, 1), h) h["connection"] = ["keep-alive"] - assert not protocol.connection_close((1, 1), h) + assert not HTTP1Protocol.connection_close((1, 1), h) h["connection"] = ["close"] - assert protocol.connection_close((1, 1), h) + assert HTTP1Protocol.connection_close((1, 1), h) def test_get_header_tokens(): @@ -82,119 +91,119 @@ def test_get_header_tokens(): def test_read_http_body_request(): h = odict.ODictCaseless() - r = cStringIO.StringIO("testing") - assert protocol.read_http_body(r, h, None, "GET", None, True) == "" + data = "testing" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "" def test_read_http_body_response(): h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + data = "testing" + assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing" def test_read_http_body(): # test default case h = odict.ODictCaseless() h["content-length"] = [7] - s = cStringIO.StringIO("testing") - assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + data = "testing" + assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing" # test content length: invalid header h["content-length"] = ["foo"] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, None, "GET", 200, False + http.HttpError, + mock_protocol(data).read_http_body, + h, None, "GET", 200, False ) # test content length: invalid header #2 h["content-length"] = [-1] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, None, "GET", 200, False + http.HttpError, + mock_protocol(data).read_http_body, + h, None, "GET", 200, False ) # test content length: content length > actual content h["content-length"] = [5] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, 4, "GET", 200, False + http.HttpError, + mock_protocol(data).read_http_body, + h, 4, "GET", 200, False ) # test content length: content length < actual content - s = cStringIO.StringIO("testing") - assert len(protocol.read_http_body(s, h, None, "GET", 200, False)) == 5 + data = "testing" + assert len(mock_protocol(data).read_http_body(h, None, "GET", 200, False)) == 5 # test no content length: limit > actual content h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert len(protocol.read_http_body(s, h, 100, "GET", 200, False)) == 7 + data = "testing" + assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content - s = tcp.Reader(cStringIO.StringIO("testing")) + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, 4, "GET", 200, False + http.HttpError, + mock_protocol(data, chunked=True).read_http_body, + h, 4, "GET", 200, False ) # test chunked h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] - s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) - assert protocol.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" + data = "5\r\naaaaa\r\n0\r\n\r\n" + assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa" def test_expected_http_body_size(): # gibber in the content-length field h = odict.ODictCaseless() h["content-length"] = ["foo"] - assert protocol.expected_http_body_size(h, False, "GET", 200) is None + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None # negative number in the content-length field h = odict.ODictCaseless() h["content-length"] = ["-7"] - assert protocol.expected_http_body_size(h, False, "GET", 200) is None + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None # explicit length h = odict.ODictCaseless() h["content-length"] = ["5"] - assert protocol.expected_http_body_size(h, False, "GET", 200) == 5 + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == 5 # no length h = odict.ODictCaseless() - assert protocol.expected_http_body_size(h, False, "GET", 200) == -1 + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == -1 # no length request h = odict.ODictCaseless() - assert protocol.expected_http_body_size(h, True, "GET", None) == 0 + assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0 def test_parse_http_protocol(): - assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) - assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not protocol.parse_http_protocol("HTTP/a.1") - assert not protocol.parse_http_protocol("HTTP/1.a") - assert not protocol.parse_http_protocol("foo/0.0") - assert not protocol.parse_http_protocol("HTTP/x") + assert HTTP1Protocol.parse_http_protocol("HTTP/1.1") == (1, 1) + assert HTTP1Protocol.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not HTTP1Protocol.parse_http_protocol("HTTP/a.1") + assert not HTTP1Protocol.parse_http_protocol("HTTP/1.a") + assert not HTTP1Protocol.parse_http_protocol("foo/0.0") + assert not HTTP1Protocol.parse_http_protocol("HTTP/x") def test_parse_init_connect(): - assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") - assert not protocol.parse_init_connect("bogus") - assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") - assert not protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0") + assert HTTP1Protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("bogus") + assert not HTTP1Protocol.parse_init_connect("GET host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0") def test_parse_init_proxy(): u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) + m, s, h, po, pa, httpversion = HTTP1Protocol.parse_init_proxy(u) assert m == "GET" assert s == "http" assert h == "foo.com" @@ -203,27 +212,27 @@ def test_parse_init_proxy(): assert httpversion == (1, 1) u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not protocol.parse_init_proxy(u) + assert not HTTP1Protocol.parse_init_proxy(u) - assert not protocol.parse_init_proxy("invalid") - assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") - assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + assert not HTTP1Protocol.parse_init_proxy("invalid") + assert not HTTP1Protocol.parse_init_proxy("GET invalid HTTP/1.1") + assert not HTTP1Protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") def test_parse_init_http(): u = "GET /test HTTP/1.1" - m, u, httpversion = protocol.parse_init_http(u) + m, u, httpversion = HTTP1Protocol.parse_init_http(u) assert m == "GET" assert u == "/test" assert httpversion == (1, 1) u = "G\xfeET /test HTTP/1.1" - assert not protocol.parse_init_http(u) + assert not HTTP1Protocol.parse_init_http(u) - assert not protocol.parse_init_http("invalid") - assert not protocol.parse_init_http("GET invalid HTTP/1.1") - assert not protocol.parse_init_http("GET /test foo/1.1") - assert not protocol.parse_init_http("GET /test\xc0 HTTP/1.1") + assert not HTTP1Protocol.parse_init_http("invalid") + assert not HTTP1Protocol.parse_init_http("GET invalid HTTP/1.1") + assert not HTTP1Protocol.parse_init_http("GET /test foo/1.1") + assert not HTTP1Protocol.parse_init_http("GET /test\xc0 HTTP/1.1") class TestReadHeaders: @@ -232,8 +241,7 @@ class TestReadHeaders: if not verbatim: data = textwrap.dedent(data) data = data.strip() - s = cStringIO.StringIO(data) - return protocol.read_headers(s) + return mock_protocol(data).read_headers() def test_read_simple(self): data = """ @@ -287,16 +295,15 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase): def test_no_content_length(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - resp = protocol.read_response(c.rfile, "GET", None) + resp = HTTP1Protocol(c).read_response("GET", None) assert resp.content == "bar\r\n\r\n" def test_read_response(): def tst(data, method, limit, include_body=True): data = textwrap.dedent(data) - r = cStringIO.StringIO(data) - return protocol.read_response( - r, method, limit, include_body=include_body + return mock_protocol(data).read_response( + method, limit, include_body=include_body ) tutils.raises("server disconnect", tst, "", "GET", None) @@ -358,16 +365,16 @@ def test_read_response(): def test_get_request_line(): - r = cStringIO.StringIO("\nfoo") - assert protocol.get_request_line(r) == "foo" - assert not protocol.get_request_line(r) + data = "\nfoo" + p = mock_protocol(data) + assert p.get_request_line() == "foo" + assert not p.get_request_line() class TestReadRequest(): def tst(self, data, **kwargs): - r = cStringIO.StringIO(data) - return protocol.read_request(r, **kwargs) + return mock_protocol(data).read_request(**kwargs) def test_invalid(self): tutils.raises( @@ -421,14 +428,15 @@ class TestReadRequest(): assert v.host == "foo.com" def test_expect(self): - w = cStringIO.StringIO() - r = cStringIO.StringIO( + data = "".join( "GET / HTTP/1.1\r\n" "Content-Length: 3\r\n" "Expect: 100-continue\r\n\r\n" - "foobar", + "foobar" ) - v = protocol.read_request(r, wfile=w) - assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" + + p = mock_protocol(data) + v = p.read_request() + assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" assert v.content == "foo" - assert r.read(3) == "bar" + assert p.tcp_handler.rfile.read(3) == "bar" diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 07ad0452..fb7ba39a 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -4,6 +4,7 @@ from nose.tools import raises from netlib import tcp, http, websockets from netlib.http.exceptions import * +from netlib.http.http1 import HTTP1Protocol from .. import tutils, tservers @@ -32,10 +33,13 @@ class WebSocketsEchoHandler(tcp.BaseHandler): frame.to_file(self.wfile) def handshake(self): - req = http.http1.read_request(self.rfile) + http1_protocol = HTTP1Protocol(self) + + req = http1_protocol.read_request() key = self.protocol.check_client_handshake(req.headers) - self.wfile.write(http.http1.response_preamble(101) + "\r\n") + preamble = http1_protocol.response_preamble(101) + self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers(key) self.wfile.write(headers.format() + "\r\n") self.wfile.flush() @@ -56,14 +60,16 @@ class WebSocketsClient(tcp.TCPClient): def connect(self): super(WebSocketsClient, self).connect() - preamble = http.http1.protocol.request_preamble("GET", "/") + http1_protocol = HTTP1Protocol(self) + + preamble = http1_protocol.request_preamble("GET", "/") self.wfile.write(preamble + "\r\n") headers = self.protocol.client_handshake_headers() self.client_nonce = headers.get_first("sec-websocket-key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() - resp = http.http1.protocol.read_response(self.rfile, "get", None) + resp = http1_protocol.read_response("get", None) server_nonce = self.protocol.check_server_handshake(resp.headers) if not server_nonce == self.protocol.create_server_nonce( @@ -151,10 +157,13 @@ class TestWebSockets(tservers.ServerTestBase): class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - client_hs = http.http1.protocol.read_request(self.rfile) + http1_protocol = HTTP1Protocol(self) + + client_hs = http1_protocol.read_request() self.protocol.check_client_handshake(client_hs.headers) - self.wfile.write(http.http1.protocol.response_preamble(101) + "\r\n") + preamble = http1_protocol.response_preamble(101) + self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers("malformed key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() -- cgit v1.2.3 From 4617ab8a3a981f3abd8d62b561c80f9ad141e57b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 17 Jul 2015 09:37:57 +0200 Subject: add Request class and unify read_request interface --- netlib/http/__init__.py | 1 + netlib/http/http1/protocol.py | 22 +++++----------------- netlib/http/http2/protocol.py | 20 +++++++++++++++++--- netlib/http/semantics.py | 31 +++++++++++++++++++++++++++++++ test/http/http2/test_protocol.py | 9 +++++---- 5 files changed, 59 insertions(+), 24 deletions(-) diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 9b4b0e6b..b01afc6d 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,2 +1,3 @@ +from . import * from exceptions import * from semantics import * diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 401654c1..8d631a13 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -11,25 +11,10 @@ from ..exceptions import * class HTTP1Protocol(object): - # TODO: make this a regular class - just like Response - Request = collections.namedtuple( - "Request", - [ - "form_in", - "method", - "scheme", - "host", - "port", - "path", - "httpversion", - "headers", - "content" - ] - ) - def __init__(self, tcp_handler): self.tcp_handler = tcp_handler + def get_request_line(self): """ Get a line, possibly preceded by a blank. @@ -40,6 +25,7 @@ class HTTP1Protocol(object): line = self.tcp_handler.rfile.readline() return line + def read_headers(self): """ Read a set of headers. @@ -175,6 +161,7 @@ class HTTP1Protocol(object): return None return host, port, httpversion + @classmethod def parse_init_proxy(self, line): v = self.parse_init(line) @@ -188,6 +175,7 @@ class HTTP1Protocol(object): scheme, host, port, path = parts return method, scheme, host, port, path, httpversion + @classmethod def parse_init_http(self, line): """ @@ -425,7 +413,7 @@ class HTTP1Protocol(object): True ) - return self.Request( + return http.Request( form_in, method, scheme, diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 0d6eac85..1dfdda21 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -187,11 +187,25 @@ class HTTP2Protocol(object): self._create_body(body, stream_id))) def read_response(self, *args): - stream_id_, headers, body = self._receive_transmission() - return http.Response("HTTP/2", headers[':status'], "", headers, body) + stream_id, headers, body = self._receive_transmission() + + response = http.Response("HTTP/2", headers[':status'], "", headers, body) + response.stream_id = stream_id + return response def read_request(self): - return self._receive_transmission() + stream_id, headers, body = self._receive_transmission() + + form_in = "" + method = headers.get(':method', '') + scheme = headers.get(':scheme', '') + host = headers.get(':host', '') + port = '' # TODO: parse port number? + path = headers.get(':path', '') + + request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) + request.stream_id = stream_id + return request def _receive_transmission(self): body_expected = True diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index a62c93e3..9a010318 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -7,6 +7,37 @@ import urlparse from .. import utils +class Request(object): + + def __init__( + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content, + ): + self.form_in = form_in + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.httpversion = httpversion + self.headers = headers + self.content = content + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Request(%s - %s, %s)" % (self.method, self.host, self.path) + + class Response(object): def __init__( diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 403a2589..f41b9565 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -278,6 +278,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): resp = protocol.read_response() + assert resp.stream_id assert resp.httpversion == "HTTP/2" assert resp.status_code == "200" assert resp.msg == "" @@ -303,11 +304,11 @@ class TestReadRequest(tservers.ServerTestBase): c.convert_to_ssl() protocol = http2.HTTP2Protocol(c, is_server=True) - stream_id, headers, body = protocol.read_request() + resp = protocol.read_request() - assert stream_id - assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} - assert body == b'foobar' + assert resp.stream_id + assert resp.headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} + assert resp.content == b'foobar' class TestCreateResponse(): -- cgit v1.2.3 From 37a0cb858cda255bac8f06749a81859c82c5177f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 17:52:10 +0200 Subject: introduce ConnectRequest class --- netlib/http/http1/protocol.py | 2 +- netlib/http/semantics.py | 24 +++++++++++++++++++----- netlib/odict.py | 2 ++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 8d631a13..257efb19 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -380,7 +380,7 @@ class HTTP1Protocol(object): "Bad HTTP request line: %s" % repr(request_line) ) host, port, _ = r - path = None + return http.ConnectRequest(host, port) else: form_in = "absolute" r = self.parse_init_proxy(request_line) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 9a010318..664f9def 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -19,7 +19,7 @@ class Request(object): path, httpversion, headers, - content, + body, ): self.form_in = form_in self.method = method @@ -29,7 +29,7 @@ class Request(object): self.path = path self.httpversion = httpversion self.headers = headers - self.content = content + self.body = body def __eq__(self, other): return self.__dict__ == other.__dict__ @@ -38,6 +38,21 @@ class Request(object): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) +class ConnectRequest(Request): + def __init__(self, host, port): + super(ConnectRequest, self).__init__( + form_in="authority", + method="CONNECT", + scheme="", + host=host, + port=port, + path="", + httpversion="", + headers="", + body="", + ) + + class Response(object): def __init__( @@ -46,14 +61,14 @@ class Response(object): status_code, msg, headers, - content, + body, sslinfo=None, ): self.httpversion = httpversion self.status_code = status_code self.msg = msg self.headers = headers - self.content = content + self.body = body self.sslinfo = sslinfo def __eq__(self, other): @@ -63,7 +78,6 @@ class Response(object): return "Response(%s - %s)" % (self.status_code, self.msg) - def is_valid_port(port): if not 0 <= port <= 65535: return False diff --git a/netlib/odict.py b/netlib/odict.py index f52acd50..ee1e6938 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -20,6 +20,8 @@ class ODict(object): """ def __init__(self, lst=None): + if isinstance(lst, ODict): + lst = lst.items() self.lst = lst or [] def _kconv(self, s): -- cgit v1.2.3 From d62dbee0f6cd47b4cad1ee7cc731b413600c0add Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 18:17:30 +0200 Subject: rename content -> body --- netlib/wsgi.py | 6 +++--- test/http/http1/test_protocol.py | 10 +++++----- test/http/http2/test_protocol.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index ad43dc19..99afe00e 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -21,9 +21,9 @@ class Flow(object): class Request(object): - def __init__(self, scheme, method, path, headers, content): + def __init__(self, scheme, method, path, headers, body): self.scheme, self.method, self.path = scheme, method, path - self.headers, self.content = headers, content + self.headers, self.body = headers, body def date_time_string(): @@ -58,7 +58,7 @@ class WSGIAdaptor(object): environ = { 'wsgi.version': (1, 0), 'wsgi.url_scheme': flow.request.scheme, - 'wsgi.input': cStringIO.StringIO(flow.request.content), + 'wsgi.input': cStringIO.StringIO(flow.request.body or ""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 6b8a884c..936fe20d 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -296,7 +296,7 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() resp = HTTP1Protocol(c).read_response("GET", None) - assert resp.content == "bar\r\n\r\n" + assert resp.body == "bar\r\n\r\n" def test_read_response(): @@ -344,8 +344,8 @@ def test_read_response(): foo """ - assert tst(data, "GET", None).content == 'foo' - assert tst(data, "HEAD", None).content == '' + assert tst(data, "GET", None).body == 'foo' + assert tst(data, "HEAD", None).body == '' data = """ HTTP/1.1 200 OK @@ -361,7 +361,7 @@ def test_read_response(): foo """ - assert tst(data, "GET", None, include_body=False).content is None + assert tst(data, "GET", None, include_body=False).body is None def test_get_request_line(): @@ -438,5 +438,5 @@ class TestReadRequest(): p = mock_protocol(data) v = p.read_request() assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" - assert v.content == "foo" + assert v.body == "foo" assert p.tcp_handler.rfile.read(3) == "bar" diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index f41b9565..34e4ef50 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -257,7 +257,7 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.status_code == "200" assert resp.msg == "" assert resp.headers == {':status': '200', 'etag': 'foobar'} - assert resp.content == b'foobar' + assert resp.body == b'foobar' class TestReadEmptyResponse(tservers.ServerTestBase): @@ -283,7 +283,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.status_code == "200" assert resp.msg == "" assert resp.headers == {':status': '200', 'etag': 'foobar'} - assert resp.content == b'' + assert resp.body == b'' class TestReadRequest(tservers.ServerTestBase): @@ -308,7 +308,7 @@ class TestReadRequest(tservers.ServerTestBase): assert resp.stream_id assert resp.headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} - assert resp.content == b'foobar' + assert resp.body == b'foobar' class TestCreateResponse(): -- cgit v1.2.3 From 83f013fca13c7395ca4e3da3fac60c8d907172b6 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 20:46:26 +0200 Subject: introduce EmptyRequest class --- netlib/http/http1/protocol.py | 7 +++++-- netlib/http/semantics.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 257efb19..d2a77399 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -333,7 +333,7 @@ class HTTP1Protocol(object): return -1 - def read_request(self, include_body=True, body_size_limit=None): + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): """ Parse an HTTP request from a file stream @@ -354,7 +354,10 @@ class HTTP1Protocol(object): request_line = self.get_request_line() if not request_line: - raise tcp.NetLibDisconnect() + if allow_empty: + return http.EmptyRequest() + else: + raise tcp.NetLibDisconnect() request_line_parts = self.parse_init(request_line) if not request_line_parts: diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 664f9def..355906dd 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -38,6 +38,20 @@ class Request(object): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) +class EmptyRequest(Request): + def __init__(self): + super(EmptyRequest, self).__init__( + form_in="", + method="", + scheme="", + host="", + port="", + path="", + httpversion="", + headers="", + body="", + ) + class ConnectRequest(Request): def __init__(self, host, port): super(ConnectRequest, self).__init__( -- cgit v1.2.3 From ecc7ffe9282ae9d1b652a88946d6edc550dc9633 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 23:25:15 +0200 Subject: reduce public interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit use private indicator pattern “_methodname” --- netlib/http/http1/protocol.py | 569 ++++++++++++++++++++------------------- test/http/http1/test_protocol.py | 56 ++-- 2 files changed, 313 insertions(+), 312 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index d2a77399..e7727e00 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -15,15 +15,144 @@ class HTTP1Protocol(object): self.tcp_handler = tcp_handler - def get_request_line(self): + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): """ - Get a line, possibly preceded by a blank. + Parse an HTTP request from a file stream + + Args: + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = self._get_request_line() + if not request_line: + if allow_empty: + return http.EmptyRequest() + else: + raise tcp.NetLibDisconnect() + + request_line_parts = self._parse_init(request_line) + if not request_line_parts: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = self._parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + return http.ConnectRequest(host, port) + else: + form_in = "absolute" + r = self._parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = self.read_headers() + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): + self.tcp_handler.wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' + ) + self.tcp_handler.wfile.flush() + del headers['expect'] + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + method, + None, + True + ) + + return http.Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content + ) + + + def read_response(self, request_method, body_size_limit, include_body=True): + """ + Returns an http.Response + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ + line = self.tcp_handler.rfile.readline() + # Possible leftover from previous message if line == "\r\n" or line == "\n": - # Possible leftover from previous message line = self.tcp_handler.rfile.readline() - return line + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = self.parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = self._parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = self.read_headers() + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return http.Response(httpversion, code, msg, headers, content) def read_headers(self): @@ -56,7 +185,146 @@ class HTTP1Protocol(object): return odict.ODictCaseless(ret) - def read_chunked(self, limit, is_request): + def read_http_body(self, *args, **kwargs): + return "".join( + content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) + ) + + + def read_http_body_chunked( + self, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None + ): + """ + Read an HTTP message body: + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False + otherwise + """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxsize + + expected_size = self.expected_http_body_size( + headers, is_request, request_method, response_code + ) + + if expected_size is None: + if self.has_chunked_encoding(headers): + # Python 3: yield from + for x in self._read_chunked(limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", self.tcp_handler.rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = self.tcp_handler.rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = self.tcp_handler.rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + + @classmethod + def expected_http_body_size(self, headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if self.has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + return None + if is_request: + return 0 + return -1 + + + @classmethod + def request_preamble(self, method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) + + + @classmethod + def response_preamble(self, code, message=None, http_major="1", http_minor="1"): + if message is None: + message = status_codes.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) + + + @classmethod + def has_chunked_encoding(self, headers): + return "chunked" in [ + i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") + ] + + + def _get_request_line(self): + """ + Get a line, possibly preceded by a blank. + """ + line = self.tcp_handler.rfile.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = self.tcp_handler.rfile.readline() + return line + + + + def _read_chunked(self, limit, is_request): """ Read a chunked HTTP body. @@ -88,20 +356,13 @@ class HTTP1Protocol(object): suffix = self.tcp_handler.rfile.readline(5) if suffix != '\r\n': raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' - if length == 0: - return - - - @classmethod - def has_chunked_encoding(self, headers): - return "chunked" in [ - i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") - ] + yield line, chunk, '\r\n' + if length == 0: + return @classmethod - def parse_http_protocol(self, line): + def _parse_http_protocol(self, line): """ Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or None. @@ -121,12 +382,12 @@ class HTTP1Protocol(object): @classmethod - def parse_init(self, line): + def _parse_init(self, line): try: method, url, protocol = string.split(line) except ValueError: return None - httpversion = self.parse_http_protocol(protocol) + httpversion = self._parse_http_protocol(protocol) if not httpversion: return None if not utils.isascii(method): @@ -135,12 +396,12 @@ class HTTP1Protocol(object): @classmethod - def parse_init_connect(self, line): + def _parse_init_connect(self, line): """ Returns (host, port, httpversion) if line is a valid CONNECT line. http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 """ - v = self.parse_init(line) + v = self._parse_init(line) if not v: return None method, url, httpversion = v @@ -163,8 +424,8 @@ class HTTP1Protocol(object): @classmethod - def parse_init_proxy(self, line): - v = self.parse_init(line) + def _parse_init_proxy(self, line): + v = self._parse_init(line) if not v: return None method, url, httpversion = v @@ -177,11 +438,11 @@ class HTTP1Protocol(object): @classmethod - def parse_init_http(self, line): + def _parse_init_http(self, line): """ Returns (method, url, httpversion) """ - v = self.parse_init(line) + v = self._parse_init(line) if not v: return None method, url, httpversion = v @@ -225,263 +486,3 @@ class HTTP1Protocol(object): except ValueError: return None return (proto, code, msg) - - - def read_http_body(self, *args, **kwargs): - return "".join( - content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) - ) - - - def read_http_body_chunked( - self, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None - ): - """ - Read an HTTP message body: - headers: An ODictCaseless object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = self.expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if self.has_chunked_encoding(headers): - # Python 3: yield from - for x in self.read_chunked(limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - yield "", self.tcp_handler.rfile.read(chunk_size), "" - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = self.tcp_handler.rfile.read(chunk_size) - if not content: - return - yield "", content, "" - bytes_left -= chunk_size - not_done = self.tcp_handler.rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - - @classmethod - def expected_http_body_size(self, headers, is_request, request_method, response_code): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if self.has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"][0]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - - def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): - """ - Parse an HTTP request from a file stream - - Args: - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - httpversion, host, port, scheme, method, path, headers, content = ( - None, None, None, None, None, None, None, None) - - request_line = self.get_request_line() - if not request_line: - if allow_empty: - return http.EmptyRequest() - else: - raise tcp.NetLibDisconnect() - - request_line_parts = self.parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = self.parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, _ = r - return http.ConnectRequest(host, port) - else: - form_in = "absolute" - r = self.parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = self.read_headers() - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): - self.tcp_handler.wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - self.tcp_handler.wfile.flush() - del headers['expect'] - - if include_body: - content = self.read_http_body( - headers, - body_size_limit, - method, - None, - True - ) - - return http.Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content - ) - - - def read_response(self, request_method, body_size_limit, include_body=True): - """ - Returns an http.Response - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - - line = self.tcp_handler.rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = self.tcp_handler.rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = self.parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = self.parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = self.read_headers() - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = self.read_http_body( - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return http.Response(httpversion, code, msg, headers, content) - - - @classmethod - def request_preamble(self, method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - - @classmethod - def response_preamble(self, code, message=None, http_major="1", http_minor="1"): - if message is None: - message = status_codes.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 936fe20d..8d05b31f 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -181,29 +181,29 @@ def test_expected_http_body_size(): def test_parse_http_protocol(): - assert HTTP1Protocol.parse_http_protocol("HTTP/1.1") == (1, 1) - assert HTTP1Protocol.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not HTTP1Protocol.parse_http_protocol("HTTP/a.1") - assert not HTTP1Protocol.parse_http_protocol("HTTP/1.a") - assert not HTTP1Protocol.parse_http_protocol("foo/0.0") - assert not HTTP1Protocol.parse_http_protocol("HTTP/x") + assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1) + assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0) + assert not HTTP1Protocol._parse_http_protocol("HTTP/a.1") + assert not HTTP1Protocol._parse_http_protocol("HTTP/1.a") + assert not HTTP1Protocol._parse_http_protocol("foo/0.0") + assert not HTTP1Protocol._parse_http_protocol("HTTP/x") def test_parse_init_connect(): - assert HTTP1Protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("bogus") - assert not HTTP1Protocol.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") - assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0") + assert HTTP1Protocol._parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:444444 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("bogus") + assert not HTTP1Protocol._parse_init_connect("GET host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:443 foo/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:foo HTTP/1.0") def test_parse_init_proxy(): u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = HTTP1Protocol.parse_init_proxy(u) + m, s, h, po, pa, httpversion = HTTP1Protocol._parse_init_proxy(u) assert m == "GET" assert s == "http" assert h == "foo.com" @@ -212,27 +212,27 @@ def test_parse_init_proxy(): assert httpversion == (1, 1) u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not HTTP1Protocol.parse_init_proxy(u) + assert not HTTP1Protocol._parse_init_proxy(u) - assert not HTTP1Protocol.parse_init_proxy("invalid") - assert not HTTP1Protocol.parse_init_proxy("GET invalid HTTP/1.1") - assert not HTTP1Protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + assert not HTTP1Protocol._parse_init_proxy("invalid") + assert not HTTP1Protocol._parse_init_proxy("GET invalid HTTP/1.1") + assert not HTTP1Protocol._parse_init_proxy("GET http://foo.com:8888/test foo/1.1") def test_parse_init_http(): u = "GET /test HTTP/1.1" - m, u, httpversion = HTTP1Protocol.parse_init_http(u) + m, u, httpversion = HTTP1Protocol._parse_init_http(u) assert m == "GET" assert u == "/test" assert httpversion == (1, 1) u = "G\xfeET /test HTTP/1.1" - assert not HTTP1Protocol.parse_init_http(u) + assert not HTTP1Protocol._parse_init_http(u) - assert not HTTP1Protocol.parse_init_http("invalid") - assert not HTTP1Protocol.parse_init_http("GET invalid HTTP/1.1") - assert not HTTP1Protocol.parse_init_http("GET /test foo/1.1") - assert not HTTP1Protocol.parse_init_http("GET /test\xc0 HTTP/1.1") + assert not HTTP1Protocol._parse_init_http("invalid") + assert not HTTP1Protocol._parse_init_http("GET invalid HTTP/1.1") + assert not HTTP1Protocol._parse_init_http("GET /test foo/1.1") + assert not HTTP1Protocol._parse_init_http("GET /test\xc0 HTTP/1.1") class TestReadHeaders: @@ -367,8 +367,8 @@ def test_read_response(): def test_get_request_line(): data = "\nfoo" p = mock_protocol(data) - assert p.get_request_line() == "foo" - assert not p.get_request_line() + assert p._get_request_line() == "foo" + assert not p._get_request_line() class TestReadRequest(): -- cgit v1.2.3 From faf17d3d60e658d0cd1df30a10be4f11035502f8 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 20 Jul 2015 16:33:00 +0200 Subject: http2: make proper use of odict --- netlib/http/http2/protocol.py | 19 +++++++++++-------- netlib/odict.py | 2 -- test/http/http2/test_protocol.py | 8 ++++---- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 1dfdda21..55b5ca76 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -2,7 +2,7 @@ from __future__ import (absolute_import, print_function, division) import itertools from hpack.hpack import Encoder, Decoder -from netlib import http, utils +from netlib import http, utils, odict from . import frame @@ -189,7 +189,8 @@ class HTTP2Protocol(object): def read_response(self, *args): stream_id, headers, body = self._receive_transmission() - response = http.Response("HTTP/2", headers[':status'], "", headers, body) + status = headers[':status'][0] + response = http.Response("HTTP/2", status, "", headers, body) response.stream_id = stream_id return response @@ -197,11 +198,11 @@ class HTTP2Protocol(object): stream_id, headers, body = self._receive_transmission() form_in = "" - method = headers.get(':method', '') - scheme = headers.get(':scheme', '') - host = headers.get(':host', '') + method = headers.get(':method', [''])[0] + scheme = headers.get(':scheme', [''])[0] + host = headers.get(':host', [''])[0] port = '' # TODO: parse port number? - path = headers.get(':path', '') + path = headers.get(':path', [''])[0] request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) request.stream_id = stream_id @@ -233,15 +234,17 @@ class HTTP2Protocol(object): break # TODO: implement window update & flow - headers = {} + headers = odict.ODictCaseless() for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value + headers.add(header, value) return stream_id, headers, body def create_response(self, code, stream_id=None, headers=None, body=None): if headers is None: headers = [] + if isinstance(headers, odict.ODict): + headers = headers.items() headers = [(b':status', bytes(str(code)))] + headers diff --git a/netlib/odict.py b/netlib/odict.py index ee1e6938..f52acd50 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -20,8 +20,6 @@ class ODict(object): """ def __init__(self, lst=None): - if isinstance(lst, ODict): - lst = lst.items() self.lst = lst or [] def _kconv(self, s): diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 34e4ef50..d3040266 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,6 +1,6 @@ import OpenSSL -from netlib import tcp +from netlib import tcp, odict from netlib.http import http2 from netlib.http.http2.frame import * from ... import tutils, tservers @@ -256,7 +256,7 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.httpversion == "HTTP/2" assert resp.status_code == "200" assert resp.msg == "" - assert resp.headers == {':status': '200', 'etag': 'foobar'} + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'foobar' @@ -282,7 +282,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.httpversion == "HTTP/2" assert resp.status_code == "200" assert resp.msg == "" - assert resp.headers == {':status': '200', 'etag': 'foobar'} + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'' @@ -307,7 +307,7 @@ class TestReadRequest(tservers.ServerTestBase): resp = protocol.read_request() assert resp.stream_id - assert resp.headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} + assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] assert resp.body == b'foobar' -- cgit v1.2.3 From 657973eca3b091cdf07a65f8363affd3d36f0d0f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 22 Jul 2015 13:01:24 +0200 Subject: fix bugs --- netlib/http/http1/protocol.py | 26 +++++++++++++++++--------- netlib/http/semantics.py | 28 +++++++++++----------------- test/http/http1/test_protocol.py | 9 +++------ 3 files changed, 31 insertions(+), 32 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index e7727e00..e46ad7ab 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -9,10 +9,18 @@ from netlib import odict, utils, tcp, http from .. import status_codes from ..exceptions import * +class TCPHandler(object): + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + class HTTP1Protocol(object): - def __init__(self, tcp_handler): - self.tcp_handler = tcp_handler + def __init__(self, tcp_handler=None, rfile=None, wfile=None): + if tcp_handler: + self.tcp_handler = tcp_handler + else: + self.tcp_handler = TCPHandler(rfile, wfile) def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): @@ -31,7 +39,7 @@ class HTTP1Protocol(object): Raises: HttpError: If the input is invalid. """ - httpversion, host, port, scheme, method, path, headers, content = ( + httpversion, host, port, scheme, method, path, headers, body = ( None, None, None, None, None, None, None, None) request_line = self._get_request_line() @@ -56,7 +64,7 @@ class HTTP1Protocol(object): 400, "Bad HTTP request line: %s" % repr(request_line) ) - elif method.upper() == 'CONNECT': + elif method == 'CONNECT': form_in = "authority" r = self._parse_init_connect(request_line) if not r: @@ -64,8 +72,8 @@ class HTTP1Protocol(object): 400, "Bad HTTP request line: %s" % repr(request_line) ) - host, port, _ = r - return http.ConnectRequest(host, port) + host, port, httpversion = r + path = None else: form_in = "absolute" r = self._parse_init_proxy(request_line) @@ -81,7 +89,7 @@ class HTTP1Protocol(object): raise HttpError(400, "Invalid headers") expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): + if expect_header == "100-continue" and httpversion == (1, 1): self.tcp_handler.wfile.write( 'HTTP/1.1 100 Continue\r\n' '\r\n' @@ -90,7 +98,7 @@ class HTTP1Protocol(object): del headers['expect'] if include_body: - content = self.read_http_body( + body = self.read_http_body( headers, body_size_limit, method, @@ -107,7 +115,7 @@ class HTTP1Protocol(object): path, httpversion, headers, - content + body ) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 355906dd..9e13edaa 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -5,7 +5,7 @@ import string import sys import urlparse -from .. import utils +from .. import utils, odict class Request(object): @@ -37,6 +37,10 @@ class Request(object): def __repr__(self): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) + @property + def content(self): + return self.body + class EmptyRequest(Request): def __init__(self): @@ -47,22 +51,8 @@ class EmptyRequest(Request): host="", port="", path="", - httpversion="", - headers="", - body="", - ) - -class ConnectRequest(Request): - def __init__(self, host, port): - super(ConnectRequest, self).__init__( - form_in="authority", - method="CONNECT", - scheme="", - host=host, - port=port, - path="", - httpversion="", - headers="", + httpversion=(0, 0), + headers=odict.ODictCaseless(), body="", ) @@ -91,6 +81,10 @@ class Response(object): def __repr__(self): return "Response(%s - %s)" % (self.status_code, self.msg) + @property + def content(self): + return self.body + def is_valid_port(port): if not 0 <= port <= 65535: diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 8d05b31f..dcebbd5e 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -8,12 +8,9 @@ from ... import tutils, tservers def mock_protocol(data='', chunked=False): - class TCPHandlerMock(object): - pass - tcp_handler = TCPHandlerMock() - tcp_handler.rfile = cStringIO.StringIO(data) - tcp_handler.wfile = cStringIO.StringIO() - return HTTP1Protocol(tcp_handler) + rfile = cStringIO.StringIO(data) + wfile = cStringIO.StringIO() + return HTTP1Protocol(rfile=rfile, wfile=wfile) -- cgit v1.2.3 From 1b261613826565dc5453b2846904c23773243921 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 24 Jul 2015 16:47:28 +0200 Subject: add distinct error for cert verification issues --- netlib/certutils.py | 2 -- netlib/tcp.py | 11 +++++++++-- test/test_tcp.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index c699af00..cc143a50 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -304,8 +304,6 @@ class CertStore(object): valid, plain-ASCII, IDNA-encoded domain name. sans: A list of Subject Alternate Names. - - Return None if the certificate could not be found or generated. """ potential_keys = self.asterisk_forms(commonname) diff --git a/netlib/tcp.py b/netlib/tcp.py index 5c4094d7..77c2a531 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -65,6 +65,10 @@ class NetLibSSLError(NetLibError): pass +class NetLibInvalidCertificateError(NetLibSSLError): + pass + + class SSLKeyLogger(object): def __init__(self, filename): @@ -517,13 +521,16 @@ class TCPClient(_Connection): try: self.connection.do_handshake() except SSL.Error as v: - raise NetLibError("SSL handshake error: %s" % repr(v)) + if self.ssl_verification_error: + raise NetLibInvalidCertificateError("SSL handshake error: %s" % repr(v)) + else: + raise NetLibError("SSL handshake error: %s" % repr(v)) # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on # certificate validation failure verification_mode = sslctx_kwargs.get('verify_options', None) if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER: - raise NetLibError("SSL handshake error: certificate verify failed") + raise NetLibInvalidCertificateError("SSL handshake error: certificate verify failed") self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) diff --git a/test/test_tcp.py b/test/test_tcp.py index 8a3299b6..289ed72f 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -224,7 +224,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): c.connect() tutils.raises( - tcp.NetLibError, + tcp.NetLibInvalidCertificateError, c.convert_to_ssl, verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem")) -- cgit v1.2.3 From fb482172241b6235da083f6dbf154b641772a4fc Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 25 Jul 2015 13:30:25 +0200 Subject: improve pyopenssl version check --- netlib/version_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version_check.py b/netlib/version_check.py index 2081c410..5465c901 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -29,7 +29,7 @@ def version_check( file=fp ) sys.exit(1) - v = tuple([int(x) for x in OpenSSL.__version__.split(".")][:2]) + v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) if v < pyopenssl_min_version: print( "You are using an outdated version of pyOpenSSL:" -- cgit v1.2.3 From 827fe824d97d96779512c8a4032d9b30d516d63f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 27 Jul 2015 09:36:50 +0200 Subject: move code from mitmproxy to netlib --- netlib/http/http1/protocol.py | 52 ++++++++++++++++++----- netlib/http/http2/protocol.py | 92 ++++++++++++++++++++++++++++++++-------- netlib/http/semantics.py | 49 ++++++++++++++++++++- test/http/http1/test_protocol.py | 4 +- test/http/http2/test_protocol.py | 4 +- 5 files changed, 167 insertions(+), 34 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index e46ad7ab..af9882e8 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -4,6 +4,7 @@ import collections import string import sys import urlparse +import time from netlib import odict, utils, tcp, http from .. import status_codes @@ -17,10 +18,7 @@ class TCPHandler(object): class HTTP1Protocol(object): def __init__(self, tcp_handler=None, rfile=None, wfile=None): - if tcp_handler: - self.tcp_handler = tcp_handler - else: - self.tcp_handler = TCPHandler(rfile, wfile) + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): @@ -39,6 +37,10 @@ class HTTP1Protocol(object): Raises: HttpError: If the input is invalid. """ + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + httpversion, host, port, scheme, method, path, headers, body = ( None, None, None, None, None, None, None, None) @@ -106,6 +108,12 @@ class HTTP1Protocol(object): True ) + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + return http.Request( form_in, method, @@ -115,7 +123,9 @@ class HTTP1Protocol(object): path, httpversion, headers, - body + body, + timestamp_start, + timestamp_end, ) @@ -124,12 +134,15 @@ class HTTP1Protocol(object): Returns an http.Response By default, both response header and body are read. - If include_body=False is specified, content may be one of the + If include_body=False is specified, body may be one of the following: - None, if the response is technically allowed to have a response body - "", if the response must not have a response body (e.g. it's a response to a HEAD request) """ + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() line = self.tcp_handler.rfile.readline() # Possible leftover from previous message @@ -149,7 +162,7 @@ class HTTP1Protocol(object): raise HttpError(502, "Invalid headers.") if include_body: - content = self.read_http_body( + body = self.read_http_body( headers, body_size_limit, request_method, @@ -157,10 +170,29 @@ class HTTP1Protocol(object): False ) else: - # if include_body==False then a None content means the body should be + # if include_body==False then a None body means the body should be # read separately - content = None - return http.Response(httpversion, code, msg, headers, content) + body = None + + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + return http.Response( + httpversion, + code, + msg, + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) def read_headers(self): diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 55b5ca76..41321fdc 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -1,11 +1,18 @@ from __future__ import (absolute_import, print_function, division) import itertools +import time from hpack.hpack import Encoder, Decoder from netlib import http, utils, odict from . import frame +class TCPHandler(object): + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + + class HTTP2Protocol(object): ERROR_CODES = utils.BiDi( @@ -31,16 +38,26 @@ class HTTP2Protocol(object): ALPN_PROTO_H2 = 'h2' - def __init__(self, tcp_handler, is_server=False, dump_frames=False): - self.tcp_handler = tcp_handler + + def __init__( + self, + tcp_handler=None, + rfile=None, + wfile=None, + is_server=False, + dump_frames=False, + encoder=None, + decoder=None, + ): + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) self.is_server = is_server + self.dump_frames = dump_frames + self.encoder = encoder or Encoder() + self.decoder = decoder or Decoder() self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() self.connection_preface_performed = False - self.dump_frames = dump_frames def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() @@ -186,29 +203,68 @@ class HTTP2Protocol(object): self._create_headers(headers, stream_id, end_stream=(body is None)), self._create_body(body, stream_id))) - def read_response(self, *args): - stream_id, headers, body = self._receive_transmission() + def read_response(self, request_method_='', body_size_limit_=None, include_body=True): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - status = headers[':status'][0] - response = http.Response("HTTP/2", status, "", headers, body) + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + response = http.Response( + (2, 0), + headers[':status'][0], + "", + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) response.stream_id = stream_id + return response - def read_request(self): - stream_id, headers, body = self._receive_transmission() + def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() - form_in = "" - method = headers.get(':method', [''])[0] - scheme = headers.get(':scheme', [''])[0] - host = headers.get(':host', [''])[0] port = '' # TODO: parse port number? - path = headers.get(':path', [''])[0] - request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) + request = http.Request( + "", + headers.get_first(':method', ['']), + headers.get_first(':scheme', ['']), + headers.get_first(':host', ['']), + port, + headers.get_first(':path', ['']), + (2, 0), + headers, + body, + timestamp_start, + timestamp_end, + ) request.stream_id = stream_id + return request - def _receive_transmission(self): + def _receive_transmission(self, include_body=True): body_expected = True stream_id = 0 diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 9e13edaa..63b6beb9 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -20,7 +20,11 @@ class Request(object): httpversion, headers, body, + timestamp_start=None, + timestamp_end=None, ): + assert isinstance(headers, odict.ODictCaseless) or not headers + self.form_in = form_in self.method = method self.scheme = scheme @@ -30,17 +34,30 @@ class Request(object): self.httpversion = httpversion self.headers = headers self.body = body + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end def __eq__(self, other): - return self.__dict__ == other.__dict__ + try: + self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False def __repr__(self): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) @property def content(self): + # TODO: remove deprecated getter return self.body + @content.setter + def content(self, content): + # TODO: remove deprecated setter + self.body = content + class EmptyRequest(Request): def __init__(self): @@ -67,24 +84,52 @@ class Response(object): headers, body, sslinfo=None, + timestamp_start=None, + timestamp_end=None, ): + assert isinstance(headers, odict.ODictCaseless) or not headers + self.httpversion = httpversion self.status_code = status_code self.msg = msg self.headers = headers self.body = body self.sslinfo = sslinfo + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end def __eq__(self, other): - return self.__dict__ == other.__dict__ + try: + self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False def __repr__(self): return "Response(%s - %s)" % (self.status_code, self.msg) @property def content(self): + # TODO: remove deprecated getter return self.body + @content.setter + def content(self, content): + # TODO: remove deprecated setter + self.body = content + + @property + def code(self): + # TODO: remove deprecated getter + return self.status_code + + @code.setter + def code(self, code): + # TODO: remove deprecated setter + self.status_code = code + + def is_valid_port(port): if not 0 <= port <= 65535: diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index dcebbd5e..b196b7a3 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -297,10 +297,10 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase): def test_read_response(): - def tst(data, method, limit, include_body=True): + def tst(data, method, body_size_limit, include_body=True): data = textwrap.dedent(data) return mock_protocol(data).read_response( - method, limit, include_body=include_body + method, body_size_limit, include_body=include_body ) tutils.raises("server disconnect", tst, "", "GET", None) diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index d3040266..0216128f 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -253,7 +253,7 @@ class TestReadResponse(tservers.ServerTestBase): resp = protocol.read_response() - assert resp.httpversion == "HTTP/2" + assert resp.httpversion == (2, 0) assert resp.status_code == "200" assert resp.msg == "" assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] @@ -279,7 +279,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): resp = protocol.read_response() assert resp.stream_id - assert resp.httpversion == "HTTP/2" + assert resp.httpversion == (2, 0) assert resp.status_code == "200" assert resp.msg == "" assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] -- cgit v1.2.3 From c7fcc2cca5ff85641febbb908d11d22336bbd81c Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 29 Jul 2015 11:27:43 +0200 Subject: add on-the-wire representation methods --- netlib/http/http1/protocol.py | 101 ++++++++++++++- netlib/http/http2/protocol.py | 261 ++++++++++++++++++++------------------- netlib/http/semantics.py | 46 +++++-- netlib/utils.py | 10 ++ test/http/http2/test_protocol.py | 63 +++++++--- 5 files changed, 324 insertions(+), 157 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index af9882e8..b098110a 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -7,6 +7,7 @@ import urlparse import time from netlib import odict, utils, tcp, http +from netlib.http import semantics from .. import status_codes from ..exceptions import * @@ -15,7 +16,7 @@ class TCPHandler(object): self.rfile = rfile self.wfile = wfile -class HTTP1Protocol(object): +class HTTP1Protocol(semantics.ProtocolMixin): def __init__(self, tcp_handler=None, rfile=None, wfile=None): self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) @@ -195,6 +196,32 @@ class HTTP1Protocol(object): ) + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + if request.body == semantics.CONTENT_MISSING: + raise http.HttpError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + first_line = self._assemble_request_first_line(request) + headers = self._assemble_request_headers(request) + return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) + + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + if response.body == semantics.CONTENT_MISSING: + raise http.HttpError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + first_line = self._assemble_response_first_line(response) + headers = self._assemble_response_headers(response) + return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) + + def read_headers(self): """ Read a set of headers. @@ -363,7 +390,6 @@ class HTTP1Protocol(object): return line - def _read_chunked(self, limit, is_request): """ Read a chunked HTTP body. @@ -526,3 +552,74 @@ class HTTP1Protocol(object): except ValueError: return None return (proto, code, msg) + + + @classmethod + def _assemble_request_first_line(self, request): + if request.form_in == "relative": + request_line = '%s %s HTTP/%s.%s' % ( + request.method, + request.path, + request.httpversion[0], + request.httpversion[1], + ) + elif request.form_in == "authority": + request_line = '%s %s:%s HTTP/%s.%s' % ( + request.method, + request.host, + request.port, + request.httpversion[0], + request.httpversion[1], + ) + elif request.form_in == "absolute": + request_line = '%s %s://%s:%s%s HTTP/%s.%s' % ( + request.method, + request.scheme, + request.host, + request.port, + request.path, + request.httpversion[0], + request.httpversion[1], + ) + else: + raise http.HttpError(400, "Invalid request form") + return request_line + + def _assemble_request_headers(self, request): + headers = request.headers.copy() + for k in request._headers_to_strip_off: + del headers[k] + if 'host' not in headers and request.scheme and request.host and request.port: + headers["Host"] = [utils.hostport(request.scheme, + request.host, + request.port)] + + # If content is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if request.body or request.body == "": + headers["Content-Length"] = [str(len(request.body))] + + return headers.format() + + + def _assemble_response_first_line(self, response): + return 'HTTP/%s.%s %s %s' % ( + response.httpversion[0], + response.httpversion[1], + response.status_code, + response.msg, + ) + + def _assemble_response_headers(self, response, preserve_transfer_encoding=False): + headers = response.headers.copy() + for k in response._headers_to_strip_off: + del headers[k] + if not preserve_transfer_encoding: + del headers['Transfer-Encoding'] + + # If body is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if response.body or response.body == "": + headers["Content-Length"] = [str(len(response.body))] + + return headers.format() diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 41321fdc..618476e2 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -4,6 +4,7 @@ import time from hpack.hpack import Encoder, Decoder from netlib import http, utils, odict +from netlib.http import semantics from . import frame @@ -13,7 +14,7 @@ class TCPHandler(object): self.wfile = wfile -class HTTP2Protocol(object): +class HTTP2Protocol(semantics.ProtocolMixin): ERROR_CODES = utils.BiDi( NO_ERROR=0x0, @@ -59,26 +60,104 @@ class HTTP2Protocol(object): self.current_stream_id = None self.connection_preface_performed = False - def check_alpn(self): - alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - return True + def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() - def _receive_settings(self, hide=False): - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - break + stream_id, headers, body = self._receive_transmission(include_body) - def _read_settings_ack(self, hide=False): # pragma no cover - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - assert frm.flags & frame.Frame.FLAG_ACK - assert len(frm.settings) == 0 - break + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + + port = '' # TODO: parse port number? + + request = http.Request( + "", + headers.get_first(':method', ['']), + headers.get_first(':scheme', ['']), + headers.get_first(':host', ['']), + port, + headers.get_first(':path', ['']), + (2, 0), + headers, + body, + timestamp_start, + timestamp_end, + ) + request.stream_id = stream_id + + return request + + def read_response(self, request_method_='', body_size_limit_=None, include_body=True): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + response = http.Response( + (2, 0), + headers[':status'][0], + "", + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + response.stream_id = stream_id + + return response + + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = [ + (b':method', bytes(request.method)), + (b':path', bytes(request.path)), + (b':scheme', b'https'), + (b':authority', authority), + ] + request.headers.items() + + if hasattr(request, 'stream_id'): + stream_id = request.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(request.body is None)), + self._create_body(request.body, stream_id))) + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + headers = [(b':status', bytes(str(response.status_code)))] + response.headers.items() + + if hasattr(response, 'stream_id'): + stream_id = response.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(response.body is None)), + self._create_body(response.body, stream_id), + )) def perform_server_connection_preface(self, force=False): if force or not self.connection_preface_performed: @@ -100,18 +179,6 @@ class HTTP2Protocol(object): self.send_frame(frame.SettingsFrame(state=self), hide=True) self._receive_settings(hide=True) - def next_stream_id(self): - if self.current_stream_id is None: - if self.is_server: - # servers must use even stream ids - self.current_stream_id = 2 - else: - # clients must use odd stream ids - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - def send_frame(self, frm, hide=False): raw_bytes = frm.to_bytes() self.tcp_handler.wfile.write(raw_bytes) @@ -128,6 +195,39 @@ class HTTP2Protocol(object): return frm + def check_alpn(self): + alp = self.tcp_handler.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + return True + + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + + def _read_settings_ack(self, hide=False): # pragma no cover + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + assert frm.flags & frame.Frame.FLAG_ACK + assert len(frm.settings) == 0 + break + + def _next_stream_id(self): + if self.current_stream_id is None: + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + def _apply_settings(self, settings, hide=False): for setting, value in settings.items(): old_value = self.http2_settings[setting] @@ -181,89 +281,6 @@ class HTTP2Protocol(object): return [frm.to_bytes()] - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self, request_method_='', body_size_limit_=None, include_body=True): - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission(include_body) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - if include_body: - timestamp_end = time.time() - else: - timestamp_end = None - - response = http.Response( - (2, 0), - headers[':status'][0], - "", - headers, - body, - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - ) - response.stream_id = stream_id - - return response - - def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission(include_body) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - timestamp_end = time.time() - - port = '' # TODO: parse port number? - - request = http.Request( - "", - headers.get_first(':method', ['']), - headers.get_first(':scheme', ['']), - headers.get_first(':host', ['']), - port, - headers.get_first(':path', ['']), - (2, 0), - headers, - body, - timestamp_start, - timestamp_end, - ) - request.stream_id = stream_id - - return request - def _receive_transmission(self, include_body=True): body_expected = True @@ -295,19 +312,3 @@ class HTTP2Protocol(object): headers.add(header, value) return stream_id, headers, body - - def create_response(self, code, stream_id=None, headers=None, body=None): - if headers is None: - headers = [] - if isinstance(headers, odict.ODict): - headers = headers.items() - - headers = [(b':status', bytes(str(code)))] + headers - - if not stream_id: - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id), - )) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 63b6beb9..54bf83d2 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -7,6 +7,32 @@ import urlparse from .. import utils, odict +CONTENT_MISSING = 0 + + +class ProtocolMixin(object): + + def read_request(self): + raise NotImplemented + + def read_response(self): + raise NotImplemented + + def assemble(self, message): + if isinstance(message, Request): + return self.assemble_request(message) + elif isinstance(message, Response): + return self.assemble_response(message) + else: + raise ValueError("HTTP message not supported.") + + def assemble_request(self, request): + raise NotImplemented + + def assemble_response(self, response): + raise NotImplemented + + class Request(object): def __init__( @@ -18,12 +44,14 @@ class Request(object): port, path, httpversion, - headers, - body, + headers=None, + body=None, timestamp_start=None, timestamp_end=None, ): - assert isinstance(headers, odict.ODictCaseless) or not headers + if not headers: + headers = odict.ODictCaseless() + assert isinstance(headers, odict.ODictCaseless) self.form_in = form_in self.method = method @@ -37,6 +65,7 @@ class Request(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end + def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] @@ -80,14 +109,16 @@ class Response(object): self, httpversion, status_code, - msg, - headers, - body, + msg=None, + headers=None, + body=None, sslinfo=None, timestamp_start=None, timestamp_end=None, ): - assert isinstance(headers, odict.ODictCaseless) or not headers + if not headers: + headers = odict.ODictCaseless() + assert isinstance(headers, odict.ODictCaseless) self.httpversion = httpversion self.status_code = status_code @@ -98,6 +129,7 @@ class Response(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end + def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] diff --git a/netlib/utils.py b/netlib/utils.py index bee412f9..86e33f33 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -129,3 +129,13 @@ class Data(object): if not os.path.exists(fullpath): raise ValueError("dataPath: %s does not exist." % fullpath) return fullpath + + +def hostport(scheme, host, port): + """ + Returns the host component, with a port specifcation if needed. + """ + if (port, scheme) in [(80, "http"), (443, "https")]: + return host + else: + return "%s:%s" % (host, port) diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 0216128f..5febc480 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,6 +1,6 @@ import OpenSSL -from netlib import tcp, odict +from netlib import tcp, odict, http from netlib.http import http2 from netlib.http.http2.frame import * from ... import tutils, tservers @@ -117,11 +117,11 @@ class TestClientStreamIds(): def test_client_stream_ids(self): assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 1 + assert self.protocol._next_stream_id() == 1 assert self.protocol.current_stream_id == 1 - assert self.protocol.next_stream_id() == 3 + assert self.protocol._next_stream_id() == 3 assert self.protocol.current_stream_id == 3 - assert self.protocol.next_stream_id() == 5 + assert self.protocol._next_stream_id() == 5 assert self.protocol.current_stream_id == 5 @@ -131,11 +131,11 @@ class TestServerStreamIds(): def test_server_stream_ids(self): assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 2 + assert self.protocol._next_stream_id() == 2 assert self.protocol.current_stream_id == 2 - assert self.protocol.next_stream_id() == 4 + assert self.protocol._next_stream_id() == 4 assert self.protocol.current_stream_id == 4 - assert self.protocol.next_stream_id() == 6 + assert self.protocol._next_stream_id() == 6 assert self.protocol.current_stream_id == 6 @@ -215,17 +215,36 @@ class TestCreateBody(): # TODO: add test for too large frames -class TestCreateRequest(): +class TestAssembleRequest(): c = tcp.TCPClient(("127.0.0.1", 0)) - def test_create_request_simple(self): - bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') + def test_assemble_request_simple(self): + bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + '', + '', + '', + '/', + (2, 0), + None, + None, + )) assert len(bytes) == 1 assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') - def test_create_request_with_body(self): - bytes = http2.HTTP2Protocol(self.c).create_request( - 'GET', '/', [(b'foo', b'bar')], 'foobar') + def test_assemble_request_with_body(self): + bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + '', + '', + '', + '/', + (2, 0), + odict.ODictCaseless([('foo', 'bar')]), + 'foobar', + )) assert len(bytes) == 2 assert bytes[0] ==\ '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') @@ -315,16 +334,24 @@ class TestCreateResponse(): c = tcp.TCPClient(("127.0.0.1", 0)) def test_create_response_simple(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) + bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + (2, 0), + 200, + )) assert len(bytes) == 1 assert bytes[0] ==\ '00000101050000000288'.decode('hex') def test_create_response_with_body(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( - 200, 1, [(b'foo', b'bar')], 'foobar') + bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + (2, 0), + 200, + '', + odict.ODictCaseless([('foo', 'bar')]), + 'foobar' + )) assert len(bytes) == 2 assert bytes[0] ==\ - '00000901040000000188408294e7838c767f'.decode('hex') + '00000901040000000288408294e7838c767f'.decode('hex') assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') + '000006000100000002666f6f626172'.decode('hex') -- cgit v1.2.3 From 7b10817670b30550dd45af48491ed8cf3cacd5e6 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 30 Jul 2015 13:52:13 +0200 Subject: http2: improve protocol --- netlib/http/http2/protocol.py | 61 +++++++++++++++++++++++++++------------- netlib/odict.py | 7 +++-- test/http/http2/test_protocol.py | 11 +++++--- 3 files changed, 53 insertions(+), 26 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 618476e2..a1ca4a18 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -60,7 +60,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): self.current_stream_id = None self.connection_preface_performed = False - def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + self.perform_connection_preface() + timestamp_start = time.time() if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() @@ -73,15 +75,13 @@ class HTTP2Protocol(semantics.ProtocolMixin): timestamp_end = time.time() - port = '' # TODO: parse port number? - request = http.Request( - "", - headers.get_first(':method', ['']), - headers.get_first(':scheme', ['']), - headers.get_first(':host', ['']), - port, - headers.get_first(':path', ['']), + "relative", # TODO: use the correct value + headers.get_first(':method', 'GET'), + headers.get_first(':scheme', 'https'), + headers.get_first(':host', 'localhost'), + 443, # TODO: parse port number from host? + headers.get_first(':path', '/'), (2, 0), headers, body, @@ -92,7 +92,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): return request - def read_response(self, request_method_='', body_size_limit_=None, include_body=True): + def read_response(self, request_method='', body_size_limit=None, include_body=True): + self.perform_connection_preface() + timestamp_start = time.time() if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() @@ -110,7 +112,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): response = http.Response( (2, 0), - headers[':status'][0], + int(headers.get_first(':status')), "", headers, body, @@ -121,6 +123,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): return response + def assemble_request(self, request): assert isinstance(request, semantics.Request) @@ -128,12 +131,18 @@ class HTTP2Protocol(semantics.ProtocolMixin): if self.tcp_handler.address.port != 443: authority += ":%d" % self.tcp_handler.address.port - headers = [ - (b':method', bytes(request.method)), - (b':path', bytes(request.path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + request.headers.items() + headers = request.headers.copy() + + if not ':authority' in headers.keys(): + headers.add(':authority', bytes(authority), prepend=True) + if not ':scheme' in headers.keys(): + headers.add(':scheme', bytes(request.scheme), prepend=True) + if not ':path' in headers.keys(): + headers.add(':path', bytes(request.path), prepend=True) + if not ':method' in headers.keys(): + headers.add(':method', bytes(request.method), prepend=True) + + headers = headers.items() if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -141,13 +150,18 @@ class HTTP2Protocol(semantics.ProtocolMixin): stream_id = self._next_stream_id() return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(request.body is None)), + self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), self._create_body(request.body, stream_id))) def assemble_response(self, response): assert isinstance(response, semantics.Response) - headers = [(b':status', bytes(str(response.status_code)))] + response.headers.items() + headers = response.headers.copy() + + if not ':status' in headers.keys(): + headers.add(':status', bytes(str(response.status_code)), prepend=True) + + headers = headers.items() if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -155,10 +169,17 @@ class HTTP2Protocol(semantics.ProtocolMixin): stream_id = self._next_stream_id() return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(response.body is None)), + self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), self._create_body(response.body, stream_id), )) + def perform_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + if self.is_server: + self.perform_server_connection_preface(force) + else: + self.perform_client_connection_preface(force) + def perform_server_connection_preface(self, force=False): if force or not self.connection_preface_performed: self.connection_preface_performed = True diff --git a/netlib/odict.py b/netlib/odict.py index f52acd50..d02de08d 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -96,8 +96,11 @@ class ODict(object): return True return False - def add(self, key, value): - self.lst.append([key, value]) + def add(self, key, value, prepend=False): + if prepend: + self.lst.insert(0, [key, value]) + else: + self.lst.append([key, value]) def get(self, k, d=None): if k in self: diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 5febc480..b2d414d1 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -222,7 +222,7 @@ class TestAssembleRequest(): bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( '', 'GET', - '', + 'https', '', '', '/', @@ -237,7 +237,7 @@ class TestAssembleRequest(): bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( '', 'GET', - '', + 'https', '', '', '/', @@ -269,11 +269,12 @@ class TestReadResponse(tservers.ServerTestBase): c.connect() c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) + protocol.connection_preface_performed = True resp = protocol.read_response() assert resp.httpversion == (2, 0) - assert resp.status_code == "200" + assert resp.status_code == 200 assert resp.msg == "" assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'foobar' @@ -294,12 +295,13 @@ class TestReadEmptyResponse(tservers.ServerTestBase): c.connect() c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) + protocol.connection_preface_performed = True resp = protocol.read_response() assert resp.stream_id assert resp.httpversion == (2, 0) - assert resp.status_code == "200" + assert resp.status_code == 200 assert resp.msg == "" assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'' @@ -322,6 +324,7 @@ class TestReadRequest(tservers.ServerTestBase): c.connect() c.convert_to_ssl() protocol = http2.HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True resp = protocol.read_request() -- cgit v1.2.3 From a837230320378d629ba9f25960b1dfd25c892ad9 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 1 Aug 2015 10:39:14 +0200 Subject: move code from mitmproxy to netlib --- netlib/encoding.py | 82 +++++++++ netlib/http/exceptions.py | 13 ++ netlib/http/http1/protocol.py | 39 +---- netlib/http/semantics.py | 366 +++++++++++++++++++++++++++++++-------- netlib/tutils.py | 125 +++++++++++++ netlib/utils.py | 100 +++++++++++ test/http/http1/test_protocol.py | 10 -- test/http/test_exceptions.py | 6 + test/http/test_semantics.py | 295 ++++++++++++++++++++++++++----- test/test_utils.py | 77 +++++++- test/tutils.py | 68 -------- 11 files changed, 952 insertions(+), 229 deletions(-) create mode 100644 netlib/encoding.py create mode 100644 netlib/tutils.py create mode 100644 test/http/test_exceptions.py delete mode 100644 test/tutils.py diff --git a/netlib/encoding.py b/netlib/encoding.py new file mode 100644 index 00000000..f107eb5f --- /dev/null +++ b/netlib/encoding.py @@ -0,0 +1,82 @@ +""" + Utility functions for decoding response bodies. +""" +from __future__ import absolute_import +import cStringIO +import gzip +import zlib + +__ALL__ = ["ENCODINGS"] + +ENCODINGS = set(["identity", "gzip", "deflate"]) + + +def decode(e, content): + encoding_map = { + "identity": identity, + "gzip": decode_gzip, + "deflate": decode_deflate, + } + if e not in encoding_map: + return None + return encoding_map[e](content) + + +def encode(e, content): + encoding_map = { + "identity": identity, + "gzip": encode_gzip, + "deflate": encode_deflate, + } + if e not in encoding_map: + return None + return encoding_map[e](content) + + +def identity(content): + """ + Returns content unchanged. Identity is the default value of + Accept-Encoding headers. + """ + return content + + +def decode_gzip(content): + gfile = gzip.GzipFile(fileobj=cStringIO.StringIO(content)) + try: + return gfile.read() + except (IOError, EOFError): + return None + + +def encode_gzip(content): + s = cStringIO.StringIO() + gf = gzip.GzipFile(fileobj=s, mode='wb') + gf.write(content) + gf.close() + return s.getvalue() + + +def decode_deflate(content): + """ + Returns decompressed data for DEFLATE. Some servers may respond with + compressed data without a zlib header or checksum. An undocumented + feature of zlib permits the lenient decompression of data missing both + values. + + http://bugs.python.org/issue5784 + """ + try: + try: + return zlib.decompress(content) + except zlib.error: + return zlib.decompress(content, -15) + except zlib.error: + return None + + +def encode_deflate(content): + """ + Returns compressed content, always including zlib header and checksum. + """ + return zlib.compress(content) diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py index 8a2bbebc..45bd2dce 100644 --- a/netlib/http/exceptions.py +++ b/netlib/http/exceptions.py @@ -7,3 +7,16 @@ class HttpError(Exception): class HttpErrorConnClosed(HttpError): pass + + + +class HttpAuthenticationError(Exception): + def __init__(self, auth_headers=None): + super(HttpAuthenticationError, self).__init__( + "Proxy Authentication Required" + ) + self.headers = auth_headers + self.code = 407 + + def __repr__(self): + return "Proxy Authentication Required" diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index b098110a..a189bffc 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -375,7 +375,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): @classmethod def has_chunked_encoding(self, headers): return "chunked" in [ - i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") + i.lower() for i in utils.get_header_tokens(headers, "transfer-encoding") ] @@ -482,9 +482,9 @@ class HTTP1Protocol(semantics.ProtocolMixin): port = int(port) except ValueError: return None - if not http.is_valid_port(port): + if not utils.is_valid_port(port): return None - if not http.is_valid_host(host): + if not utils.is_valid_host(host): return None return host, port, httpversion @@ -496,7 +496,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None method, url, httpversion = v - parts = http.parse_url(url) + parts = utils.parse_url(url) if not parts: return None scheme, host, port, path = parts @@ -528,7 +528,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): """ # At first, check if we have an explicit Connection header. if "connection" in headers: - toks = http.get_header_tokens(headers, "connection") + toks = utils.get_header_tokens(headers, "connection") if "close" in toks: return True elif "keep-alive" in toks: @@ -556,34 +556,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): @classmethod def _assemble_request_first_line(self, request): - if request.form_in == "relative": - request_line = '%s %s HTTP/%s.%s' % ( - request.method, - request.path, - request.httpversion[0], - request.httpversion[1], - ) - elif request.form_in == "authority": - request_line = '%s %s:%s HTTP/%s.%s' % ( - request.method, - request.host, - request.port, - request.httpversion[0], - request.httpversion[1], - ) - elif request.form_in == "absolute": - request_line = '%s %s://%s:%s%s HTTP/%s.%s' % ( - request.method, - request.scheme, - request.host, - request.port, - request.path, - request.httpversion[0], - request.httpversion[1], - ) - else: - raise http.HttpError(400, "Invalid request form") - return request_line + return request.legacy_first_line() def _assemble_request_headers(self, request): headers = request.headers.copy() diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 54bf83d2..e7ae2b5f 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -3,9 +3,15 @@ import binascii import collections import string import sys +import urllib import urlparse from .. import utils, odict +from . import cookies +from netlib import utils, encoding + +HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 @@ -75,7 +81,240 @@ class Request(object): return False def __repr__(self): - return "Request(%s - %s, %s)" % (self.method, self.host, self.path) + # return "Request(%s - %s, %s)" % (self.method, self.host, self.path) + + return "".format( + self.legacy_first_line()[:-9] + ) + + def legacy_first_line(self): + if self.form_in == "relative": + return '%s %s HTTP/%s.%s' % ( + self.method, + self.path, + self.httpversion[0], + self.httpversion[1], + ) + elif self.form_in == "authority": + return '%s %s:%s HTTP/%s.%s' % ( + self.method, + self.host, + self.port, + self.httpversion[0], + self.httpversion[1], + ) + elif self.form_in == "absolute": + return '%s %s://%s:%s%s HTTP/%s.%s' % ( + self.method, + self.scheme, + self.host, + self.port, + self.path, + self.httpversion[0], + self.httpversion[1], + ) + else: + raise http.HttpError(400, "Invalid request form") + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + del self.headers[i] + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = ["identity"] + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + if self.headers["accept-encoding"]: + self.headers["accept-encoding"] = [ + ', '.join( + e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0])] + + def update_host_header(self): + """ + Update the host header to reflect the current target. + """ + self.headers["Host"] = [self.host] + + def get_form(self): + """ + Retrieves the URL-encoded or multipart form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.body: + if self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): + return self.get_form_urlencoded() + elif self.headers.in_any("content-type", HDR_FORM_MULTIPART, True): + return self.get_form_multipart() + return odict.ODict([]) + + def get_form_urlencoded(self): + """ + Retrieves the URL-encoded form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.body and self.headers.in_any( + "content-type", + HDR_FORM_URLENCODED, + True): + return odict.ODict(utils.urldecode(self.body)) + return odict.ODict([]) + + def get_form_multipart(self): + if self.body and self.headers.in_any( + "content-type", + HDR_FORM_MULTIPART, + True): + return odict.ODict( + utils.multipartdecode( + self.headers, + self.body)) + return odict.ODict([]) + + def set_form_urlencoded(self, odict): + """ + Sets the body to the URL-encoded form data, and adds the + appropriate content-type header. Note that this will destory the + existing body if there is one. + """ + # FIXME: If there's an existing content-type header indicating a + # url-encoded form, leave it alone. + self.headers["Content-Type"] = [HDR_FORM_URLENCODED] + self.body = utils.urlencode(odict.lst) + + def get_path_components(self): + """ + Returns the path components of the URL as a list of strings. + + Components are unquoted. + """ + _, _, path, _, _, _ = urlparse.urlparse(self.url) + return [urllib.unquote(i) for i in path.split("/") if i] + + def set_path_components(self, lst): + """ + Takes a list of strings, and sets the path component of the URL. + + Components are quoted. + """ + lst = [urllib.quote(i, safe="") for i in lst] + path = "/" + "/".join(lst) + scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url) + self.url = urlparse.urlunparse( + [scheme, netloc, path, params, query, fragment] + ) + + def get_query(self): + """ + Gets the request query string. Returns an ODict object. + """ + _, _, _, _, query, _ = urlparse.urlparse(self.url) + if query: + return odict.ODict(utils.urldecode(query)) + return odict.ODict([]) + + def set_query(self, odict): + """ + Takes an ODict object, and sets the request query string. + """ + scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.url) + query = utils.urlencode(odict.lst) + self.url = urlparse.urlunparse( + [scheme, netloc, path, params, query, fragment] + ) + + def pretty_host(self, hostheader): + """ + Heuristic to get the host of the request. + + Note that pretty_host() does not always return the TCP destination + of the request, e.g. if an upstream proxy is in place + + If hostheader is set to True, the Host: header will be used as + additional (and preferred) data source. This is handy in + transparent mode, where only the IO of the destination is known, + but not the resolved name. This is disabled by default, as an + attacker may spoof the host header to confuse an analyst. + """ + host = None + if hostheader: + host = self.headers.get_first("host") + if not host: + host = self.host + if host: + try: + return host.encode("idna") + except ValueError: + return host + else: + return None + + def pretty_url(self, hostheader): + if self.form_out == "authority": # upstream proxy mode + return "%s:%s" % (self.pretty_host(hostheader), self.port) + return utils.unparse_url(self.scheme, + self.pretty_host(hostheader), + self.port, + self.path).encode('ascii') + + def get_cookies(self): + """ + Returns a possibly empty netlib.odict.ODict object. + """ + ret = odict.ODict() + for i in self.headers["cookie"]: + ret.extend(cookies.parse_cookie_header(i)) + return ret + + def set_cookies(self, odict): + """ + Takes an netlib.odict.ODict object. Over-writes any existing Cookie + headers. + """ + v = cookies.format_cookie_header(odict) + self.headers["Cookie"] = [v] + + @property + def url(self): + """ + Returns a URL string, constructed from the Request's URL components. + """ + return utils.unparse_url( + self.scheme, + self.host, + self.port, + self.path + ).encode('ascii') + + @url.setter + def url(self, url): + """ + Parses a URL specification, and updates the Request's information + accordingly. + + Returns False if the URL was invalid, True if the request succeeded. + """ + parts = utils.parse_url(url) + if not parts: + raise ValueError("Invalid URL: %s" % url) + self.scheme, self.host, self.port, self.path = parts @property def content(self): @@ -139,7 +378,56 @@ class Response(object): return False def __repr__(self): - return "Response(%s - %s)" % (self.status_code, self.msg) + # return "Response(%s - %s)" % (self.status_code, self.msg) + + if self.body: + size = utils.pretty_size(len(self.body)) + else: + size = "content missing" + return "".format( + status_code=self.status_code, + msg=self.msg, + contenttype=self.headers.get_first( + "content-type", "unknown content type" + ), + size=size + ) + + + def get_cookies(self): + """ + Get the contents of all Set-Cookie headers. + + Returns a possibly empty ODict, where keys are cookie name strings, + and values are [value, attr] lists. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary + attributes (e.g. HTTPOnly) are indicated by a Null value. + """ + ret = [] + for header in self.headers["set-cookie"]: + v = cookies.parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append([name, [value, attrs]]) + return odict.ODict(ret) + + def set_cookies(self, odict): + """ + Set the Set-Cookie headers on this response, over-writing existing + headers. + + Accepts an ODict of the same format as that returned by get_cookies. + """ + values = [] + for i in odict.lst: + values.append( + cookies.format_set_cookie_header( + i[0], + i[1][0], + i[1][1] + ) + ) + self.headers["Set-Cookie"] = values @property def content(self): @@ -160,77 +448,3 @@ class Response(object): def code(self, code): # TODO: remove deprecated setter self.status_code = code - - - -def is_valid_port(port): - if not 0 <= port <= 65535: - return False - return True - - -def is_valid_host(host): - try: - host.decode("idna") - except ValueError: - return False - if "\0" in host: - return None - return True - - -def parse_url(url): - """ - Returns a (scheme, host, port, path) tuple, or None on error. - - Checks that: - port is an integer 0-65535 - host is a valid IDNA-encoded hostname with no null-bytes - path is valid ASCII - """ - try: - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - except ValueError: - return None - if not scheme: - return None - if '@' in netloc: - # FIXME: Consider what to do with the discarded credentials here Most - # probably we should extend the signature to return these as a separate - # value. - _, netloc = string.rsplit(netloc, '@', maxsplit=1) - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None - else: - host = netloc - if scheme == "https": - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path - if not is_valid_host(host): - return None - if not utils.isascii(path): - return None - if not is_valid_port(port): - return None - return scheme, host, port, path - - -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks diff --git a/netlib/tutils.py b/netlib/tutils.py new file mode 100644 index 00000000..5018b9e8 --- /dev/null +++ b/netlib/tutils.py @@ -0,0 +1,125 @@ +import cStringIO +import tempfile +import os +import time +import shutil +from contextlib import contextmanager + +from netlib import tcp, utils, odict, http + + +def treader(bytes): + """ + Construct a tcp.Read object from bytes. + """ + fp = cStringIO.StringIO(bytes) + return tcp.Reader(fp) + + +@contextmanager +def tmpdir(*args, **kwargs): + orig_workdir = os.getcwd() + temp_workdir = tempfile.mkdtemp(*args, **kwargs) + os.chdir(temp_workdir) + + yield temp_workdir + + os.chdir(orig_workdir) + shutil.rmtree(temp_workdir) + + +def raises(exc, obj, *args, **kwargs): + """ + Assert that a callable raises a specified exception. + + :exc An exception class or a string. If a class, assert that an + exception of this type is raised. If a string, assert that the string + occurs in the string representation of the exception, based on a + case-insenstivie match. + + :obj A callable object. + + :args Arguments to be passsed to the callable. + + :kwargs Arguments to be passed to the callable. + """ + try: + ret = obj(*args, **kwargs) + except Exception as v: + if isinstance(exc, basestring): + if exc.lower() in str(v).lower(): + return + else: + raise AssertionError( + "Expected %s, but caught %s" % ( + repr(str(exc)), v + ) + ) + else: + if isinstance(v, exc): + return + else: + raise AssertionError( + "Expected %s, but caught %s %s" % ( + exc.__name__, v.__class__.__name__, str(v) + ) + ) + raise AssertionError("No exception raised. Return value: {}".format(ret)) + +test_data = utils.Data(__name__) + + + + +def treq(content="content", scheme="http", host="address", port=22): + """ + @return: libmproxy.protocol.http.HTTPRequest + """ + headers = odict.ODictCaseless() + headers["header"] = ["qvalue"] + req = http.Request( + "relative", + "GET", + scheme, + host, + port, + "/path", + (1, 1), + headers, + content, + None, + None, + ) + return req + + +def treq_absolute(content="content"): + """ + @return: libmproxy.protocol.http.HTTPRequest + """ + r = treq(content) + r.form_in = r.form_out = "absolute" + r.host = "address" + r.port = 22 + r.scheme = "http" + return r + + +def tresp(content="message"): + """ + @return: libmproxy.protocol.http.HTTPResponse + """ + + headers = odict.ODictCaseless() + headers["header_response"] = ["svalue"] + + resp = http.semantics.Response( + (1, 1), + 200, + "OK", + headers, + content, + time.time(), + time.time(), + ) + return resp diff --git a/netlib/utils.py b/netlib/utils.py index 86e33f33..39354605 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,5 +1,10 @@ from __future__ import (absolute_import, print_function, division) import os.path +import cgi +import urllib +import urlparse +import string + def isascii(s): try: @@ -131,6 +136,81 @@ class Data(object): return fullpath + + +def is_valid_port(port): + if not 0 <= port <= 65535: + return False + return True + + +def is_valid_host(host): + try: + host.decode("idna") + except ValueError: + return False + if "\0" in host: + return None + return True + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + + Checks that: + port is an integer 0-65535 + host is a valid IDNA-encoded hostname with no null-bytes + path is valid ASCII + """ + try: + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + except ValueError: + return None + if not scheme: + return None + if '@' in netloc: + # FIXME: Consider what to do with the discarded credentials here Most + # probably we should extend the signature to return these as a separate + # value. + _, netloc = string.rsplit(netloc, '@', maxsplit=1) + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + if not is_valid_host(host): + return None + if not isascii(path): + return None + if not is_valid_port(port): + return None + return scheme, host, port, path + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: + for j in i.split(","): + toks.append(j.strip()) + return toks + + def hostport(scheme, host, port): """ Returns the host component, with a port specifcation if needed. @@ -139,3 +219,23 @@ def hostport(scheme, host, port): return host else: return "%s:%s" % (host, port) + +def unparse_url(scheme, host, port, path=""): + """ + Returns a URL string, constructed from the specified compnents. + """ + return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) + + +def urlencode(s): + """ + Takes a list of (key, value) tuples and returns a urlencoded string. + """ + s = [tuple(i) for i in s] + return urllib.urlencode(s, False) + +def urldecode(s): + """ + Takes a urlencoded string and returns a list of (key, value) tuples. + """ + return cgi.parse_qsl(s, keep_blank_values=True) diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index b196b7a3..05bad1af 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -75,16 +75,6 @@ def test_connection_close(): assert HTTP1Protocol.connection_close((1, 1), h) -def test_get_header_tokens(): - h = odict.ODictCaseless() - assert http.get_header_tokens(h, "foo") == [] - h["foo"] = ["bar"] - assert http.get_header_tokens(h, "foo") == ["bar"] - h["foo"] = ["bar, voing"] - assert http.get_header_tokens(h, "foo") == ["bar", "voing"] - h["foo"] = ["bar, voing", "oink"] - assert http.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] - def test_read_http_body_request(): h = odict.ODictCaseless() diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py new file mode 100644 index 00000000..aa57f831 --- /dev/null +++ b/test/http/test_exceptions.py @@ -0,0 +1,6 @@ +from netlib.http.exceptions import * + +def test_HttpAuthenticationError(): + x = HttpAuthenticationError({"foo": "bar"}) + assert str(x) + assert "foo" in x.headers diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py index c4605302..986afc39 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_semantics.py @@ -1,54 +1,267 @@ import cStringIO import textwrap import binascii +from mock import MagicMock from netlib import http, odict, tcp from netlib.http import http1 +from netlib.http.semantics import CONTENT_MISSING from .. import tutils, tservers def test_httperror(): e = http.exceptions.HttpError(404, "Not found") assert str(e) +class TestRequest: + # def test_asterisk_form_in(self): + # f = tutils.tflow(req=None) + # protocol = mock_protocol("OPTIONS * HTTP/1.1") + # f.request = HTTPRequest.from_protocol(protocol) + # + # assert f.request.form_in == "relative" + # f.request.host = f.server_conn.address.host + # f.request.port = f.server_conn.address.port + # f.request.scheme = "http" + # assert protocol.assemble(f.request) == ( + # "OPTIONS * HTTP/1.1\r\n" + # "Host: address:22\r\n" + # "Content-Length: 0\r\n\r\n") + # + # def test_relative_form_in(self): + # protocol = mock_protocol("GET /foo\xff HTTP/1.1") + # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) + # + # protocol = mock_protocol("GET /foo HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: h2c") + # r = HTTPRequest.from_protocol(protocol) + # assert r.headers["Upgrade"] == ["h2c"] + # + # def test_expect_header(self): + # protocol = mock_protocol( + # "GET / HTTP/1.1\r\nContent-Length: 3\r\nExpect: 100-continue\r\n\r\nfoobar") + # r = HTTPRequest.from_protocol(protocol) + # assert protocol.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" + # assert r.content == "foo" + # assert protocol.tcp_handler.rfile.read(3) == "bar" + # + # def test_authority_form_in(self): + # protocol = mock_protocol("CONNECT oops-no-port.com HTTP/1.1") + # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) + # + # protocol = mock_protocol("CONNECT address:22 HTTP/1.1") + # r = HTTPRequest.from_protocol(protocol) + # r.scheme, r.host, r.port = "http", "address", 22 + # assert protocol.assemble(r) == ( + # "CONNECT address:22 HTTP/1.1\r\n" + # "Host: address:22\r\n" + # "Content-Length: 0\r\n\r\n") + # assert r.pretty_url(False) == "address:22" + # + # def test_absolute_form_in(self): + # protocol = mock_protocol("GET oops-no-protocol.com HTTP/1.1") + # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) + # + # protocol = mock_protocol("GET http://address:22/ HTTP/1.1") + # r = HTTPRequest.from_protocol(protocol) + # assert protocol.assemble(r) == ( + # "GET http://address:22/ HTTP/1.1\r\n" + # "Host: address:22\r\n" + # "Content-Length: 0\r\n\r\n") + # + # def test_http_options_relative_form_in(self): + # """ + # Exercises fix for Issue #392. + # """ + # protocol = mock_protocol("OPTIONS /secret/resource HTTP/1.1") + # r = HTTPRequest.from_protocol(protocol) + # r.host = 'address' + # r.port = 80 + # r.scheme = "http" + # assert protocol.assemble(r) == ( + # "OPTIONS /secret/resource HTTP/1.1\r\n" + # "Host: address\r\n" + # "Content-Length: 0\r\n\r\n") + # + # def test_http_options_absolute_form_in(self): + # protocol = mock_protocol("OPTIONS http://address/secret/resource HTTP/1.1") + # r = HTTPRequest.from_protocol(protocol) + # r.host = 'address' + # r.port = 80 + # r.scheme = "http" + # assert protocol.assemble(r) == ( + # "OPTIONS http://address:80/secret/resource HTTP/1.1\r\n" + # "Host: address\r\n" + # "Content-Length: 0\r\n\r\n") -def test_parse_url(): - assert not http.parse_url("") - - u = "http://foo.com:8888/test" - s, h, po, pa = http.parse_url(u) - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - - s, h, po, pa = http.parse_url("http://foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = http.parse_url("http://user:pass@foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = http.parse_url("http://foo") - assert pa == "/" - - s, h, po, pa = http.parse_url("https://foo") - assert po == 443 - - assert not http.parse_url("https://foo:bar") - assert not http.parse_url("https://foo:") - - # Invalid IDNA - assert not http.parse_url("http://\xfafoo") - # Invalid PATH - assert not http.parse_url("http:/\xc6/localhost:56121") - # Null byte in host - assert not http.parse_url("http://foo\0") - # Port out of range - assert not http.parse_url("http://foo:999999") - # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt - assert not http.parse_url('http://lo[calhost') + def test_set_url(self): + r = tutils.treq_absolute() + r.url = "https://otheraddress:42/ORLY" + assert r.scheme == "https" + assert r.host == "otheraddress" + assert r.port == 42 + assert r.path == "/ORLY" + + def test_repr(self): + r = tutils.treq() + assert repr(r) + + def test_pretty_host(self): + r = tutils.treq() + assert r.pretty_host(True) == "address" + assert r.pretty_host(False) == "address" + r.headers["host"] = ["other"] + assert r.pretty_host(True) == "other" + assert r.pretty_host(False) == "address" + r.host = None + assert r.pretty_host(True) == "other" + assert r.pretty_host(False) is None + del r.headers["host"] + assert r.pretty_host(True) is None + assert r.pretty_host(False) is None + + # Invalid IDNA + r.headers["host"] = [".disqus.com"] + assert r.pretty_host(True) == ".disqus.com" + + def test_get_form_for_urlencoded(self): + r = tutils.treq() + r.headers.add("content-type", "application/x-www-form-urlencoded") + r.get_form_urlencoded = MagicMock() + + r.get_form() + + assert r.get_form_urlencoded.called + + def test_get_form_for_multipart(self): + r = tutils.treq() + r.headers.add("content-type", "multipart/form-data") + r.get_form_multipart = MagicMock() + + r.get_form() + + assert r.get_form_multipart.called + + def test_get_cookies_none(self): + h = odict.ODictCaseless() + r = tutils.treq() + r.headers = h + assert len(r.get_cookies()) == 0 + + def test_get_cookies_single(self): + h = odict.ODictCaseless() + h["Cookie"] = ["cookiename=cookievalue"] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 1 + assert result['cookiename'] == ['cookievalue'] + + def test_get_cookies_double(self): + h = odict.ODictCaseless() + h["Cookie"] = [ + "cookiename=cookievalue;othercookiename=othercookievalue" + ] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 2 + assert result['cookiename'] == ['cookievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_get_cookies_withequalsign(self): + h = odict.ODictCaseless() + h["Cookie"] = [ + "cookiename=coo=kievalue;othercookiename=othercookievalue" + ] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 2 + assert result['cookiename'] == ['coo=kievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_set_cookies(self): + h = odict.ODictCaseless() + h["Cookie"] = ["cookiename=cookievalue"] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + result["cookiename"] = ["foo"] + r.set_cookies(result) + assert r.get_cookies()["cookiename"] == ["foo"] + + +class TestResponse(object): + def test_repr(self): + r = tutils.tresp() + assert "unknown content type" in repr(r) + r.headers["content-type"] = ["foo"] + assert "foo" in repr(r) + assert repr(tutils.tresp(content=CONTENT_MISSING)) + + def test_get_cookies_none(self): + h = odict.ODictCaseless() + resp = tutils.tresp() + resp.headers = h + assert not resp.get_cookies() + + def test_get_cookies_simple(self): + h = odict.ODictCaseless() + h["Set-Cookie"] = ["cookiename=cookievalue"] + resp = tutils.tresp() + resp.headers = h + result = resp.get_cookies() + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0] == ["cookievalue", odict.ODict()] + + def test_get_cookies_with_parameters(self): + h = odict.ODictCaseless() + h["Set-Cookie"] = [ + "cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly"] + resp = tutils.tresp() + resp.headers = h + result = resp.get_cookies() + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0][0] == "cookievalue" + attrs = result["cookiename"][0][1] + assert len(attrs) == 4 + assert attrs["domain"] == ["example.com"] + assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"] + assert attrs["path"] == ["/"] + assert attrs["httponly"] == [None] + + def test_get_cookies_no_value(self): + h = odict.ODictCaseless() + h["Set-Cookie"] = [ + "cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/" + ] + resp = tutils.tresp() + resp.headers = h + result = resp.get_cookies() + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0][0] == "" + assert len(result["cookiename"][0][1]) == 2 + + def test_get_cookies_twocookies(self): + h = odict.ODictCaseless() + h["Set-Cookie"] = ["cookiename=cookievalue", "othercookie=othervalue"] + resp = tutils.tresp() + resp.headers = h + result = resp.get_cookies() + assert len(result) == 2 + assert "cookiename" in result + assert result["cookiename"][0] == ["cookievalue", odict.ODict()] + assert "othercookie" in result + assert result["othercookie"][0] == ["othervalue", odict.ODict()] + + def test_set_cookies(self): + resp = tutils.tresp() + v = resp.get_cookies() + v.add("foo", ["bar", odict.ODictCaseless()]) + resp.set_cookies(v) + + v = resp.get_cookies() + assert len(v) == 1 + assert v["foo"] == [["bar", odict.ODictCaseless()]] diff --git a/test/test_utils.py b/test/test_utils.py index 8e66bce4..0153030c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,6 @@ -from netlib import utils +import urlparse + +from netlib import utils, odict import tutils @@ -27,3 +29,76 @@ def test_pretty_size(): assert utils.pretty_size(1024) == "1kB" assert utils.pretty_size(1024 + (1024 / 2.0)) == "1.5kB" assert utils.pretty_size(1024 * 1024) == "1MB" + + + + +def test_parse_url(): + assert not utils.parse_url("") + + u = "http://foo.com:8888/test" + s, h, po, pa = utils.parse_url(u) + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + + s, h, po, pa = utils.parse_url("http://foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = utils.parse_url("http://user:pass@foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = utils.parse_url("http://foo") + assert pa == "/" + + s, h, po, pa = utils.parse_url("https://foo") + assert po == 443 + + assert not utils.parse_url("https://foo:bar") + assert not utils.parse_url("https://foo:") + + # Invalid IDNA + assert not utils.parse_url("http://\xfafoo") + # Invalid PATH + assert not utils.parse_url("http:/\xc6/localhost:56121") + # Null byte in host + assert not utils.parse_url("http://foo\0") + # Port out of range + assert not utils.parse_url("http://foo:999999") + # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt + assert not utils.parse_url('http://lo[calhost') + + +def test_unparse_url(): + assert utils.unparse_url("http", "foo.com", 99, "") == "http://foo.com:99" + assert utils.unparse_url("http", "foo.com", 80, "") == "http://foo.com" + assert utils.unparse_url("https", "foo.com", 80, "") == "https://foo.com:80" + assert utils.unparse_url("https", "foo.com", 443, "") == "https://foo.com" + + +def test_urlencode(): + assert utils.urlencode([('foo', 'bar')]) + + + +def test_urldecode(): + s = "one=two&three=four" + assert len(utils.urldecode(s)) == 2 + + +def test_get_header_tokens(): + h = odict.ODictCaseless() + assert utils.get_header_tokens(h, "foo") == [] + h["foo"] = ["bar"] + assert utils.get_header_tokens(h, "foo") == ["bar"] + h["foo"] = ["bar, voing"] + assert utils.get_header_tokens(h, "foo") == ["bar", "voing"] + h["foo"] = ["bar, voing", "oink"] + assert utils.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] diff --git a/test/tutils.py b/test/tutils.py deleted file mode 100644 index 94139f6f..00000000 --- a/test/tutils.py +++ /dev/null @@ -1,68 +0,0 @@ -import cStringIO -import tempfile -import os -import shutil -from contextlib import contextmanager - -from netlib import tcp, utils - - -def treader(bytes): - """ - Construct a tcp.Read object from bytes. - """ - fp = cStringIO.StringIO(bytes) - return tcp.Reader(fp) - - -@contextmanager -def tmpdir(*args, **kwargs): - orig_workdir = os.getcwd() - temp_workdir = tempfile.mkdtemp(*args, **kwargs) - os.chdir(temp_workdir) - - yield temp_workdir - - os.chdir(orig_workdir) - shutil.rmtree(temp_workdir) - - -def raises(exc, obj, *args, **kwargs): - """ - Assert that a callable raises a specified exception. - - :exc An exception class or a string. If a class, assert that an - exception of this type is raised. If a string, assert that the string - occurs in the string representation of the exception, based on a - case-insenstivie match. - - :obj A callable object. - - :args Arguments to be passsed to the callable. - - :kwargs Arguments to be passed to the callable. - """ - try: - ret = obj(*args, **kwargs) - except Exception as v: - if isinstance(exc, basestring): - if exc.lower() in str(v).lower(): - return - else: - raise AssertionError( - "Expected %s, but caught %s" % ( - repr(str(exc)), v - ) - ) - else: - if isinstance(v, exc): - return - else: - raise AssertionError( - "Expected %s, but caught %s %s" % ( - exc.__name__, v.__class__.__name__, str(v) - ) - ) - raise AssertionError("No exception raised. Return value: {}".format(ret)) - -test_data = utils.Data(__name__) -- cgit v1.2.3 From 0be84fd6b96c170db6020b5aed1e962d64ffedda Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 1 Aug 2015 14:49:15 +0200 Subject: fix tutils imports --- netlib/utils.py | 2 +- test/http/http1/test_protocol.py | 4 ++-- test/http/http2/test_frames.py | 4 ++-- test/http/http2/test_protocol.py | 4 ++-- test/http/test_authentication.py | 3 +-- test/http/test_semantics.py | 4 ++-- test/test_certutils.py | 3 +-- test/test_odict.py | 3 +-- test/test_socks.py | 3 +-- test/test_tcp.py | 4 ++-- test/test_utils.py | 3 +-- test/tservers.py | 3 +-- test/websockets/test_websockets.py | 4 ++-- 13 files changed, 19 insertions(+), 25 deletions(-) diff --git a/netlib/utils.py b/netlib/utils.py index 39354605..35ea0ec7 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -130,7 +130,7 @@ class Data(object): This function will raise ValueError if the path does not exist. """ - fullpath = os.path.join(self.dirname, path) + fullpath = os.path.join(self.dirname, '../test/', path) if not os.path.exists(fullpath): raise ValueError("dataPath: %s does not exist." % fullpath) return fullpath diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 05bad1af..e3c3ff43 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -2,9 +2,9 @@ import cStringIO import textwrap import binascii -from netlib import http, odict, tcp +from netlib import http, odict, tcp, tutils from netlib.http.http1 import HTTP1Protocol -from ... import tutils, tservers +from ... import tservers def mock_protocol(data='', chunked=False): diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py index ee2edc39..077f5bc2 100644 --- a/test/http/http2/test_frames.py +++ b/test/http/http2/test_frames.py @@ -1,7 +1,7 @@ import cStringIO -from test import tutils from nose.tools import assert_equal -from netlib import tcp + +from netlib import tcp, tutils from netlib.http.http2.frame import * diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index b2d414d1..8a27bbb1 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,9 +1,9 @@ import OpenSSL -from netlib import tcp, odict, http +from netlib import tcp, odict, http, tutils from netlib.http import http2 from netlib.http.http2.frame import * -from ... import tutils, tservers +from ... import tservers class EchoHandler(tcp.BaseHandler): diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py index 8f231643..5261e029 100644 --- a/test/http/test_authentication.py +++ b/test/http/test_authentication.py @@ -1,8 +1,7 @@ import binascii -from netlib import odict, http +from netlib import odict, http, tutils from netlib.http import authentication -from .. import tutils def test_parse_http_basic_auth(): diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py index 986afc39..d58a44d2 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_semantics.py @@ -3,10 +3,10 @@ import textwrap import binascii from mock import MagicMock -from netlib import http, odict, tcp +from netlib import http, odict, tcp, tutils from netlib.http import http1 from netlib.http.semantics import CONTENT_MISSING -from .. import tutils, tservers +from .. import tservers def test_httperror(): e = http.exceptions.HttpError(404, "Not found") diff --git a/test/test_certutils.py b/test/test_certutils.py index 50df36ae..b44879f6 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -1,6 +1,5 @@ import os -from netlib import certutils -import tutils +from netlib import certutils, tutils # class TestDNTree: # def test_simple(self): diff --git a/test/test_odict.py b/test/test_odict.py index d66ae59b..be3d862d 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -1,5 +1,4 @@ -from netlib import odict -import tutils +from netlib import odict, tutils class TestODict: diff --git a/test/test_socks.py b/test/test_socks.py index 1b6c2a32..36fc5b3d 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -1,8 +1,7 @@ from cStringIO import StringIO import socket from nose.plugins.skip import SkipTest -from netlib import socks, tcp -import tutils +from netlib import socks, tcp, tutils def test_client_greeting(): diff --git a/test/test_tcp.py b/test/test_tcp.py index 289ed72f..2a5deb2b 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -10,8 +10,8 @@ import mock from OpenSSL import SSL import OpenSSL -from netlib import tcp, certutils -from . import tutils, tservers +from netlib import tcp, certutils, tutils +from . import tservers class EchoHandler(tcp.BaseHandler): diff --git a/test/test_utils.py b/test/test_utils.py index 0153030c..5e681eb6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,7 +1,6 @@ import urlparse -from netlib import utils, odict -import tutils +from netlib import utils, odict, tutils def test_bidi(): diff --git a/test/tservers.py b/test/tservers.py index 5e99c0e2..3f3ea8b4 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -3,8 +3,7 @@ import threading import Queue import cStringIO import OpenSSL -from netlib import tcp, certutils -from . import tutils +from netlib import tcp, certutils, tutils class ServerThread(threading.Thread): diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index fb7ba39a..28dbb833 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -2,10 +2,10 @@ import os from nose.tools import raises -from netlib import tcp, http, websockets +from netlib import tcp, http, websockets, tutils from netlib.http.exceptions import * from netlib.http.http1 import HTTP1Protocol -from .. import tutils, tservers +from .. import tservers class WebSocketsEchoHandler(tcp.BaseHandler): -- cgit v1.2.3 From 1c12e7c2b8bc04a2b01e21ac58771bc958a8ac8a Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 1 Aug 2015 14:53:13 +0200 Subject: move encoding tests from mitmproxy to netlib --- test/test_encoding.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 test/test_encoding.py diff --git a/test/test_encoding.py b/test/test_encoding.py new file mode 100644 index 00000000..faf718ae --- /dev/null +++ b/test/test_encoding.py @@ -0,0 +1,32 @@ +from netlib import encoding + +def test_identity(): + assert "string" == encoding.decode("identity", "string") + assert "string" == encoding.encode("identity", "string") + assert not encoding.encode("nonexistent", "string") + assert None == encoding.decode("nonexistent encoding", "string") + + +def test_gzip(): + assert "string" == encoding.decode( + "gzip", + encoding.encode( + "gzip", + "string")) + assert None == encoding.decode("gzip", "bogus") + + +def test_deflate(): + assert "string" == encoding.decode( + "deflate", + encoding.encode( + "deflate", + "string")) + assert "string" == encoding.decode( + "deflate", + encoding.encode( + "deflate", + "string")[ + 2:- + 4]) + assert None == encoding.decode("deflate", "bogus") -- cgit v1.2.3 From 6a678d86e16ccab7d16a74c79a6a0b928007d532 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 2 Aug 2015 11:27:01 +0200 Subject: fix mitmproxy tests --- netlib/http/exceptions.py | 6 ++++-- netlib/http/http1/protocol.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py index 45bd2dce..7cd26c12 100644 --- a/netlib/http/exceptions.py +++ b/netlib/http/exceptions.py @@ -1,5 +1,6 @@ -class HttpError(Exception): +from netlib import odict +class HttpError(Exception): def __init__(self, code, message): super(HttpError, self).__init__(message) self.code = code @@ -9,12 +10,13 @@ class HttpErrorConnClosed(HttpError): pass - class HttpAuthenticationError(Exception): def __init__(self, auth_headers=None): super(HttpAuthenticationError, self).__init__( "Proxy Authentication Required" ) + if isinstance(auth_headers, dict): + auth_headers = odict.ODictCaseless(auth_headers.items()) self.headers = auth_headers self.code = 407 diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index a189bffc..2e85a762 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -302,7 +302,8 @@ class HTTP1Protocol(semantics.ProtocolMixin): bytes_left = expected_size while bytes_left: chunk_size = min(bytes_left, max_chunk_size) - yield "", self.tcp_handler.rfile.read(chunk_size), "" + content = self.tcp_handler.rfile.read(chunk_size) + yield "", content, "" bytes_left -= chunk_size else: bytes_left = limit or -1 -- cgit v1.2.3 From c2832ef72bd4eed485a1c8d4bcb732da69896444 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 3 Aug 2015 18:06:31 +0200 Subject: fix mitmproxy/mitmproxy#705 --- netlib/tcp.py | 6 +++++- netlib/version_check.py | 25 ++++++++++++------------- test/test_version_check.py | 24 +++++++++++++++--------- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 77c2a531..c355cfdd 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -11,7 +11,11 @@ import certifi import OpenSSL from OpenSSL import SSL -from . import certutils +from . import certutils, version_check + +# This is a rather hackish way to make sure that +# the latest version of pyOpenSSL is actually installed. +version_check.check_pyopenssl_version() EINTR = 4 diff --git a/netlib/version_check.py b/netlib/version_check.py index 5465c901..aae4e8c7 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -1,23 +1,19 @@ -from __future__ import print_function, absolute_import +""" +Having installed a wrong version of pyOpenSSL or netlib is unfortunately a +very common source of error. Check before every start that both versions +are somewhat okay. +""" +from __future__ import division, absolute_import, print_function, unicode_literals import sys import inspect import os.path - import OpenSSL from . import version PYOPENSSL_MIN_VERSION = (0, 15) -def version_check( - mitmproxy_version, - pyopenssl_min_version=PYOPENSSL_MIN_VERSION, - fp=sys.stderr): - """ - Having installed a wrong version of pyOpenSSL or netlib is unfortunately a - very common source of error. Check before every start that both versions - are somewhat okay. - """ +def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): # We don't introduce backward-incompatible changes in patch versions. Only # consider major and minor version. if version.IVERSION[:2] != mitmproxy_version[:2]: @@ -29,12 +25,15 @@ def version_check( file=fp ) sys.exit(1) + + +def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) - if v < pyopenssl_min_version: + if v < min_version: print( "You are using an outdated version of pyOpenSSL:" " mitmproxy requires pyOpenSSL %s or greater." % - str(pyopenssl_min_version), + str(min_version), file=fp ) # Some users apparently have multiple versions of pyOpenSSL installed. diff --git a/test/test_version_check.py b/test/test_version_check.py index bf6ad1f5..a16969d2 100644 --- a/test/test_version_check.py +++ b/test/test_version_check.py @@ -4,19 +4,25 @@ from netlib import version_check, version @mock.patch("sys.exit") -def test_version_check(sexit): +def test_check_mitmproxy_version(sexit): fp = cStringIO.StringIO() - version_check.version_check(version.IVERSION, fp=fp) + version_check.check_mitmproxy_version(version.IVERSION, fp=fp) + assert not fp.getvalue() assert not sexit.called b = (version.IVERSION[0] - 1, version.IVERSION[1]) - version_check.version_check(b, fp=fp) + version_check.check_mitmproxy_version(b, fp=fp) + assert fp.getvalue() assert sexit.called - sexit.reset_mock() - version_check.version_check( - version.IVERSION, - pyopenssl_min_version=(9999,), - fp=fp - ) + +@mock.patch("sys.exit") +def test_check_pyopenssl_version(sexit): + fp = cStringIO.StringIO() + version_check.check_pyopenssl_version(fp=fp) + assert not fp.getvalue() + assert not sexit.called + + version_check.check_pyopenssl_version((9999,), fp=fp) + assert fp.getvalue() assert sexit.called -- cgit v1.2.3 From 690b8b4f4e00d60b373b5a1481930f21bbc5054a Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 5 Aug 2015 21:32:53 +0200 Subject: add move tests and code from mitmproxy --- netlib/http/http1/protocol.py | 14 -- netlib/http/semantics.py | 43 ++-- netlib/odict.py | 3 +- netlib/tutils.py | 4 +- netlib/utils.py | 56 ++++++ test/http/http1/test_protocol.py | 265 ++++++++++++++++--------- test/http/http2/test_protocol.py | 247 ++++++++++++++++------- test/http/test_exceptions.py | 29 ++- test/http/test_semantics.py | 389 +++++++++++++++++++++++++++---------- test/test_utils.py | 31 +++ test/websockets/test_websockets.py | 7 +- 11 files changed, 778 insertions(+), 310 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 2e85a762..31e9cc85 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -359,20 +359,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return -1 - @classmethod - def request_preamble(self, method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - - @classmethod - def response_preamble(self, code, message=None, http_major="1", http_minor="1"): - if message is None: - message = status_codes.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) - - @classmethod def has_chunked_encoding(self, headers): return "chunked" in [ diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index e7ae2b5f..974fe6e6 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -7,7 +7,7 @@ import urllib import urlparse from .. import utils, odict -from . import cookies +from . import cookies, exceptions from netlib import utils, encoding HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" @@ -18,10 +18,10 @@ CONTENT_MISSING = 0 class ProtocolMixin(object): - def read_request(self): + def read_request(self, *args, **kwargs): # pragma: no cover raise NotImplemented - def read_response(self): + def read_response(self, *args, **kwargs): # pragma: no cover raise NotImplemented def assemble(self, message): @@ -32,14 +32,23 @@ class ProtocolMixin(object): else: raise ValueError("HTTP message not supported.") - def assemble_request(self, request): + def assemble_request(self, request): # pragma: no cover raise NotImplemented - def assemble_response(self, response): + def assemble_response(self, response): # pragma: no cover raise NotImplemented class Request(object): + # This list is adopted legacy code. + # We probably don't need to strip off keep-alive. + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding', + 'Upgrade', + ] def __init__( self, @@ -71,7 +80,6 @@ class Request(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] @@ -114,7 +122,7 @@ class Request(object): self.httpversion[1], ) else: - raise http.HttpError(400, "Invalid request form") + raise exceptions.HttpError(400, "Invalid request form") def anticache(self): """ @@ -143,7 +151,7 @@ class Request(object): if self.headers["accept-encoding"]: self.headers["accept-encoding"] = [ ', '.join( - e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0])] + e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))] def update_host_header(self): """ @@ -317,12 +325,12 @@ class Request(object): self.scheme, self.host, self.port, self.path = parts @property - def content(self): + def content(self): # pragma: no cover # TODO: remove deprecated getter return self.body @content.setter - def content(self, content): + def content(self, content): # pragma: no cover # TODO: remove deprecated setter self.body = content @@ -343,6 +351,11 @@ class EmptyRequest(Request): class Response(object): + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Alternate-Protocol', + 'Alt-Svc', + ] def __init__( self, @@ -368,7 +381,6 @@ class Response(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] @@ -393,7 +405,6 @@ class Response(object): size=size ) - def get_cookies(self): """ Get the contents of all Set-Cookie headers. @@ -430,21 +441,21 @@ class Response(object): self.headers["Set-Cookie"] = values @property - def content(self): + def content(self): # pragma: no cover # TODO: remove deprecated getter return self.body @content.setter - def content(self, content): + def content(self, content): # pragma: no cover # TODO: remove deprecated setter self.body = content @property - def code(self): + def code(self): # pragma: no cover # TODO: remove deprecated getter return self.status_code @code.setter - def code(self, code): + def code(self, code): # pragma: no cover # TODO: remove deprecated setter self.status_code = code diff --git a/netlib/odict.py b/netlib/odict.py index d02de08d..11d5d52a 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -91,8 +91,9 @@ class ODict(object): self.lst = self._filter_lst(k, self.lst) def __contains__(self, k): + k = self._kconv(k) for i in self.lst: - if self._kconv(i[0]) == self._kconv(k): + if self._kconv(i[0]) == k: return True return False diff --git a/netlib/tutils.py b/netlib/tutils.py index 5018b9e8..3c471d0d 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -119,7 +119,7 @@ def tresp(content="message"): "OK", headers, content, - time.time(), - time.time(), + timestamp_start=time.time(), + timestamp_end=time.time(), ) return resp diff --git a/netlib/utils.py b/netlib/utils.py index 35ea0ec7..2dfcafc6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -4,6 +4,7 @@ import cgi import urllib import urlparse import string +import re def isascii(s): @@ -239,3 +240,58 @@ def urldecode(s): Takes a urlencoded string and returns a list of (key, value) tuples. """ return cgi.parse_qsl(s, keep_blank_values=True) + + +def parse_content_type(c): + """ + A simple parser for content-type values. Returns a (type, subtype, + parameters) tuple, where type and subtype are strings, and parameters + is a dict. If the string could not be parsed, return None. + + E.g. the following string: + + text/html; charset=UTF-8 + + Returns: + + ("text", "html", {"charset": "UTF-8"}) + """ + parts = c.split(";", 1) + ts = parts[0].split("/", 1) + if len(ts) != 2: + return None + d = {} + if len(parts) == 2: + for i in parts[1].split(";"): + clause = i.split("=", 1) + if len(clause) == 2: + d[clause[0].strip()] = clause[1].strip() + return ts[0].lower(), ts[1].lower(), d + + +def multipartdecode(hdrs, content): + """ + Takes a multipart boundary encoded string and returns list of (key, value) tuples. + """ + v = hdrs.get_first("content-type") + if v: + v = parse_content_type(v) + if not v: + return [] + boundary = v[2].get("boundary") + if not boundary: + return [] + + rx = re.compile(r'\bname="([^"]+)"') + r = [] + + for i in content.split("--" + boundary): + parts = i.splitlines() + if len(parts) > 1 and parts[0][0:2] != "--": + match = rx.search(parts[1]) + if match: + key = match.group(1) + value = "".join(parts[3 + parts[2:].index(""):]) + r.append((key, value)) + return r + return [] diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index e3c3ff43..ff70b87d 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -3,16 +3,40 @@ import textwrap import binascii from netlib import http, odict, tcp, tutils +from netlib.http import semantics from netlib.http.http1 import HTTP1Protocol from ... import tservers -def mock_protocol(data='', chunked=False): +class NoContentLengthHTTPHandler(tcp.BaseHandler): + def handle(self): + self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") + self.wfile.flush() + + +def mock_protocol(data=''): rfile = cStringIO.StringIO(data) wfile = cStringIO.StringIO() return HTTP1Protocol(rfile=rfile, wfile=wfile) +def match_http_string(data): + return textwrap.dedent(data).strip().replace('\n', '\r\n') + + +def test_stripped_chunked_encoding_no_content(): + """ + https://github.com/mitmproxy/mitmproxy/issues/186 + """ + + r = tutils.treq(content="") + r.headers["Transfer-Encoding"] = ["chunked"] + assert "Content-Length" in mock_protocol()._assemble_request_headers(r) + + r = tutils.tresp(content="") + r.headers["Transfer-Encoding"] = ["chunked"] + assert "Content-Length" in mock_protocol()._assemble_response_headers(r) + def test_has_chunked_encoding(): h = odict.ODictCaseless() @@ -75,7 +99,6 @@ def test_connection_close(): assert HTTP1Protocol.connection_close((1, 1), h) - def test_read_http_body_request(): h = odict.ODictCaseless() data = "testing" @@ -85,7 +108,7 @@ def test_read_http_body_request(): def test_read_http_body_response(): h = odict.ODictCaseless() data = "testing" - assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing" + assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing" def test_read_http_body(): @@ -129,13 +152,13 @@ def test_read_http_body(): # test no content length: limit > actual content h = odict.ODictCaseless() data = "testing" - assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7 + assert len(mock_protocol(data).read_http_body(h, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content data = "testing" tutils.raises( http.HttpError, - mock_protocol(data, chunked=True).read_http_body, + mock_protocol(data).read_http_body, h, 4, "GET", 200, False ) @@ -143,7 +166,7 @@ def test_read_http_body(): h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] data = "5\r\naaaaa\r\n0\r\n\r\n" - assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa" + assert mock_protocol(data).read_http_body(h, 100, "GET", 200, False) == "aaaaa" def test_expected_http_body_size(): @@ -167,6 +190,13 @@ def test_expected_http_body_size(): assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0 +def test_get_request_line(): + data = "\nfoo" + p = mock_protocol(data) + assert p._get_request_line() == "foo" + assert not p._get_request_line() + + def test_parse_http_protocol(): assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1) assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0) @@ -269,96 +299,7 @@ class TestReadHeaders: assert self._read(data) is None -class NoContentLengthHTTPHandler(tcp.BaseHandler): - - def handle(self): - self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") - self.wfile.flush() - - -class TestReadResponseNoContentLength(tservers.ServerTestBase): - handler = NoContentLengthHTTPHandler - - def test_no_content_length(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - resp = HTTP1Protocol(c).read_response("GET", None) - assert resp.body == "bar\r\n\r\n" - - -def test_read_response(): - def tst(data, method, body_size_limit, include_body=True): - data = textwrap.dedent(data) - return mock_protocol(data).read_response( - method, body_size_limit, include_body=include_body - ) - - tutils.raises("server disconnect", tst, "", "GET", None) - tutils.raises("invalid server response", tst, "foo", "GET", None) - data = """ - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http.Response( - (1, 1), 200, 'OK', odict.ODictCaseless(), '' - ) - data = """ - HTTP/1.1 200 - """ - assert tst(data, "GET", None) == http.Response( - (1, 1), 200, '', odict.ODictCaseless(), '' - ) - data = """ - HTTP/x 200 OK - """ - tutils.raises("invalid http version", tst, data, "GET", None) - data = """ - HTTP/1.1 xx OK - """ - tutils.raises("invalid server response", tst, data, "GET", None) - - data = """ - HTTP/1.1 100 CONTINUE - - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http.Response( - (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' - ) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None).body == 'foo' - assert tst(data, "HEAD", None).body == '' - - data = """ - HTTP/1.1 200 OK - \tContent-Length: 3 - - foo - """ - tutils.raises("invalid headers", tst, data, "GET", None) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None, include_body=False).body is None - - -def test_get_request_line(): - data = "\nfoo" - p = mock_protocol(data) - assert p._get_request_line() == "foo" - assert not p._get_request_line() - - -class TestReadRequest(): +class TestReadRequest(object): def tst(self, data, **kwargs): return mock_protocol(data).read_request(**kwargs) @@ -385,6 +326,10 @@ class TestReadRequest(): "\r\n" ) + def test_empty(self): + v = self.tst("", allow_empty=True) + assert isinstance(v, semantics.EmptyRequest) + def test_asterisk_form_in(self): v = self.tst("OPTIONS * HTTP/1.1") assert v.form_in == "relative" @@ -427,3 +372,131 @@ class TestReadRequest(): assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" assert v.body == "foo" assert p.tcp_handler.rfile.read(3) == "bar" + + +class TestReadResponse(object): + def tst(self, data, method, body_size_limit, include_body=True): + data = textwrap.dedent(data) + return mock_protocol(data).read_response( + method, body_size_limit, include_body=include_body + ) + + def test_errors(self): + tutils.raises("server disconnect", self.tst, "", "GET", None) + tutils.raises("invalid server response", self.tst, "foo", "GET", None) + + def test_simple(self): + data = """ + HTTP/1.1 200 + """ + assert self.tst(data, "GET", None) == http.Response( + (1, 1), 200, '', odict.ODictCaseless(), '' + ) + + def test_simple_message(self): + data = """ + HTTP/1.1 200 OK + """ + assert self.tst(data, "GET", None) == http.Response( + (1, 1), 200, 'OK', odict.ODictCaseless(), '' + ) + + def test_invalid_http_version(self): + data = """ + HTTP/x 200 OK + """ + tutils.raises("invalid http version", self.tst, data, "GET", None) + + def test_invalid_status_code(self): + data = """ + HTTP/1.1 xx OK + """ + tutils.raises("invalid server response", self.tst, data, "GET", None) + + def test_valid_with_continue(self): + data = """ + HTTP/1.1 100 CONTINUE + + HTTP/1.1 200 OK + """ + assert self.tst(data, "GET", None) == http.Response( + (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' + ) + + def test_simple_body(self): + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert self.tst(data, "GET", None).body == 'foo' + assert self.tst(data, "HEAD", None).body == '' + + def test_invalid_headers(self): + data = """ + HTTP/1.1 200 OK + \tContent-Length: 3 + + foo + """ + tutils.raises("invalid headers", self.tst, data, "GET", None) + + def test_without_body(self): + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert self.tst(data, "GET", None, include_body=False).body is None + + +class TestReadResponseNoContentLength(tservers.ServerTestBase): + handler = NoContentLengthHTTPHandler + + def test_no_content_length(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + resp = HTTP1Protocol(c).read_response("GET", None) + assert resp.body == "bar\r\n\r\n" + + +class TestAssembleRequest(object): + def test_simple(self): + req = tutils.treq() + b = HTTP1Protocol().assemble_request(req) + assert b == match_http_string(""" + GET /path HTTP/1.1 + header: qvalue + Host: address:22 + Content-Length: 7 + + content""") + + def test_body_missing(self): + req = tutils.treq(content=semantics.CONTENT_MISSING) + tutils.raises(http.HttpError, HTTP1Protocol().assemble_request, req) + + def test_not_a_request(self): + tutils.raises(AssertionError, HTTP1Protocol().assemble_request, 'foo') + + +class TestAssembleResponse(object): + def test_simple(self): + resp = tutils.tresp() + b = HTTP1Protocol().assemble_response(resp) + print(b) + assert b == match_http_string(""" + HTTP/1.1 200 OK + header_response: svalue + Content-Length: 7 + + message""") + + def test_body_missing(self): + resp = tutils.tresp(content=semantics.CONTENT_MISSING) + tutils.raises(http.HttpError, HTTP1Protocol().assemble_response, resp) + + def test_not_a_request(self): + tutils.raises(AssertionError, HTTP1Protocol().assemble_response, 'foo') diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 8a27bbb1..3044179f 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,10 +1,25 @@ import OpenSSL +import mock from netlib import tcp, odict, http, tutils from netlib.http import http2 +from netlib.http.http2 import HTTP2Protocol from netlib.http.http2.frame import * from ... import tservers +class TestTCPHandlerWrapper: + def test_wrapped(self): + h = http2.TCPHandler(rfile='foo', wfile='bar') + p = HTTP2Protocol(h) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + + def test_direct(self): + p = HTTP2Protocol(rfile='foo', wfile='bar') + assert isinstance(p.tcp_handler, http2.TCPHandler) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + class EchoHandler(tcp.BaseHandler): sni = None @@ -16,10 +31,40 @@ class EchoHandler(tcp.BaseHandler): self.wfile.flush() +class TestProtocol: + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + def test_perform_connection_preface(self, mock_client_method, mock_server_method): + protocol = HTTP2Protocol(is_server=False) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert mock_client_method.called + assert not mock_server_method.called + + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): + protocol = HTTP2Protocol(is_server=True) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert not mock_client_method.called + assert mock_server_method.called + + class TestCheckALPNMatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, + alpn_select=HTTP2Protocol.ALPN_PROTO_H2, ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -27,8 +72,8 @@ class TestCheckALPNMatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) + c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + protocol = HTTP2Protocol(c) assert protocol.check_alpn() @@ -43,8 +88,8 @@ class TestCheckALPNMismatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) + c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + protocol = HTTP2Protocol(c) tutils.raises(NotImplementedError, protocol.check_alpn) @@ -76,8 +121,13 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase): def test_perform_server_connection_preface(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) + + assert not protocol.connection_preface_performed protocol.perform_server_connection_preface() + assert protocol.connection_preface_performed + + tutils.raises(tcp.NetLibIncomplete, protocol.perform_server_connection_preface, force=True) class TestPerformClientConnectionPreface(tservers.ServerTestBase): @@ -107,13 +157,16 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase): def test_perform_client_connection_preface(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) + + assert not protocol.connection_preface_performed protocol.perform_client_connection_preface() + assert protocol.connection_preface_performed class TestClientStreamIds(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) def test_client_stream_ids(self): assert self.protocol.current_stream_id is None @@ -127,7 +180,7 @@ class TestClientStreamIds(): class TestServerStreamIds(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c, is_server=True) + protocol = HTTP2Protocol(c, is_server=True) def test_server_stream_ids(self): assert self.protocol.current_stream_id is None @@ -154,7 +207,7 @@ class TestApplySettings(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) protocol._apply_settings({ SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', @@ -182,13 +235,13 @@ class TestCreateHeaders(): (b':scheme', b'https'), (b'foo', b'bar')] - bytes = http2.HTTP2Protocol(self.c)._create_headers( + bytes = HTTP2Protocol(self.c)._create_headers( headers, 1, end_stream=True) assert b''.join(bytes) ==\ '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ .decode('hex') - bytes = http2.HTTP2Protocol(self.c)._create_headers( + bytes = HTTP2Protocol(self.c)._create_headers( headers, 1, end_stream=False) assert b''.join(bytes) ==\ '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ @@ -199,7 +252,7 @@ class TestCreateHeaders(): class TestCreateBody(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) def test_create_body_empty(self): bytes = self.protocol._create_body(b'', 1) @@ -215,41 +268,30 @@ class TestCreateBody(): # TODO: add test for too large frames -class TestAssembleRequest(): - c = tcp.TCPClient(("127.0.0.1", 0)) +class TestReadRequest(tservers.ServerTestBase): + class handler(tcp.BaseHandler): - def test_assemble_request_simple(self): - bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - None, - None, - )) - assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') + def handle(self): + self.wfile.write( + b'000003010400000001828487'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() - def test_assemble_request_with_body(self): - bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - odict.ODictCaseless([('foo', 'bar')]), - 'foobar', - )) - assert len(bytes) == 2 - assert bytes[0] ==\ - '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') + ssl = True + + def test_read_request(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True + + resp = protocol.read_request() + + assert resp.stream_id + assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] + assert resp.body == b'foobar' class TestReadResponse(tservers.ServerTestBase): @@ -268,7 +310,7 @@ class TestReadResponse(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True resp = protocol.read_response() @@ -278,6 +320,23 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.msg == "" assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'foobar' + assert resp.timestamp_end + + def test_read_response_no_body(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c) + protocol.connection_preface_performed = True + + resp = protocol.read_response(include_body=False) + + assert resp.httpversion == (2, 0) + assert resp.status_code == 200 + assert resp.msg == "" + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] + assert resp.body == b'foobar' # TODO: this should be true: assert resp.body == http.CONTENT_MISSING + assert not resp.timestamp_end class TestReadEmptyResponse(tservers.ServerTestBase): @@ -294,7 +353,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True resp = protocol.read_response() @@ -307,37 +366,66 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.body == b'' -class TestReadRequest(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'000003010400000001828487'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True +class TestAssembleRequest(object): + c = tcp.TCPClient(("127.0.0.1", 0)) - def test_read_request(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c, is_server=True) - protocol.connection_preface_performed = True + def test_request_simple(self): + bytes = HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + None, + None, + )) + assert len(bytes) == 1 + assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') - resp = protocol.read_request() + def test_request_with_stream_id(self): + req = http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + None, + None, + ) + req.stream_id = 0x42 + bytes = HTTP2Protocol(self.c).assemble_request(req) + assert len(bytes) == 1 + print(bytes[0].encode('hex')) + assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex') - assert resp.stream_id - assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] - assert resp.body == b'foobar' + def test_request_with_body(self): + bytes = HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + odict.ODictCaseless([('foo', 'bar')]), + 'foobar', + )) + assert len(bytes) == 2 + assert bytes[0] ==\ + '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') -class TestCreateResponse(): +class TestAssembleResponse(object): c = tcp.TCPClient(("127.0.0.1", 0)) - def test_create_response_simple(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + def test_simple(self): + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( (2, 0), 200, )) @@ -345,8 +433,19 @@ class TestCreateResponse(): assert bytes[0] ==\ '00000101050000000288'.decode('hex') - def test_create_response_with_body(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + def test_with_stream_id(self): + resp = http.Response( + (2, 0), + 200, + ) + resp.stream_id = 0x42 + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000004288'.decode('hex') + + def test_with_body(self): + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( (2, 0), 200, '', diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py index aa57f831..0131c7ef 100644 --- a/test/http/test_exceptions.py +++ b/test/http/test_exceptions.py @@ -1,6 +1,27 @@ from netlib.http.exceptions import * +from netlib import odict -def test_HttpAuthenticationError(): - x = HttpAuthenticationError({"foo": "bar"}) - assert str(x) - assert "foo" in x.headers +class TestHttpError: + def test_simple(self): + e = HttpError(404, "Not found") + assert str(e) + +class TestHttpAuthenticationError: + def test_init(self): + headers = odict.ODictCaseless([("foo", "bar")]) + x = HttpAuthenticationError(headers) + assert str(x) + assert isinstance(x.headers, odict.ODictCaseless) + assert x.code == 407 + assert x.headers == headers + print(x.headers.keys()) + assert "foo" in x.headers.keys() + + def test_header_conversion(self): + headers = {"foo": "bar"} + x = HttpAuthenticationError(headers) + assert isinstance(x.headers, odict.ODictCaseless) + assert x.headers.lst == headers.items() + + def test_repr(self): + assert repr(HttpAuthenticationError()) == "Proxy Authentication Required" diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py index d58a44d2..59364eae 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_semantics.py @@ -1,18 +1,277 @@ import cStringIO import textwrap import binascii +import mock from mock import MagicMock -from netlib import http, odict, tcp, tutils -from netlib.http import http1 +from netlib import http, odict, tcp, tutils, utils +from netlib.http import semantics from netlib.http.semantics import CONTENT_MISSING from .. import tservers -def test_httperror(): - e = http.exceptions.HttpError(404, "Not found") - assert str(e) +class TestProtocolMixin(object): + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") + def test_assemble_request(self, mock_request_method, mock_response_method): + p = semantics.ProtocolMixin() + p.assemble(tutils.treq()) + assert mock_request_method.called + assert not mock_response_method.called + + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") + def test_assemble_response(self, mock_request_method, mock_response_method): + p = semantics.ProtocolMixin() + p.assemble(tutils.tresp()) + assert not mock_request_method.called + assert mock_response_method.called + + def test_assemble_foo(self): + p = semantics.ProtocolMixin() + tutils.raises(ValueError, p.assemble, 'foo') + +class TestRequest(object): + def test_repr(self): + r = tutils.treq() + assert repr(r) + + def test_headers_odict(self): + tutils.raises(AssertionError, semantics.Request, + 'form_in', + 'method', + 'scheme', + 'host', + 'port', + 'path', + (1, 1), + 'foobar', + ) + + req = semantics.Request( + 'form_in', + 'method', + 'scheme', + 'host', + 'port', + 'path', + (1, 1), + ) + assert isinstance(req.headers, odict.ODictCaseless) + + def test_equal(self): + a = tutils.treq() + b = tutils.treq() + assert a == b + + assert not a == 'foo' + assert not b == 'foo' + assert not 'foo' == a + assert not 'foo' == b + + def test_legacy_first_line(self): + req = tutils.treq() + + req.form_in = 'relative' + assert req.legacy_first_line() == "GET /path HTTP/1.1" + + req.form_in = 'authority' + assert req.legacy_first_line() == "GET address:22 HTTP/1.1" + + req.form_in = 'absolute' + assert req.legacy_first_line() == "GET http://address:22/path HTTP/1.1" + + req.form_in = 'foobar' + tutils.raises(http.HttpError, req.legacy_first_line) + + def test_anticache(self): + req = tutils.treq() + req.headers.add("If-Modified-Since", "foo") + req.headers.add("If-None-Match", "bar") + req.anticache() + assert "If-Modified-Since" not in req.headers + assert "If-None-Match" not in req.headers + + def test_anticomp(self): + req = tutils.treq() + req.headers.add("Accept-Encoding", "foobar") + req.anticomp() + assert req.headers["Accept-Encoding"] == ["identity"] + + def test_constrain_encoding(self): + req = tutils.treq() + req.headers.add("Accept-Encoding", "identity, gzip, foo") + req.constrain_encoding() + assert "foo" not in req.headers.get_first("Accept-Encoding") + + def test_update_host(self): + req = tutils.treq() + req.headers.add("Host", "") + req.host = "foobar" + req.update_host_header() + assert req.headers.get_first("Host") == "foobar" + + def test_get_form(self): + req = tutils.treq() + assert req.get_form() == odict.ODict() + + @mock.patch("netlib.http.semantics.Request.get_form_multipart") + @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") + def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart): + req = tutils.treq() + assert req.get_form() == odict.ODict() + + req = tutils.treq() + req.body = "foobar" + req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED] + req.get_form() + assert req.get_form_urlencoded.called + assert not req.get_form_multipart.called + + @mock.patch("netlib.http.semantics.Request.get_form_multipart") + @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") + def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart): + req = tutils.treq() + req.body = "foobar" + req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART] + req.get_form() + assert not req.get_form_urlencoded.called + assert req.get_form_multipart.called + + def test_get_form_urlencoded(self): + req = tutils.treq("foobar") + assert req.get_form_urlencoded() == odict.ODict() + + req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED] + assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body)) + + def test_get_form_multipart(self): + req = tutils.treq("foobar") + assert req.get_form_multipart() == odict.ODict() + + req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART] + assert req.get_form_multipart() == odict.ODict( + utils.multipartdecode( + req.headers, + req.body)) + + def test_set_form_urlencoded(self): + req = tutils.treq() + req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')])) + assert req.headers.get_first("Content-Type") == semantics.HDR_FORM_URLENCODED + assert req.body + + def test_get_path_components(self): + req = tutils.treq() + assert req.get_path_components() + # TODO: add meaningful assertions + + def test_set_path_components(self): + req = tutils.treq() + req.set_path_components(["foo", "bar"]) + # TODO: add meaningful assertions + + def test_get_query(self): + req = tutils.treq() + assert req.get_query().lst == [] + + req.url = "http://localhost:80/foo?bar=42" + assert req.get_query().lst == [("bar", "42")] + + def test_set_query(self): + req = tutils.treq() + req.set_query(odict.ODict([])) + + def test_pretty_host(self): + r = tutils.treq() + assert r.pretty_host(True) == "address" + assert r.pretty_host(False) == "address" + r.headers["host"] = ["other"] + assert r.pretty_host(True) == "other" + assert r.pretty_host(False) == "address" + r.host = None + assert r.pretty_host(True) == "other" + assert r.pretty_host(False) is None + del r.headers["host"] + assert r.pretty_host(True) is None + assert r.pretty_host(False) is None + + # Invalid IDNA + r.headers["host"] = [".disqus.com"] + assert r.pretty_host(True) == ".disqus.com" + + def test_pretty_url(self): + req = tutils.treq() + req.form_out = "authority" + assert req.pretty_url(True) == "address:22" + assert req.pretty_url(False) == "address:22" + + req.form_out = "relative" + assert req.pretty_url(True) == "http://address:22/path" + assert req.pretty_url(False) == "http://address:22/path" + + def test_get_cookies_none(self): + h = odict.ODictCaseless() + r = tutils.treq() + r.headers = h + assert len(r.get_cookies()) == 0 + + def test_get_cookies_single(self): + h = odict.ODictCaseless() + h["Cookie"] = ["cookiename=cookievalue"] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 1 + assert result['cookiename'] == ['cookievalue'] + + def test_get_cookies_double(self): + h = odict.ODictCaseless() + h["Cookie"] = [ + "cookiename=cookievalue;othercookiename=othercookievalue" + ] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 2 + assert result['cookiename'] == ['cookievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_get_cookies_withequalsign(self): + h = odict.ODictCaseless() + h["Cookie"] = [ + "cookiename=coo=kievalue;othercookiename=othercookievalue" + ] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 2 + assert result['cookiename'] == ['coo=kievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_set_cookies(self): + h = odict.ODictCaseless() + h["Cookie"] = ["cookiename=cookievalue"] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + result["cookiename"] = ["foo"] + r.set_cookies(result) + assert r.get_cookies()["cookiename"] == ["foo"] + + def test_set_url(self): + r = tutils.treq_absolute() + r.url = "https://otheraddress:42/ORLY" + assert r.scheme == "https" + assert r.host == "otheraddress" + assert r.port == 42 + assert r.path == "/ORLY" + + try: + r.url = "//localhost:80/foo@bar" + assert False + except: + assert True -class TestRequest: # def test_asterisk_form_in(self): # f = tutils.tflow(req=None) # protocol = mock_protocol("OPTIONS * HTTP/1.1") @@ -92,105 +351,35 @@ class TestRequest: # "Host: address\r\n" # "Content-Length: 0\r\n\r\n") - def test_set_url(self): - r = tutils.treq_absolute() - r.url = "https://otheraddress:42/ORLY" - assert r.scheme == "https" - assert r.host == "otheraddress" - assert r.port == 42 - assert r.path == "/ORLY" - - def test_repr(self): - r = tutils.treq() - assert repr(r) - - def test_pretty_host(self): - r = tutils.treq() - assert r.pretty_host(True) == "address" - assert r.pretty_host(False) == "address" - r.headers["host"] = ["other"] - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) == "address" - r.host = None - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) is None - del r.headers["host"] - assert r.pretty_host(True) is None - assert r.pretty_host(False) is None - - # Invalid IDNA - r.headers["host"] = [".disqus.com"] - assert r.pretty_host(True) == ".disqus.com" - - def test_get_form_for_urlencoded(self): - r = tutils.treq() - r.headers.add("content-type", "application/x-www-form-urlencoded") - r.get_form_urlencoded = MagicMock() - - r.get_form() - - assert r.get_form_urlencoded.called - - def test_get_form_for_multipart(self): - r = tutils.treq() - r.headers.add("content-type", "multipart/form-data") - r.get_form_multipart = MagicMock() - - r.get_form() +class TestEmptyRequest(object): + def test_init(self): + req = semantics.EmptyRequest() + assert req - assert r.get_form_multipart.called - - def test_get_cookies_none(self): - h = odict.ODictCaseless() - r = tutils.treq() - r.headers = h - assert len(r.get_cookies()) == 0 - - def test_get_cookies_single(self): - h = odict.ODictCaseless() - h["Cookie"] = ["cookiename=cookievalue"] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - assert len(result) == 1 - assert result['cookiename'] == ['cookievalue'] - - def test_get_cookies_double(self): - h = odict.ODictCaseless() - h["Cookie"] = [ - "cookiename=cookievalue;othercookiename=othercookievalue" - ] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['cookievalue'] - assert result['othercookiename'] == ['othercookievalue'] +class TestResponse(object): + def test_headers_odict(self): + tutils.raises(AssertionError, semantics.Response, + (1, 1), + 200, + headers='foobar', + ) - def test_get_cookies_withequalsign(self): - h = odict.ODictCaseless() - h["Cookie"] = [ - "cookiename=coo=kievalue;othercookiename=othercookievalue" - ] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['coo=kievalue'] - assert result['othercookiename'] == ['othercookievalue'] + resp = semantics.Response( + (1, 1), + 200, + ) + assert isinstance(resp.headers, odict.ODictCaseless) - def test_set_cookies(self): - h = odict.ODictCaseless() - h["Cookie"] = ["cookiename=cookievalue"] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - result["cookiename"] = ["foo"] - r.set_cookies(result) - assert r.get_cookies()["cookiename"] == ["foo"] + def test_equal(self): + a = tutils.tresp() + b = tutils.tresp() + assert a == b + assert not a == 'foo' + assert not b == 'foo' + assert not 'foo' == a + assert not 'foo' == b -class TestResponse(object): def test_repr(self): r = tutils.tresp() assert "unknown content type" in repr(r) diff --git a/test/test_utils.py b/test/test_utils.py index 5e681eb6..aafa1571 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -101,3 +101,34 @@ def test_get_header_tokens(): assert utils.get_header_tokens(h, "foo") == ["bar", "voing"] h["foo"] = ["bar, voing", "oink"] assert utils.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] + + + + + +def test_multipartdecode(): + boundary = 'somefancyboundary' + headers = odict.ODict( + [('content-type', ('multipart/form-data; boundary=%s' % boundary))]) + content = "--{0}\n" \ + "Content-Disposition: form-data; name=\"field1\"\n\n" \ + "value1\n" \ + "--{0}\n" \ + "Content-Disposition: form-data; name=\"field2\"\n\n" \ + "value2\n" \ + "--{0}--".format(boundary) + + form = utils.multipartdecode(headers, content) + + assert len(form) == 2 + assert form[0] == ('field1', 'value1') + assert form[1] == ('field2', 'value2') + + +def test_parse_content_type(): + p = utils.parse_content_type + assert p("text/html") == ("text", "html", {}) + assert p("text") is None + + v = p("text/html; charset=UTF-8") + assert v == ('text', 'html', {'charset': 'UTF-8'}) diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 28dbb833..9fa98172 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -3,6 +3,7 @@ import os from nose.tools import raises from netlib import tcp, http, websockets, tutils +from netlib.http import status_codes from netlib.http.exceptions import * from netlib.http.http1 import HTTP1Protocol from .. import tservers @@ -38,7 +39,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): req = http1_protocol.read_request() key = self.protocol.check_client_handshake(req.headers) - preamble = http1_protocol.response_preamble(101) + preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers(key) self.wfile.write(headers.format() + "\r\n") @@ -62,7 +63,7 @@ class WebSocketsClient(tcp.TCPClient): http1_protocol = HTTP1Protocol(self) - preamble = http1_protocol.request_preamble("GET", "/") + preamble = 'GET / HTTP/1.1' self.wfile.write(preamble + "\r\n") headers = self.protocol.client_handshake_headers() self.client_nonce = headers.get_first("sec-websocket-key") @@ -162,7 +163,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler): client_hs = http1_protocol.read_request() self.protocol.check_client_handshake(client_hs.headers) - preamble = http1_protocol.response_preamble(101) + preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers("malformed key") self.wfile.write(headers.format() + "\r\n") -- cgit v1.2.3 From 476badf45cd085d69b6162cd48983e3cd22cefcc Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 10 Aug 2015 20:36:47 +0200 Subject: cleanup imports --- netlib/http/authentication.py | 2 -- netlib/http/http1/protocol.py | 4 ---- netlib/http/semantics.py | 4 ---- netlib/websockets/frame.py | 5 ++--- netlib/websockets/protocol.py | 5 ++--- test/http/http1/test_protocol.py | 1 - test/http/test_semantics.py | 10 ++++------ test/test_utils.py | 1 - test/tservers.py | 3 ++- test/websockets/test_websockets.py | 4 +++- 10 files changed, 13 insertions(+), 26 deletions(-) diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 9a227010..29b9eb3c 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -2,7 +2,6 @@ from __future__ import (absolute_import, print_function, division) from argparse import Action, ArgumentTypeError import binascii -from .. import http def parse_http_basic_auth(s): words = s.split() @@ -37,7 +36,6 @@ class NullProxyAuth(object): """ Clean up authentication headers, so they're not passed upstream. """ - pass def authenticate(self, headers_): """ diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 31e9cc85..c797e930 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -1,14 +1,10 @@ from __future__ import (absolute_import, print_function, division) -import binascii -import collections import string import sys -import urlparse import time from netlib import odict, utils, tcp, http from netlib.http import semantics -from .. import status_codes from ..exceptions import * class TCPHandler(object): diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 974fe6e6..15add957 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -1,8 +1,4 @@ from __future__ import (absolute_import, print_function, division) -import binascii -import collections -import string -import sys import urllib import urlparse diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index 49d8ee10..ad4ad0ee 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -1,12 +1,11 @@ from __future__ import absolute_import -import base64 -import hashlib import os import struct import io from .protocol import Masker -from netlib import utils, odict, tcp +from netlib import tcp +from netlib import utils DEFAULT = object() diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 29b4db3d..8169309a 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -2,10 +2,9 @@ from __future__ import absolute_import import base64 import hashlib import os -import struct -import io -from netlib import utils, odict, tcp +from netlib import odict +from netlib import utils # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index ff70b87d..af77c55f 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -1,6 +1,5 @@ import cStringIO import textwrap -import binascii from netlib import http, odict, tcp, tutils from netlib.http import semantics diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py index 59364eae..7ef69dcf 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_semantics.py @@ -1,13 +1,11 @@ -import cStringIO -import textwrap -import binascii import mock -from mock import MagicMock -from netlib import http, odict, tcp, tutils, utils +from netlib import http +from netlib import odict +from netlib import tutils +from netlib import utils from netlib.http import semantics from netlib.http.semantics import CONTENT_MISSING -from .. import tservers class TestProtocolMixin(object): @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") diff --git a/test/test_utils.py b/test/test_utils.py index aafa1571..27fc5cc5 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,3 @@ -import urlparse from netlib import utils, odict, tutils diff --git a/test/tservers.py b/test/tservers.py index 3f3ea8b4..682a9144 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -3,7 +3,8 @@ import threading import Queue import cStringIO import OpenSSL -from netlib import tcp, certutils, tutils +from netlib import tcp +from netlib import tutils class ServerThread(threading.Thread): diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 9fa98172..752f2c3e 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -2,7 +2,9 @@ import os from nose.tools import raises -from netlib import tcp, http, websockets, tutils +from netlib import tcp +from netlib import tutils +from netlib import websockets from netlib.http import status_codes from netlib.http.exceptions import * from netlib.http.http1 import HTTP1Protocol -- cgit v1.2.3 From ff27d65f08d00c312a162965c5b1db711aa8f6ed Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 10 Aug 2015 20:44:36 +0200 Subject: cleanup whitespace --- netlib/http/exceptions.py | 3 +++ netlib/http/http1/protocol.py | 50 ++++++++++++++++++++++++------------------- netlib/http/http2/frame.py | 2 +- netlib/http/http2/protocol.py | 17 +++++++++++---- netlib/http/semantics.py | 10 ++++----- netlib/tutils.py | 2 -- netlib/utils.py | 5 +++-- netlib/websockets/frame.py | 1 + netlib/websockets/protocol.py | 2 ++ test/test_encoding.py | 1 + test/test_socks.py | 6 +++++- test/test_utils.py | 6 ------ 12 files changed, 62 insertions(+), 43 deletions(-) diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py index 7cd26c12..987a7908 100644 --- a/netlib/http/exceptions.py +++ b/netlib/http/exceptions.py @@ -1,6 +1,8 @@ from netlib import odict + class HttpError(Exception): + def __init__(self, code, message): super(HttpError, self).__init__(message) self.code = code @@ -11,6 +13,7 @@ class HttpErrorConnClosed(HttpError): class HttpAuthenticationError(Exception): + def __init__(self, auth_headers=None): super(HttpAuthenticationError, self).__init__( "Proxy Authentication Required" diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index c797e930..8eeb7744 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -7,18 +7,25 @@ from netlib import odict, utils, tcp, http from netlib.http import semantics from ..exceptions import * + class TCPHandler(object): + def __init__(self, rfile, wfile=None): self.rfile = rfile self.wfile = wfile + class HTTP1Protocol(semantics.ProtocolMixin): def __init__(self, tcp_handler=None, rfile=None, wfile=None): self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) - - def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + def read_request( + self, + include_body=True, + body_size_limit=None, + allow_empty=False, + ): """ Parse an HTTP request from a file stream @@ -125,8 +132,12 @@ class HTTP1Protocol(semantics.ProtocolMixin): timestamp_end, ) - - def read_response(self, request_method, body_size_limit, include_body=True): + def read_response( + self, + request_method, + body_size_limit, + include_body=True, + ): """ Returns an http.Response @@ -171,7 +182,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): # read separately body = None - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): # more accurate timestamp_start timestamp_start = self.tcp_handler.rfile.first_byte_timestamp @@ -191,7 +201,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): timestamp_end=timestamp_end, ) - def assemble_request(self, request): assert isinstance(request, semantics.Request) @@ -204,7 +213,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): headers = self._assemble_request_headers(request) return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) - def assemble_response(self, response): assert isinstance(response, semantics.Response) @@ -217,7 +225,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): headers = self._assemble_response_headers(response) return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) - def read_headers(self): """ Read a set of headers. @@ -262,7 +269,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): response_code, is_request, max_chunk_size=None - ): + ): """ Read an HTTP message body: headers: An ODictCaseless object @@ -317,9 +324,14 @@ class HTTP1Protocol(semantics.ProtocolMixin): "HTTP Body too large. Limit is %s," % limit ) - @classmethod - def expected_http_body_size(self, headers, is_request, request_method, response_code): + def expected_http_body_size( + self, + headers, + is_request, + request_method, + response_code, + ): """ Returns the expected body length: - a positive integer, if the size is known in advance @@ -372,7 +384,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): line = self.tcp_handler.rfile.readline() return line - def _read_chunked(self, limit, is_request): """ Read a chunked HTTP body. @@ -409,7 +420,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): if length == 0: return - @classmethod def _parse_http_protocol(self, line): """ @@ -429,7 +439,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return major, minor - @classmethod def _parse_init(self, line): try: @@ -443,7 +452,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return method, url, httpversion - @classmethod def _parse_init_connect(self, line): """ @@ -471,7 +479,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return host, port, httpversion - @classmethod def _parse_init_proxy(self, line): v = self._parse_init(line) @@ -485,7 +492,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): scheme, host, port, path = parts return method, scheme, host, port, path, httpversion - @classmethod def _parse_init_http(self, line): """ @@ -501,7 +507,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return method, url, httpversion - @classmethod def connection_close(self, httpversion, headers): """ @@ -521,7 +526,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): # be persistent return httpversion != (1, 1) - @classmethod def parse_response_line(self, line): parts = line.strip().split(" ", 2) @@ -536,7 +540,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return (proto, code, msg) - @classmethod def _assemble_request_first_line(self, request): return request.legacy_first_line() @@ -557,7 +560,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return headers.format() - def _assemble_response_first_line(self, response): return 'HTTP/%s.%s %s %s' % ( response.httpversion[0], @@ -566,7 +568,11 @@ class HTTP1Protocol(semantics.ProtocolMixin): response.msg, ) - def _assemble_response_headers(self, response, preserve_transfer_encoding=False): + def _assemble_response_headers( + self, + response, + preserve_transfer_encoding=False, + ): headers = response.headers.copy() for k in response._headers_to_strip_off: del headers[k] diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index f7e60471..aa1fbae4 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -117,7 +117,7 @@ class Frame(object): return "\n".join([ "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), self.payload_human_readable(), "===============================================================", ]) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index a1ca4a18..896b728b 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -9,6 +9,7 @@ from . import frame class TCPHandler(object): + def __init__(self, rfile, wfile=None): self.rfile = rfile self.wfile = wfile @@ -39,7 +40,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): ALPN_PROTO_H2 = 'h2' - def __init__( self, tcp_handler=None, @@ -60,7 +60,12 @@ class HTTP2Protocol(semantics.ProtocolMixin): self.current_stream_id = None self.connection_preface_performed = False - def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + def read_request( + self, + include_body=True, + body_size_limit=None, + allow_empty=False, + ): self.perform_connection_preface() timestamp_start = time.time() @@ -92,7 +97,12 @@ class HTTP2Protocol(semantics.ProtocolMixin): return request - def read_response(self, request_method='', body_size_limit=None, include_body=True): + def read_response( + self, + request_method='', + body_size_limit=None, + include_body=True, + ): self.perform_connection_preface() timestamp_start = time.time() @@ -123,7 +133,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): return response - def assemble_request(self, request): assert isinstance(request, semantics.Request) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 15add957..d9dbb559 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -332,6 +332,7 @@ class Request(object): class EmptyRequest(Request): + def __init__(self): super(EmptyRequest, self).__init__( form_in="", @@ -343,7 +344,7 @@ class EmptyRequest(Request): httpversion=(0, 0), headers=odict.ODictCaseless(), body="", - ) + ) class Response(object): @@ -396,10 +397,9 @@ class Response(object): status_code=self.status_code, msg=self.msg, contenttype=self.headers.get_first( - "content-type", "unknown content type" - ), - size=size - ) + "content-type", + "unknown content type"), + size=size) def get_cookies(self): """ diff --git a/netlib/tutils.py b/netlib/tutils.py index 3c471d0d..7434c108 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -69,8 +69,6 @@ def raises(exc, obj, *args, **kwargs): test_data = utils.Data(__name__) - - def treq(content="content", scheme="http", host="address", port=22): """ @return: libmproxy.protocol.http.HTTPRequest diff --git a/netlib/utils.py b/netlib/utils.py index 2dfcafc6..31dcd622 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -119,6 +119,7 @@ def pretty_size(size): class Data(object): + def __init__(self, name): m = __import__(name) dirname, _ = os.path.split(m.__file__) @@ -137,8 +138,6 @@ class Data(object): return fullpath - - def is_valid_port(port): if not 0 <= port <= 65535: return False @@ -221,6 +220,7 @@ def hostport(scheme, host, port): else: return "%s:%s" % (host, port) + def unparse_url(scheme, host, port, path=""): """ Returns a URL string, constructed from the specified compnents. @@ -235,6 +235,7 @@ def urlencode(s): s = [tuple(i) for i in s] return urllib.urlencode(s, False) + def urldecode(s): """ Takes a urlencoded string and returns a list of (key, value) tuples. diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index ad4ad0ee..1c4a03b2 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -21,6 +21,7 @@ OPCODE = utils.BiDi( PONG=0x0a ) + class FrameHeader(object): def __init__( diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 8169309a..6ce32eac 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -25,6 +25,7 @@ HEADER_WEBSOCKET_KEY = 'sec-websocket-key' HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept' HEADER_WEBSOCKET_VERSION = 'sec-websocket-version' + class Masker(object): """ @@ -52,6 +53,7 @@ class Masker(object): self.offset += len(ret) return ret + class WebsocketsProtocol(object): def __init__(self): diff --git a/test/test_encoding.py b/test/test_encoding.py index faf718ae..612aea89 100644 --- a/test/test_encoding.py +++ b/test/test_encoding.py @@ -1,5 +1,6 @@ from netlib import encoding + def test_identity(): assert "string" == encoding.decode("identity", "string") assert "string" == encoding.encode("identity", "string") diff --git a/test/test_socks.py b/test/test_socks.py index 36fc5b3d..3d109f42 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -44,7 +44,11 @@ def test_client_greeting_assert_socks5(): assert False raw = tutils.treader("XX") - tutils.raises(socks.SocksError, socks.ClientGreeting.from_file, raw, fail_early=True) + tutils.raises( + socks.SocksError, + socks.ClientGreeting.from_file, + raw, + fail_early=True) def test_server_greeting(): diff --git a/test/test_utils.py b/test/test_utils.py index 27fc5cc5..89ce0f17 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -29,8 +29,6 @@ def test_pretty_size(): assert utils.pretty_size(1024 * 1024) == "1MB" - - def test_parse_url(): assert not utils.parse_url("") @@ -85,7 +83,6 @@ def test_urlencode(): assert utils.urlencode([('foo', 'bar')]) - def test_urldecode(): s = "one=two&three=four" assert len(utils.urldecode(s)) == 2 @@ -102,9 +99,6 @@ def test_get_header_tokens(): assert utils.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] - - - def test_multipartdecode(): boundary = 'somefancyboundary' headers = odict.ODict( -- cgit v1.2.3 From 6a30ad2ad236fa20d086e271ff962ebc907da027 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 10 Aug 2015 20:50:05 +0200 Subject: fix minor style offences --- netlib/http/http2/protocol.py | 10 +++++----- netlib/http/semantics.py | 12 ++++++------ test/test_utils.py | 1 - 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 896b728b..c2ad5edd 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -142,13 +142,13 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = request.headers.copy() - if not ':authority' in headers.keys(): + if ':authority' not in headers.keys(): headers.add(':authority', bytes(authority), prepend=True) - if not ':scheme' in headers.keys(): + if ':scheme' not in headers.keys(): headers.add(':scheme', bytes(request.scheme), prepend=True) - if not ':path' in headers.keys(): + if ':path' not in headers.keys(): headers.add(':path', bytes(request.path), prepend=True) - if not ':method' in headers.keys(): + if ':method' not in headers.keys(): headers.add(':method', bytes(request.method), prepend=True) headers = headers.items() @@ -167,7 +167,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = response.headers.copy() - if not ':status' in headers.keys(): + if ':status' not in headers.keys(): headers.add(':status', bytes(str(response.status_code)), prepend=True) headers = headers.items() diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index d9dbb559..76213cd1 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -15,10 +15,10 @@ CONTENT_MISSING = 0 class ProtocolMixin(object): def read_request(self, *args, **kwargs): # pragma: no cover - raise NotImplemented + raise NotImplementedError def read_response(self, *args, **kwargs): # pragma: no cover - raise NotImplemented + raise NotImplementedError def assemble(self, message): if isinstance(message, Request): @@ -28,11 +28,11 @@ class ProtocolMixin(object): else: raise ValueError("HTTP message not supported.") - def assemble_request(self, request): # pragma: no cover - raise NotImplemented + def assemble_request(self, *args, **kwargs): # pragma: no cover + raise NotImplementedError - def assemble_response(self, response): # pragma: no cover - raise NotImplemented + def assemble_response(self, *args, **kwargs): # pragma: no cover + raise NotImplementedError class Request(object): diff --git a/test/test_utils.py b/test/test_utils.py index 89ce0f17..fc7174d6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,3 @@ - from netlib import utils, odict, tutils -- cgit v1.2.3 From b7e6e1c9b2c57270ee0c49af9235a2b119600056 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 15 Aug 2015 17:49:59 +0200 Subject: add HTTP/1.1 ALPN version string --- netlib/http/http1/protocol.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 8eeb7744..dc33a8af 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -17,6 +17,8 @@ class TCPHandler(object): class HTTP1Protocol(semantics.ProtocolMixin): + ALPN_PROTO_HTTP1 = 'http/1.1' + def __init__(self, tcp_handler=None, rfile=None, wfile=None): self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) -- cgit v1.2.3 From 85cede47aa8f9ffd770ad2830084e53b04b4e77e Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 16 Aug 2015 11:41:34 +0200 Subject: allow direct ALPN callback method --- netlib/tcp.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index c355cfdd..b3171a1c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -403,6 +403,7 @@ class _Connection(object): cipher_list=None, alpn_protos=None, alpn_select=None, + alpn_select_callback=None, ): """ Creates an SSL Context. @@ -457,7 +458,7 @@ class _Connection(object): if alpn_protos is not None: # advertise application layer protocols context.set_alpn_protos(alpn_protos) - elif alpn_select is not None: + elif alpn_select is not None and alpn_select_callback is None: # select application layer protocol def alpn_select_callback(conn_, options): if alpn_select in options: @@ -465,6 +466,10 @@ class _Connection(object): else: # pragma no cover return options[0] context.set_alpn_select_callback(alpn_select_callback) + elif alpn_select_callback is not None and alpn_select is None: + context.set_alpn_select_callback(alpn_select_callback) + elif alpn_select_callback is not None and alpn_select is not None: + raise NetLibError("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") return context -- cgit v1.2.3 From 99e89a1efc9871e8956460d1e40cf8282f14babd Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 16 Aug 2015 21:47:26 +1200 Subject: Remove stray prints from test suite --- test/http/http1/test_protocol.py | 1 - test/http/http2/test_protocol.py | 1 - test/http/test_exceptions.py | 1 - 3 files changed, 3 deletions(-) diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index af77c55f..6704647f 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -485,7 +485,6 @@ class TestAssembleResponse(object): def test_simple(self): resp = tutils.tresp() b = HTTP1Protocol().assemble_response(resp) - print(b) assert b == match_http_string(""" HTTP/1.1 200 OK header_response: svalue diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 3044179f..0431de34 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -399,7 +399,6 @@ class TestAssembleRequest(object): req.stream_id = 0x42 bytes = HTTP2Protocol(self.c).assemble_request(req) assert len(bytes) == 1 - print(bytes[0].encode('hex')) assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex') def test_request_with_body(self): diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py index 0131c7ef..d7c438f7 100644 --- a/test/http/test_exceptions.py +++ b/test/http/test_exceptions.py @@ -14,7 +14,6 @@ class TestHttpAuthenticationError: assert isinstance(x.headers, odict.ODictCaseless) assert x.code == 407 assert x.headers == headers - print(x.headers.keys()) assert "foo" in x.headers.keys() def test_header_conversion(self): -- cgit v1.2.3 From 3d306671251723a781b6e69c826bb94117f86188 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 17 Aug 2015 10:21:30 +1200 Subject: Bump netlib version - 0.13.1 is already out --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index de42ace1..044fde2c 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 13, 1) +IVERSION = (0, 13, 2) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From c92dc1b8682ed15b68890f18c65b3f31122e9fa4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 15 Aug 2015 20:30:22 +0200 Subject: re-add form_out --- netlib/http/semantics.py | 12 ++++++++---- netlib/tcp.py | 2 ++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 76213cd1..5b7fb80f 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -59,6 +59,7 @@ class Request(object): body=None, timestamp_start=None, timestamp_end=None, + form_out=None ): if not headers: headers = odict.ODictCaseless() @@ -75,6 +76,7 @@ class Request(object): self.body = body self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end + self.form_out = form_out or form_in def __eq__(self, other): try: @@ -91,15 +93,17 @@ class Request(object): self.legacy_first_line()[:-9] ) - def legacy_first_line(self): - if self.form_in == "relative": + def legacy_first_line(self, form=None): + if form is None: + form = self.form_out + if form == "relative": return '%s %s HTTP/%s.%s' % ( self.method, self.path, self.httpversion[0], self.httpversion[1], ) - elif self.form_in == "authority": + elif form == "authority": return '%s %s:%s HTTP/%s.%s' % ( self.method, self.host, @@ -107,7 +111,7 @@ class Request(object): self.httpversion[0], self.httpversion[1], ) - elif self.form_in == "absolute": + elif form == "absolute": return '%s %s://%s:%s%s HTTP/%s.%s' % ( self.method, self.scheme, diff --git a/netlib/tcp.py b/netlib/tcp.py index b3171a1c..22cd0965 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -310,6 +310,8 @@ class Address(object): return str(self.address) def __eq__(self, other): + if not other: + return False other = Address.wrap(other) return (self.address, self.family) == (other.address, other.family) -- cgit v1.2.3 From 62416daa4a3776563556fb45ef9bd749fb44c334 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 25 Jul 2015 13:31:04 +0200 Subject: add Reader.peek() --- netlib/tcp.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 22cd0965..b05e84f5 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -265,6 +265,24 @@ class Reader(_FileLike): ) return result + def peek(self, length): + """ + Tries to peek into the underlying file object. + + Returns: + Up to the next N bytes if peeking is successful. + None, otherwise. + + Raises: + NetLibSSLError if there was an error with pyOpenSSL. + """ + if isinstance(self.o, SSL.Connection) or isinstance(self.o, socket._fileobject): + try: + return self.o._sock.recv(length, socket.MSG_PEEK) + except SSL.Error as e: + raise NetLibSSLError(str(e)) + + class Address(object): -- cgit v1.2.3 From 231656859fcf82cb1252d1aad8dbc0f77dfb8bba Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 16 Aug 2015 23:33:11 +0200 Subject: TCPClient: more sophisticated address handling --- netlib/http/semantics.py | 3 ++- netlib/tcp.py | 34 +++++++++++++++++++++++----------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 5b7fb80f..836af550 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -397,7 +397,8 @@ class Response(object): size = utils.pretty_size(len(self.body)) else: size = "content missing" - return "".format( + # TODO: Remove "(unknown content type, content missing)" edge-case + return "".format( status_code=self.status_code, msg=self.msg, contenttype=self.headers.get_first( diff --git a/netlib/tcp.py b/netlib/tcp.py index b05e84f5..289618a7 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -283,7 +283,6 @@ class Reader(_FileLike): raise NetLibSSLError(str(e)) - class Address(object): """ @@ -498,6 +497,29 @@ class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 + def __init__(self, address, source_address=None): + self.connection, self.rfile, self.wfile = None, None, None + self.address = address + self.source_address = Address.wrap( + source_address) if source_address else None + self.cert = None + self.ssl_established = False + self.ssl_verification_error = None + self.sni = None + + @property + def address(self): + return self.__address + + @address.setter + def address(self, address): + if self.connection: + raise RuntimeError("Cannot change server address after establishing connection") + if address: + self.__address = Address.wrap(address) + else: + self.__address = None + def close(self): # Make sure to close the real socket, not the SSL proxy. # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, @@ -507,16 +529,6 @@ class TCPClient(_Connection): else: close_socket(self.connection) - def __init__(self, address, source_address=None): - self.address = Address.wrap(address) - self.source_address = Address.wrap( - source_address) if source_address else None - self.connection, self.rfile, self.wfile = None, None, None - self.cert = None - self.ssl_established = False - self.ssl_verification_error = None - self.sni = None - def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): context = self._create_ssl_context( alpn_protos=alpn_protos, -- cgit v1.2.3 From 67e6e5c7d8e26fda90d6c74e80c2432745ce5921 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 18 Aug 2015 21:13:46 +0200 Subject: temporarily disable pypy on travis --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 9fd4fbd9..9c781467 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ sudo: false python: - "2.7" - - pypy + # - pypy # disable until TravisCI ships a PyPy version which works with the latest CFFI matrix: include: -- cgit v1.2.3 From c903efcf5b34fa775d3e64623e54bd8e75b740cb Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 18 Aug 2015 21:17:11 +0200 Subject: temporarily disable pypy with new openssl on travis --- .travis.yml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9c781467..3b49224e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,16 +18,16 @@ matrix: - debian-sid packages: - libssl-dev - - python: pypy - env: OPENSSL=1.0.2 - addons: - apt: - sources: - # Debian sid currently holds OpenSSL 1.0.2 - # change this with future releases! - - debian-sid - packages: - - libssl-dev + # - python: pypy + # env: OPENSSL=1.0.2 + # addons: + # apt: + # sources: + # # Debian sid currently holds OpenSSL 1.0.2 + # # change this with future releases! + # - debian-sid + # packages: + # - libssl-dev install: - "pip install --src . -r requirements.txt" -- cgit v1.2.3 From 12efa61e3af1b0ede4a803320b6f2a14b034aa5d Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 18 Aug 2015 21:22:27 +0200 Subject: fix request-target tests --- test/http/test_semantics.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py index 7ef69dcf..2a799044 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_semantics.py @@ -69,17 +69,10 @@ class TestRequest(object): def test_legacy_first_line(self): req = tutils.treq() - req.form_in = 'relative' - assert req.legacy_first_line() == "GET /path HTTP/1.1" - - req.form_in = 'authority' - assert req.legacy_first_line() == "GET address:22 HTTP/1.1" - - req.form_in = 'absolute' - assert req.legacy_first_line() == "GET http://address:22/path HTTP/1.1" - - req.form_in = 'foobar' - tutils.raises(http.HttpError, req.legacy_first_line) + assert req.legacy_first_line('relative') == "GET /path HTTP/1.1" + assert req.legacy_first_line('authority') == "GET address:22 HTTP/1.1" + assert req.legacy_first_line('absolute') == "GET http://address:22/path HTTP/1.1" + tutils.raises(http.HttpError, req.legacy_first_line, 'foobar') def test_anticache(self): req = tutils.treq() -- cgit v1.2.3 From 0d384ac2a91898d4c8623290ae0fb3a60a35e514 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 17 Aug 2015 22:55:33 +0200 Subject: http2: add support for too large data frames --- netlib/http/http2/protocol.py | 19 +++++++++++-------- test/http/http2/test_protocol.py | 17 +++++++++++------ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index c2ad5edd..cc8daba8 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -297,19 +297,22 @@ class HTTP2Protocol(semantics.ProtocolMixin): if body is None or len(body) == 0: return b'' - # TODO: implement max frame size checks and sending in chunks - # TODO: implement flow-control window - - frm = frame.DataFrame( + chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunks = range(0, len(body), chunk_size) + frms = [frame.DataFrame( state=self, - flags=frame.Frame.FLAG_END_STREAM, + flags=frame.Frame.FLAG_NO_FLAGS, stream_id=stream_id, - payload=body) + payload=body[i:i+chunk_size]) for i in chunks] + frms[-1].flags = frame.Frame.FLAG_END_STREAM + + # TODO: implement flow-control window if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) + for frm in frms: + print(frm.human_readable(">>")) - return [frm.to_bytes()] + return [frm.to_bytes() for frm in frms] def _receive_transmission(self, include_body=True): body_expected = True diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 0431de34..7f3fd2bd 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -252,20 +252,25 @@ class TestCreateHeaders(): class TestCreateBody(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = HTTP2Protocol(c) def test_create_body_empty(self): - bytes = self.protocol._create_body(b'', 1) + protocol = HTTP2Protocol(self.c) + bytes = protocol._create_body(b'', 1) assert b''.join(bytes) == ''.decode('hex') def test_create_body_single_frame(self): - bytes = self.protocol._create_body('foobar', 1) + protocol = HTTP2Protocol(self.c) + bytes = protocol._create_body('foobar', 1) assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') def test_create_body_multiple_frames(self): - pass - # bytes = self.protocol._create_body('foobar' * 3000, 1) - # TODO: add test for too large frames + protocol = HTTP2Protocol(self.c) + protocol.http2_settings[SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] = 5 + bytes = protocol._create_body('foobarmehm42', 1) + assert len(bytes) == 3 + assert bytes[0] == '000005000000000001666f6f6261'.decode('hex') + assert bytes[1] == '000005000000000001726d65686d'.decode('hex') + assert bytes[2] == '0000020001000000013432'.decode('hex') class TestReadRequest(tservers.ServerTestBase): -- cgit v1.2.3 From 07a1356e2f155d5b9e3a5f97bf90515ed9f1011f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 18 Aug 2015 09:49:56 +0200 Subject: http2: add support for too large header frames --- netlib/http/http2/protocol.py | 29 +++++++++++++++++++---------- test/http/http2/test_protocol.py | 16 +++++++++++++++- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index cc8daba8..c27b4e9e 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -274,24 +274,33 @@ class HTTP2Protocol(semantics.ProtocolMixin): # to be more strict use: self._read_settings_ack(hide) def _create_headers(self, headers, stream_id, end_stream=True): - # TODO: implement max frame size checks and sending in chunks - - flags = frame.Frame.FLAG_END_HEADERS - if end_stream: - flags |= frame.Frame.FLAG_END_STREAM + def frame_cls(chunks): + for i in chunks: + if i == 0: + yield frame.HeadersFrame, i + else: + yield frame.ContinuationFrame, i header_block_fragment = self.encoder.encode(headers) - frm = frame.HeadersFrame( + chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunks = range(0, len(header_block_fragment), chunk_size) + frms = [frm_cls( state=self, - flags=flags, + flags=frame.Frame.FLAG_NO_FLAGS, stream_id=stream_id, - header_block_fragment=header_block_fragment) + header_block_fragment=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] + + last_flags = frame.Frame.FLAG_END_HEADERS + if end_stream: + last_flags |= frame.Frame.FLAG_END_STREAM + frms[-1].flags = last_flags if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) + for frm in frms: + print(frm.human_readable(">>")) - return [frm.to_bytes()] + return [frm.to_bytes() for frm in frms] def _create_body(self, body, stream_id): if body is None or len(body) == 0: diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 7f3fd2bd..8c38bebd 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -247,7 +247,21 @@ class TestCreateHeaders(): '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ .decode('hex') - # TODO: add test for too large header_block_fragments + def test_create_headers_multiple_frames(self): + headers = [ + (b':method', b'GET'), + (b':path', b'/'), + (b':scheme', b'https'), + (b'foo', b'bar'), + (b'server', b'version')] + + protocol = HTTP2Protocol(self.c) + protocol.http2_settings[SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] = 8 + bytes = protocol._create_headers(headers, 1, end_stream=True) + assert len(bytes) == 3 + assert bytes[0] == '000008010000000001828487408294e783'.decode('hex') + assert bytes[1] == '0000080900000000018c767f7685ee5b10'.decode('hex') + assert bytes[2] == '00000209050000000163d5'.decode('hex') class TestCreateBody(): -- cgit v1.2.3 From 9686a77dcb640ace74f923c1f0f7f7307f79edfe Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 16 Aug 2015 20:02:18 +0200 Subject: http2: implement request target --- netlib/http/cookies.py | 3 +- netlib/http/http2/protocol.py | 39 ++++++++++++++++--- test/http/http2/test_protocol.py | 81 +++++++++++++++++++++++++++++++++++++--- 3 files changed, 110 insertions(+), 13 deletions(-) diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index b77e3503..78b03a83 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -23,8 +23,7 @@ variants. Serialization follows RFC6265. http://tools.ietf.org/html/rfc2965 """ -# TODO -# - Disallow LHS-only Cookie values +# TODO: Disallow LHS-only Cookie values def _read_until(s, start, term): diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index c27b4e9e..eacbd2d8 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -80,13 +80,39 @@ class HTTP2Protocol(semantics.ProtocolMixin): timestamp_end = time.time() + authority = headers.get_first(':authority', '') + method = headers.get_first(':method', 'GET') + scheme = headers.get_first(':scheme', 'https') + path = headers.get_first(':path', '/') + host = None + port = None + + if path == '*' or path.startswith("/"): + form_in = "relative" + elif method == 'CONNECT': + form_in = "authority" + if ":" in authority: + host, port = authority.split(":", 1) + else: + host = authority + else: + form_in = "absolute" + # FIXME: verify if path or :host contains what we need + scheme, host, port, _ = utils.parse_url(path) + + if host is None: + host = 'localhost' + if port is None: + port = 80 if scheme == 'http' else 443 + port = int(port) + request = http.Request( - "relative", # TODO: use the correct value - headers.get_first(':method', 'GET'), - headers.get_first(':scheme', 'https'), - headers.get_first(':host', 'localhost'), - 443, # TODO: parse port number from host? - headers.get_first(':path', '/'), + form_in, + method, + scheme, + host, + port, + path, (2, 0), headers, body, @@ -324,6 +350,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): return [frm.to_bytes() for frm in frms] def _receive_transmission(self, include_body=True): + # TODO: include_body is not respected body_expected = True stream_id = 0 diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 8c38bebd..fc0fe487 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -289,7 +289,6 @@ class TestCreateBody(): class TestReadRequest(tservers.ServerTestBase): class handler(tcp.BaseHandler): - def handle(self): self.wfile.write( b'000003010400000001828487'.decode('hex')) @@ -306,11 +305,83 @@ class TestReadRequest(tservers.ServerTestBase): protocol = HTTP2Protocol(c, is_server=True) protocol.connection_preface_performed = True - resp = protocol.read_request() + req = protocol.read_request() - assert resp.stream_id - assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] - assert resp.body == b'foobar' + assert req.stream_id + assert req.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] + assert req.body == b'foobar' + + +class TestReadRequestRelative(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + b'00000c0105000000014287d5af7e4d5a777f4481f9'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_asterisk_form_in(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True + + "OPTIONS *" + req = protocol.read_request() + + assert req.form_in == "relative" + assert req.method == "OPTIONS" + assert req.path == "*" + + +class TestReadRequestAbsolute(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + b'00001901050000000182448d9d29aee30c0e492c2a1170426366871c92585422e085'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_absolute_form_in(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True + + req = protocol.read_request() + + assert req.form_in == "absolute" + assert req.scheme == "http" + assert req.host == "address" + assert req.port == 22 + + +class TestReadRequestConnect(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + b'00001b0105000000014287bdab4e9c17b7ff44871c92585422e08541871c92585422e085'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_connect(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True + + req = protocol.read_request() + + assert req.form_in == "authority" + assert req.method == "CONNECT" + assert req.host == "address" + assert req.port == 22 class TestReadResponse(tservers.ServerTestBase): -- cgit v1.2.3 From 6810fba54ef9c885215d5ff02534b93bb6868b2e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 19 Aug 2015 16:05:42 +0200 Subject: add ssl peek polyfill --- netlib/tcp.py | 20 ++++++++++++++++++-- setup.py | 2 +- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 289618a7..c6638177 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -271,16 +271,32 @@ class Reader(_FileLike): Returns: Up to the next N bytes if peeking is successful. - None, otherwise. Raises: + NetLibError if there was an error with the socket NetLibSSLError if there was an error with pyOpenSSL. + NotImplementedError if the underlying file object is not a (pyOpenSSL) socket """ - if isinstance(self.o, SSL.Connection) or isinstance(self.o, socket._fileobject): + if isinstance(self.o, socket._fileobject): try: return self.o._sock.recv(length, socket.MSG_PEEK) + except socket.error as e: + raise NetLibError(str(e)) + elif isinstance(self.o, SSL.Connection): + try: + if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): + return self.o.recv(length, socket.MSG_PEEK) + else: + # Polyfill for pyOpenSSL <= 0.15.1 + # Taken from https://github.com/pyca/pyopenssl/commit/1d95dea7fea03c7c0df345a5ea30c12d8a0378d2 + buf = SSL._ffi.new("char[]", length) + result = SSL._lib.SSL_peek(self.o._ssl, buf, length) + self.o._raise_ssl_error(self.o._ssl, result) + return SSL._ffi.buffer(buf, result)[:] except SSL.Error as e: raise NetLibSSLError(str(e)) + else: + raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") class Address(object): diff --git a/setup.py b/setup.py index d51977ee..a4da6e69 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ setup( install_requires=[ "pyasn1>=0.1.7", "pyOpenSSL>=0.15.1", - "cryptography>=0.9", + "cryptography>=1.0", "passlib>=1.6.2", "hpack>=1.0.1", "certifi" -- cgit v1.2.3 From 9920de1e153d4a85bbc4fa1dfd8fe5db45d56ab3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 19 Aug 2015 16:06:33 +0200 Subject: tcp._Connection: clean up code, fix inheritance --- netlib/tcp.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index c6638177..a0e2ab5e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -399,6 +399,22 @@ def close_socket(sock): class _Connection(object): + rbufsize = -1 + wbufsize = -1 + + def __init__(self, connection): + if connection: + self.connection = connection + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + else: + self.connection = None + self.rfile = None + self.wfile = None + + self.ssl_established = False + self.finished = False + def get_current_cipher(self): if not self.ssl_established: return None @@ -510,16 +526,13 @@ class _Connection(object): class TCPClient(_Connection): - rbufsize = -1 - wbufsize = -1 def __init__(self, address, source_address=None): - self.connection, self.rfile, self.wfile = None, None, None + super(TCPClient, self).__init__(None) self.address = address self.source_address = Address.wrap( source_address) if source_address else None self.cert = None - self.ssl_established = False self.ssl_verification_error = None self.sni = None @@ -627,20 +640,12 @@ class BaseHandler(_Connection): """ The instantiator is expected to call the handle() and finish() methods. - """ - rbufsize = -1 - wbufsize = -1 def __init__(self, connection, address, server): - self.connection = connection + super(BaseHandler, self).__init__(connection) self.address = Address.wrap(address) self.server = server - self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) - - self.finished = False - self.ssl_established = False self.clientcert = None def create_ssl_context(self, -- cgit v1.2.3 From 1025c15242b1f9324bf17ceb53224c84e026b3dc Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 20 Aug 2015 09:54:45 +0200 Subject: fix typo --- netlib/http/http2/frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index aa1fbae4..ad00a59a 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -569,7 +569,7 @@ class WindowUpdateFrame(Frame): def payload_bytes(self): if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: raise ValueError( - 'Window Szie Increment MUST be greater than 0 and less than 2^31.') + 'Window Size Increment MUST be greater than 0 and less than 2^31.') return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) -- cgit v1.2.3 From e20d4e5c027ad7000f0d997ffb327817ef0dd557 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 19 Aug 2015 21:09:15 +0200 Subject: http2: add callback to handle unexpected frames --- netlib/http/http2/protocol.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index eacbd2d8..aa52bd71 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -49,12 +49,14 @@ class HTTP2Protocol(semantics.ProtocolMixin): dump_frames=False, encoder=None, decoder=None, + unhandled_frame_cb=None, ): self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) self.is_server = is_server self.dump_frames = dump_frames self.encoder = encoder or Encoder() self.decoder = decoder or Decoder() + self.unhandled_frame_cb = unhandled_frame_cb self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None @@ -258,11 +260,17 @@ class HTTP2Protocol(semantics.ProtocolMixin): "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True + def _handle_unexpected_frame(self, frm): + if self.unhandled_frame_cb is not None: + self.unhandled_frame_cb(frm) + def _receive_settings(self, hide=False): while True: frm = self.read_frame(hide) if isinstance(frm, frame.SettingsFrame): break + else: + self._handle_unexpected_frame(frm) def _read_settings_ack(self, hide=False): # pragma no cover while True: @@ -271,6 +279,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): assert frm.flags & frame.Frame.FLAG_ACK assert len(frm.settings) == 0 break + else: + self._handle_unexpected_frame(frm) def _next_stream_id(self): if self.current_stream_id is None: @@ -367,6 +377,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): body_expected = False if frm.flags & frame.Frame.FLAG_END_HEADERS: break + else: + self._handle_unexpected_frame(frm) while body_expected: frm = self.read_frame() @@ -374,6 +386,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): body += frm.payload if frm.flags & frame.Frame.FLAG_END_STREAM: break + else: + self._handle_unexpected_frame(frm) + # TODO: implement window update & flow headers = odict.ODictCaseless() -- cgit v1.2.3 From eb343055185fabc892a590c6220b125283036b4e Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 20 Aug 2015 10:21:22 +0200 Subject: http2: fix frame length field --- netlib/http/http2/frame.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index ad00a59a..24e6510a 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -98,7 +98,7 @@ class Frame(object): self._check_frame_size(self.length, self.state) - b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) + b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) b += struct.pack('!B', self.TYPE) b += struct.pack('!B', self.flags) b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) -- cgit v1.2.3 From 94b7beae2a818ac873fb63991ab5237de1c104dd Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 20 Aug 2015 10:21:38 +0200 Subject: http2: implement basic flow control updates --- netlib/http/http2/protocol.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index aa52bd71..bf0b364f 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -251,6 +251,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: self._apply_settings(frm.settings, hide) + if isinstance(frm, frame.DataFrame) and frm.length > 0: + self._update_flow_control_window(frm.stream_id, frm.length) + return frm def check_alpn(self): @@ -309,6 +312,12 @@ class HTTP2Protocol(semantics.ProtocolMixin): # be liberal in what we expect from the other end # to be more strict use: self._read_settings_ack(hide) + def _update_flow_control_window(self, stream_id, increment): + frm = frame.WindowUpdateFrame(stream_id=0, window_size_increment=increment) + self.send_frame(frm) + frm = frame.WindowUpdateFrame(stream_id=stream_id, window_size_increment=increment) + self.send_frame(frm) + def _create_headers(self, headers, stream_id, end_stream=True): def frame_cls(chunks): for i in chunks: @@ -351,8 +360,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): payload=body[i:i+chunk_size]) for i in chunks] frms[-1].flags = frame.Frame.FLAG_END_STREAM - # TODO: implement flow-control window - if self.dump_frames: # pragma no cover for frm in frms: print(frm.human_readable(">>")) @@ -369,8 +376,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): while True: frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame)\ - or isinstance(frm, frame.ContinuationFrame): + if ( + (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and + (stream_id == 0 or frm.stream_id == stream_id) + ): stream_id = frm.stream_id header_block_fragment += frm.header_block_fragment if frm.flags & frame.Frame.FLAG_END_STREAM: @@ -382,15 +391,13 @@ class HTTP2Protocol(semantics.ProtocolMixin): while body_expected: frm = self.read_frame() - if isinstance(frm, frame.DataFrame): + if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: body += frm.payload if frm.flags & frame.Frame.FLAG_END_STREAM: break else: self._handle_unexpected_frame(frm) - # TODO: implement window update & flow - headers = odict.ODictCaseless() for header, value in self.decoder.decode(header_block_fragment): headers.add(header, value) -- cgit v1.2.3 From 16f697f68a7f94375bd1435f5eec6e00911b7019 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 20 Aug 2015 10:26:43 +0200 Subject: http2: disable features we do not support yet --- netlib/http/http2/protocol.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index bf0b364f..cf46a130 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -225,7 +225,11 @@ class HTTP2Protocol(semantics.ProtocolMixin): magic = self.tcp_handler.rfile.safe_read(magic_length) assert magic == self.CLIENT_CONNECTION_PREFACE - self.send_frame(frame.SettingsFrame(state=self), hide=True) + frm = frame.SettingsFrame(state=self, settings={ + frame.SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 0, + frame.SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 1, + }) + self.send_frame(frm, hide=True) self._receive_settings(hide=True) def perform_client_connection_preface(self, force=False): -- cgit v1.2.3 From 53f2582313ce5e8d1c875bea8b3f1a270db35b5b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 20 Aug 2015 20:36:51 +0200 Subject: http2: fix unhandled settings frame --- netlib/http/http2/protocol.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index cf46a130..66ce19c8 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -239,7 +239,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) + self._receive_settings(hide=True) # server announces own settings + self._receive_settings(hide=True) # server acks my settings def send_frame(self, frm, hide=False): raw_bytes = frm.to_bytes() @@ -279,16 +280,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: self._handle_unexpected_frame(frm) - def _read_settings_ack(self, hide=False): # pragma no cover - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - assert frm.flags & frame.Frame.FLAG_ACK - assert len(frm.settings) == 0 - break - else: - self._handle_unexpected_frame(frm) - def _next_stream_id(self): if self.current_stream_id is None: if self.is_server: @@ -313,9 +304,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): flags=frame.Frame.FLAG_ACK) self.send_frame(frm, hide) - # be liberal in what we expect from the other end - # to be more strict use: self._read_settings_ack(hide) - def _update_flow_control_window(self, stream_id, increment): frm = frame.WindowUpdateFrame(stream_id=0, window_size_increment=increment) self.send_frame(frm) -- cgit v1.2.3 From 00ed982ea0c802f980732f846955264935d65689 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 20 Aug 2015 20:44:58 +0200 Subject: cleanup --- test/http/http2/test_protocol.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index fc0fe487..fb1f460e 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -328,7 +328,6 @@ class TestReadRequestRelative(tservers.ServerTestBase): protocol = HTTP2Protocol(c, is_server=True) protocol.connection_preface_performed = True - "OPTIONS *" req = protocol.read_request() assert req.form_in == "relative" -- cgit v1.2.3 From 6fc2ff94694d70426663209e2ded977d9e0ecd3c Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 21 Aug 2015 09:18:14 +0200 Subject: http2: fix tests --- test/http/http2/test_protocol.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index fb1f460e..726d8e2e 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -127,7 +127,7 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase): protocol.perform_server_connection_preface() assert protocol.connection_preface_performed - tutils.raises(tcp.NetLibIncomplete, protocol.perform_server_connection_preface, force=True) + tutils.raises(tcp.NetLibDisconnect, protocol.perform_server_connection_preface, force=True) class TestPerformClientConnectionPreface(tservers.ServerTestBase): @@ -194,12 +194,12 @@ class TestServerStreamIds(): class TestApplySettings(tservers.ServerTestBase): class handler(tcp.BaseHandler): - def handle(self): # check settings acknowledgement assert self.rfile.read(9) == '000000040100000000'.decode('hex') self.wfile.write("OK") self.wfile.flush() + self.rfile.safe_read(9) # just to keep the connection alive a bit longer ssl = True @@ -295,6 +295,7 @@ class TestReadRequest(tservers.ServerTestBase): self.wfile.write( b'000006000100000001666f6f626172'.decode('hex')) self.wfile.flush() + self.rfile.safe_read(9) # just to keep the connection alive a bit longer ssl = True @@ -385,13 +386,13 @@ class TestReadRequestConnect(tservers.ServerTestBase): class TestReadResponse(tservers.ServerTestBase): class handler(tcp.BaseHandler): - def handle(self): self.wfile.write( b'00000801040000000188628594e78c767f'.decode('hex')) self.wfile.write( b'000006000100000001666f6f626172'.decode('hex')) self.wfile.flush() + self.rfile.safe_read(9) # just to keep the connection alive a bit longer ssl = True -- cgit v1.2.3 From cd9701050f58f90c757a34f7e4e6b5711700d649 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 21 Aug 2015 10:03:57 +0200 Subject: read_response depends on request for stream_id --- netlib/http/http1/protocol.py | 4 ++-- netlib/http/http2/protocol.py | 18 +++++++++++------- netlib/http/semantics.py | 34 +++++++++++++++++++++++---------- test/http/http1/test_protocol.py | 5 +++-- test/http/http2/test_protocol.py | 39 +++++++++++++++----------------------- test/websockets/test_websockets.py | 6 ++---- 6 files changed, 57 insertions(+), 49 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index dc33a8af..107a48d1 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -136,7 +136,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): def read_response( self, - request_method, + request, body_size_limit, include_body=True, ): @@ -175,7 +175,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): body = self.read_http_body( headers, body_size_limit, - request_method, + request.method, code, False ) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 66ce19c8..e032c2a0 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -74,7 +74,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() - stream_id, headers, body = self._receive_transmission(include_body) + stream_id, headers, body = self._receive_transmission( + include_body=include_body, + ) if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): # more accurate timestamp_start @@ -127,7 +129,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): def read_response( self, - request_method='', + request='', body_size_limit=None, include_body=True, ): @@ -137,7 +139,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() - stream_id, headers, body = self._receive_transmission(include_body) + stream_id, headers, body = self._receive_transmission( + stream_id=request.stream_id, + include_body=include_body, + ) if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): # more accurate timestamp_start @@ -145,7 +150,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): if include_body: timestamp_end = time.time() - else: + else: # pragma: no cover timestamp_end = None response = http.Response( @@ -358,11 +363,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): return [frm.to_bytes() for frm in frms] - def _receive_transmission(self, include_body=True): + def _receive_transmission(self, stream_id=None, include_body=True): # TODO: include_body is not respected body_expected = True - stream_id = 0 header_block_fragment = b'' body = b'' @@ -370,7 +374,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): frm = self.read_frame() if ( (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and - (stream_id == 0 or frm.stream_id == stream_id) + (stream_id is None or frm.stream_id == stream_id) ): stream_id = frm.stream_id header_block_fragment += frm.header_block_fragment diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 836af550..e388a344 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -337,18 +337,32 @@ class Request(object): class EmptyRequest(Request): - def __init__(self): + def __init__( + self, + form_in="", + method="", + scheme="", + host="", + port="", + path="", + httpversion=None, + headers=None, + body="", + stream_id=None + ): super(EmptyRequest, self).__init__( - form_in="", - method="", - scheme="", - host="", - port="", - path="", - httpversion=(0, 0), - headers=odict.ODictCaseless(), - body="", + form_in=form_in, + method=method, + scheme=scheme, + host=host, + port=port, + path=path, + httpversion=(httpversion or (0, 0)), + headers=(headers or odict.ODictCaseless()), + body=body, ) + if stream_id: + self.stream_id = stream_id class Response(object): diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 6704647f..31bf7dab 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -376,8 +376,9 @@ class TestReadRequest(object): class TestReadResponse(object): def tst(self, data, method, body_size_limit, include_body=True): data = textwrap.dedent(data) + request = http.EmptyRequest(method=method) return mock_protocol(data).read_response( - method, body_size_limit, include_body=include_body + request, body_size_limit, include_body=include_body ) def test_errors(self): @@ -457,7 +458,7 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase): def test_no_content_length(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - resp = HTTP1Protocol(c).read_response("GET", None) + resp = HTTP1Protocol(c).read_response(http.EmptyRequest(method="GET"), None) assert resp.body == "bar\r\n\r\n" diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 726d8e2e..92fa109c 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -365,6 +365,8 @@ class TestReadRequestConnect(tservers.ServerTestBase): def handle(self): self.wfile.write( b'00001b0105000000014287bdab4e9c17b7ff44871c92585422e08541871c92585422e085'.decode('hex')) + self.wfile.write( + b'00001d0105000000014287bdab4e9c17b7ff44882f91d35d055c87a741882f91d35d055c87a7'.decode('hex')) self.wfile.flush() ssl = True @@ -377,20 +379,25 @@ class TestReadRequestConnect(tservers.ServerTestBase): protocol.connection_preface_performed = True req = protocol.read_request() - assert req.form_in == "authority" assert req.method == "CONNECT" assert req.host == "address" assert req.port == 22 + req = protocol.read_request() + assert req.form_in == "authority" + assert req.method == "CONNECT" + assert req.host == "example.com" + assert req.port == 443 + class TestReadResponse(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): self.wfile.write( - b'00000801040000000188628594e78c767f'.decode('hex')) + b'00000801040000002a88628594e78c767f'.decode('hex')) self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) + b'00000600010000002a666f6f626172'.decode('hex')) self.wfile.flush() self.rfile.safe_read(9) # just to keep the connection alive a bit longer @@ -403,8 +410,9 @@ class TestReadResponse(tservers.ServerTestBase): protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True - resp = protocol.read_response() + resp = protocol.read_response(http.EmptyRequest(stream_id=42)) + assert resp.stream_id == 42 assert resp.httpversion == (2, 0) assert resp.status_code == 200 assert resp.msg == "" @@ -412,29 +420,12 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.body == b'foobar' assert resp.timestamp_end - def test_read_response_no_body(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = HTTP2Protocol(c) - protocol.connection_preface_performed = True - - resp = protocol.read_response(include_body=False) - - assert resp.httpversion == (2, 0) - assert resp.status_code == 200 - assert resp.msg == "" - assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] - assert resp.body == b'foobar' # TODO: this should be true: assert resp.body == http.CONTENT_MISSING - assert not resp.timestamp_end - class TestReadEmptyResponse(tservers.ServerTestBase): class handler(tcp.BaseHandler): - def handle(self): self.wfile.write( - b'00000801050000000188628594e78c767f'.decode('hex')) + b'00000801050000002a88628594e78c767f'.decode('hex')) self.wfile.flush() ssl = True @@ -446,9 +437,9 @@ class TestReadEmptyResponse(tservers.ServerTestBase): protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True - resp = protocol.read_response() + resp = protocol.read_response(http.EmptyRequest(stream_id=42)) - assert resp.stream_id + assert resp.stream_id == 42 assert resp.httpversion == (2, 0) assert resp.status_code == 200 assert resp.msg == "" diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 752f2c3e..5f27c128 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -2,9 +2,7 @@ import os from nose.tools import raises -from netlib import tcp -from netlib import tutils -from netlib import websockets +from netlib import tcp, tutils, websockets, http from netlib.http import status_codes from netlib.http.exceptions import * from netlib.http.http1 import HTTP1Protocol @@ -72,7 +70,7 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(headers.format() + "\r\n") self.wfile.flush() - resp = http1_protocol.read_response("get", None) + resp = http1_protocol.read_response(http.EmptyRequest(method="GET"), None) server_nonce = self.protocol.check_server_handshake(resp.headers) if not server_nonce == self.protocol.create_server_nonce( -- cgit v1.2.3 From 622665952ca072a6276917c252758bbe19091a0d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 24 Aug 2015 16:52:32 +0200 Subject: minor stylistic fixes --- netlib/http/http2/protocol.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index e032c2a0..1d6e0168 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -123,6 +123,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): timestamp_start, timestamp_end, ) + # FIXME: We should not do this. request.stream_id = stream_id return request @@ -150,7 +151,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): if include_body: timestamp_end = time.time() - else: # pragma: no cover + else: timestamp_end = None response = http.Response( @@ -274,7 +275,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): return True def _handle_unexpected_frame(self, frm): - if self.unhandled_frame_cb is not None: + if self.unhandled_frame_cb: self.unhandled_frame_cb(frm) def _receive_settings(self, hide=False): @@ -364,7 +365,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): return [frm.to_bytes() for frm in frms] def _receive_transmission(self, stream_id=None, include_body=True): - # TODO: include_body is not respected + if not include_body: + raise NotImplementedError() + body_expected = True header_block_fragment = b'' -- cgit v1.2.3 From 21858995aee48c67430c9b6f3965d897b27cd734 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 24 Aug 2015 18:16:34 +0200 Subject: request -> request_method --- netlib/http/http1/protocol.py | 6 +++--- netlib/http/http2/protocol.py | 11 +++++++++-- netlib/http/semantics.py | 9 +++------ test/http/http1/test_protocol.py | 5 ++--- test/http/http2/test_protocol.py | 5 ++--- test/websockets/test_websockets.py | 2 +- 6 files changed, 20 insertions(+), 18 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 107a48d1..6b4489fb 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -136,8 +136,8 @@ class HTTP1Protocol(semantics.ProtocolMixin): def read_response( self, - request, - body_size_limit, + request_method, + body_size_limit=None, include_body=True, ): """ @@ -175,7 +175,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): body = self.read_http_body( headers, body_size_limit, - request.method, + request_method, code, False ) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 1d6e0168..b6a147d4 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -68,6 +68,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): body_size_limit=None, allow_empty=False, ): + if body_size_limit is not None: + raise NotImplementedError() + self.perform_connection_preface() timestamp_start = time.time() @@ -130,10 +133,14 @@ class HTTP2Protocol(semantics.ProtocolMixin): def read_response( self, - request='', + request_method='', body_size_limit=None, include_body=True, + stream_id=None, ): + if body_size_limit is not None: + raise NotImplementedError() + self.perform_connection_preface() timestamp_start = time.time() @@ -141,7 +148,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): self.tcp_handler.rfile.reset_timestamps() stream_id, headers, body = self._receive_transmission( - stream_id=request.stream_id, + stream_id=stream_id, include_body=include_body, ) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index e388a344..2b960483 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -345,10 +345,9 @@ class EmptyRequest(Request): host="", port="", path="", - httpversion=None, + httpversion=(0, 0), headers=None, - body="", - stream_id=None + body="" ): super(EmptyRequest, self).__init__( form_in=form_in, @@ -357,12 +356,10 @@ class EmptyRequest(Request): host=host, port=port, path=path, - httpversion=(httpversion or (0, 0)), + httpversion=httpversion, headers=(headers or odict.ODictCaseless()), body=body, ) - if stream_id: - self.stream_id = stream_id class Response(object): diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 31bf7dab..6704647f 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -376,9 +376,8 @@ class TestReadRequest(object): class TestReadResponse(object): def tst(self, data, method, body_size_limit, include_body=True): data = textwrap.dedent(data) - request = http.EmptyRequest(method=method) return mock_protocol(data).read_response( - request, body_size_limit, include_body=include_body + method, body_size_limit, include_body=include_body ) def test_errors(self): @@ -458,7 +457,7 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase): def test_no_content_length(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - resp = HTTP1Protocol(c).read_response(http.EmptyRequest(method="GET"), None) + resp = HTTP1Protocol(c).read_response("GET", None) assert resp.body == "bar\r\n\r\n" diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 92fa109c..8810894f 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -410,9 +410,8 @@ class TestReadResponse(tservers.ServerTestBase): protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True - resp = protocol.read_response(http.EmptyRequest(stream_id=42)) + resp = protocol.read_response(stream_id=42) - assert resp.stream_id == 42 assert resp.httpversion == (2, 0) assert resp.status_code == 200 assert resp.msg == "" @@ -437,7 +436,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True - resp = protocol.read_response(http.EmptyRequest(stream_id=42)) + resp = protocol.read_response(stream_id=42) assert resp.stream_id == 42 assert resp.httpversion == (2, 0) diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 5f27c128..be87b20a 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -70,7 +70,7 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(headers.format() + "\r\n") self.wfile.flush() - resp = http1_protocol.read_response(http.EmptyRequest(method="GET"), None) + resp = http1_protocol.read_response("GET", None) server_nonce = self.protocol.check_server_handshake(resp.headers) if not server_nonce == self.protocol.create_server_nonce( -- cgit v1.2.3 From de0ced73f8e14aec8f94ea93c0ba0165026e09fc Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 25 Aug 2015 18:33:55 +0200 Subject: fix error messages --- netlib/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index a0e2ab5e..3a094d9a 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -281,7 +281,7 @@ class Reader(_FileLike): try: return self.o._sock.recv(length, socket.MSG_PEEK) except socket.error as e: - raise NetLibError(str(e)) + raise NetLibError(repr(e)) elif isinstance(self.o, SSL.Connection): try: if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): @@ -294,7 +294,7 @@ class Reader(_FileLike): self.o._raise_ssl_error(self.o._ssl, result) return SSL._ffi.buffer(buf, result)[:] except SSL.Error as e: - raise NetLibSSLError(str(e)) + raise NetLibSSLError(repr(e)) else: raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -- cgit v1.2.3 From 3e3b59aa71a596fcddd14e72612067923a0d9b21 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 26 Aug 2015 20:58:00 +0200 Subject: http2: fix priority stream dependency check --- netlib/http/http2/frame.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index 24e6510a..b36b3adf 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -290,9 +290,6 @@ class PriorityFrame(Frame): raise ValueError( 'PRIORITY frames MUST be associated with a stream.') - if self.stream_dependency == 0x0: - raise ValueError('stream dependency is invalid.') - return struct.pack( '!LB', (int( -- cgit v1.2.3 From daf512ce936268b7b1095ca2991b69157f79c122 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 26 Aug 2015 21:04:13 +0200 Subject: http2: fix tests --- test/http/http2/test_frames.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py index 077f5bc2..5d5cb0ba 100644 --- a/test/http/http2/test_frames.py +++ b/test/http/http2/test_frames.py @@ -246,9 +246,9 @@ def test_priority_frame_to_bytes(): flags=(Frame.FLAG_NO_FLAGS), stream_id=0x1234567, exclusive=True, - stream_dependency=0x7654321, + stream_dependency=0x0, weight=42) - assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') + assert_equal(f.to_bytes().encode('hex'), '000005020001234567800000002a') f = PriorityFrame( length=5, @@ -266,13 +266,6 @@ def test_priority_frame_to_bytes(): stream_dependency=0x1234567) tutils.raises(ValueError, f.to_bytes) - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - stream_dependency=0x0) - tutils.raises(ValueError, f.to_bytes) - def test_priority_frame_from_bytes(): f = Frame.from_file(hex_to_file('000005020001234567876543212a')) -- cgit v1.2.3 From 982d8000c420937da532d1c584e3ca7a86c5f3e8 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 28 Aug 2015 17:35:48 +0200 Subject: wip --- netlib/http/__init__.py | 1 - netlib/http/http2/protocol.py | 4 +--- netlib/tcp.py | 18 +----------------- netlib/utils.py | 2 +- 4 files changed, 3 insertions(+), 22 deletions(-) diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index b01afc6d..9b4b0e6b 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,3 +1,2 @@ -from . import * from exceptions import * from semantics import * diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index b6a147d4..b297e0b8 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -34,9 +34,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): HTTP_1_1_REQUIRED=0xd ) - # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE =\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" ALPN_PROTO_H2 = 'h2' diff --git a/netlib/tcp.py b/netlib/tcp.py index 3a094d9a..9dfa8d22 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -23,28 +23,12 @@ EINTR = 4 # To enable all SSL methods use: SSLv23 # then add options to disable certain methods # https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 - -# Use ONLY for parsing of CLI arguments! -# All code internals should use OpenSSL constants directly! -SSL_VERSIONS = { - 'TLSv1.2': SSL.TLSv1_2_METHOD, - 'TLSv1.1': SSL.TLSv1_1_METHOD, - 'TLSv1': SSL.TLSv1_METHOD, - 'SSLv3': SSL.SSLv3_METHOD, - 'SSLv2': SSL.SSLv2_METHOD, - 'SSLv23': SSL.SSLv23_METHOD, -} - -SSL_DEFAULT_VERSION = 'SSLv23' - -SSL_DEFAULT_METHOD = SSL_VERSIONS[SSL_DEFAULT_VERSION] - +SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD SSL_DEFAULT_OPTIONS = ( SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_CIPHER_SERVER_PREFERENCE ) - if hasattr(SSL, "OP_NO_COMPRESSION"): SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION diff --git a/netlib/utils.py b/netlib/utils.py index 31dcd622..d6190673 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -182,7 +182,7 @@ def parse_url(url): return None else: host = netloc - if scheme == "https": + if scheme.endswith("https"): port = 443 else: port = 80 -- cgit v1.2.3 From 1265945f55604f32d99c3dd7c1efd13b3f2ecd9b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 29 Aug 2015 12:30:35 +0200 Subject: move sslversion mapping to netlib --- netlib/tcp.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 9dfa8d22..0d83816b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -32,6 +32,23 @@ SSL_DEFAULT_OPTIONS = ( if hasattr(SSL, "OP_NO_COMPRESSION"): SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION +""" +Map a reasonable SSL version specification into the format OpenSSL expects. +Don't ask... +https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 +""" +sslversion_choices = { + "all": (SSL.SSLv23_METHOD, 0), + # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+ + # TLSv1_METHOD would be TLS 1.0 only + "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)), + "SSLv2": (SSL.SSLv2_METHOD, 0), + "SSLv3": (SSL.SSLv3_METHOD, 0), + "TLSv1": (SSL.TLSv1_METHOD, 0), + "TLSv1_1": (SSL.TLSv1_1_METHOD, 0), + "TLSv1_2": (SSL.TLSv1_2_METHOD, 0), +} + class NetLibError(Exception): pass -- cgit v1.2.3 From 4a8fd79e334661c1a11cd1cd28d62e6999b384d9 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 29 Aug 2015 20:54:54 +0200 Subject: don't yield prefix and suffix --- netlib/http/http1/protocol.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 6b4489fb..50975818 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -258,9 +258,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): def read_http_body(self, *args, **kwargs): - return "".join( - content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) - ) + return "".join(self.read_http_body_chunked(*args, **kwargs)) def read_http_body_chunked( @@ -308,7 +306,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): while bytes_left: chunk_size = min(bytes_left, max_chunk_size) content = self.tcp_handler.rfile.read(chunk_size) - yield "", content, "" + yield content bytes_left -= chunk_size else: bytes_left = limit or -1 @@ -317,7 +315,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): content = self.tcp_handler.rfile.read(chunk_size) if not content: return - yield "", content, "" + yield content bytes_left -= chunk_size not_done = self.tcp_handler.rfile.read(1) if not_done: @@ -418,7 +416,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): suffix = self.tcp_handler.rfile.readline(5) if suffix != '\r\n': raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' + yield chunk if length == 0: return -- cgit v1.2.3 From 29b355c469dd816656908e6dceeb703a1e7e7cd5 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 29 Aug 2015 20:57:51 +0200 Subject: update .travis.yml --- .travis.yml | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/.travis.yml b/.travis.yml index 3b49224e..fd2fba3d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,13 +1,10 @@ -language: python - sudo: false - -python: - - "2.7" - # - pypy # disable until TravisCI ships a PyPy version which works with the latest CFFI +language: python matrix: + fast_finish: true include: + - python: 2.7 - python: 2.7 env: OPENSSL=1.0.2 addons: @@ -18,16 +15,21 @@ matrix: - debian-sid packages: - libssl-dev - # - python: pypy - # env: OPENSSL=1.0.2 - # addons: - # apt: - # sources: - # # Debian sid currently holds OpenSSL 1.0.2 - # # change this with future releases! - # - debian-sid - # packages: - # - libssl-dev + - python: pypy + - python: pypy + env: OPENSSL=1.0.2 + addons: + apt: + sources: + # Debian sid currently holds OpenSSL 1.0.2 + # change this with future releases! + - debian-sid + packages: + - libssl-dev + allow_failures: + # We allow pypy to fail until Travis fixes their infrastructure to a pypy + # with a recent enought CFFI library to run cryptography 1.0+. + - python: pypy install: - "pip install --src . -r requirements.txt" @@ -50,7 +52,7 @@ notifications: slack: rooms: - mitmproxy:YaDGC9Gt9TEM7o8zkC2OLNsu - on_success: :change + on_success: change on_failure: always # exclude cryptography from cache @@ -58,14 +60,11 @@ notifications: # which needs to be compiled specifically to each version before_cache: - pip uninstall -y cryptography - - rm -rf /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages/cryptography/ - - rm -rf /home/travis/virtualenv/pypy-2.5.0/site-packages/cryptography/ - - rm /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages/pip/_vendor/requests/packages/urllib3/contrib/pyopenssl.py - - rm /home/travis/virtualenv/pypy-2.5.0/site-packages/pip/_vendor/requests/packages/urllib3/contrib/pyopenssl.py cache: directories: + - $HOME/.cache/pip - /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages - /home/travis/virtualenv/python2.7.9/bin - /home/travis/virtualenv/pypy-2.5.0/site-packages - - /home/travis/virtualenv/pypy-2.5.0/bin + - /home/travis/virtualenv/pypy-2.5.0/bin \ No newline at end of file -- cgit v1.2.3 From 97d224752473911f3efdba01b173dc9481e40b50 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 1 Sep 2015 18:58:18 +0200 Subject: update .env --- .env | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.env b/.env index 7f847e29..69ac3f05 100644 --- a/.env +++ b/.env @@ -1,5 +1,6 @@ -DIR=`dirname $0` -if [ -z "$VIRTUAL_ENV" ] && [ -f $DIR/../venv.mitmproxy/bin/activate ]; then +DIR="$( dirname "${BASH_SOURCE[0]}" )" +ACTIVATE_DIR="$(if [ -f "$DIR/../venv.mitmproxy/bin/activate" ]; then echo 'bin'; else echo 'Scripts'; fi;)" +if [ -z "$VIRTUAL_ENV" ] && [ -f "$DIR/../venv.mitmproxy/$ACTIVATE_DIR/activate" ]; then echo "Activating mitmproxy virtualenv..." - source $DIR/../venv.mitmproxy/bin/activate + source "$DIR/../venv.mitmproxy/$ACTIVATE_DIR/activate" fi -- cgit v1.2.3 From 53abf5f4d7c1e6f0712c6473904e5c1a58db0bb9 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 3 Sep 2015 21:22:40 +0200 Subject: http2: handle Ping in protocol --- netlib/http/http2/protocol.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index b297e0b8..2fbe7705 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -261,16 +261,21 @@ class HTTP2Protocol(semantics.ProtocolMixin): print(frm.human_readable(">>")) def read_frame(self, hide=False): - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable("<<")) - if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: - self._apply_settings(frm.settings, hide) - - if isinstance(frm, frame.DataFrame) and frm.length > 0: - self._update_flow_control_window(frm.stream_id, frm.length) - - return frm + while True: + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable("<<")) + + if isinstance(frm, frame.PingFrame): + raw_bytes = frame.PingFrame(flags=frame.Frame.FLAG_ACK, payload=frm.payload).to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + continue + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) + if isinstance(frm, frame.DataFrame) and frm.length > 0: + self._update_flow_control_window(frm.stream_id, frm.length) + return frm def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() -- cgit v1.2.3 From 3ebe5a5147db20036d0762b92898f313b4d2f8d8 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 3 Sep 2015 21:22:55 +0200 Subject: http2: do net let Settings frames escape --- netlib/http/http2/protocol.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 2fbe7705..4328ebdd 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -285,6 +285,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): return True def _handle_unexpected_frame(self, frm): + if isinstance(frm, frame.SettingsFrame): + return if self.unhandled_frame_cb: self.unhandled_frame_cb(frm) -- cgit v1.2.3 From 5f97701958a283fca7188623c3cb4a313456b82c Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 5 Sep 2015 13:26:36 +0200 Subject: add new headers class --- netlib/http/semantics.py | 130 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 129 insertions(+), 1 deletion(-) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 2b960483..162cdbf5 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -1,4 +1,5 @@ from __future__ import (absolute_import, print_function, division) +import UserDict import urllib import urlparse @@ -12,8 +13,135 @@ HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 -class ProtocolMixin(object): +class Headers(UserDict.DictMixin): + """ + Header class which allows both convenient access to individual headers as well as + direct access to the underlying raw data. Provides a full dictionary interface. + + Example: + + .. code-block:: python + + # Create header from a list of (header_name, header_value) tuples + >>> h = Headers([ + ["Host","example.com"], + ["Accept","text/html"], + ["accept","application/xml"] + ]) + + # Headers mostly behave like a normal dict. + >>> h["Host"] + "example.com" + + # HTTP Headers are case insensitive + >>> h["host"] + "example.com" + + # Multiple headers are folded into a single header as per RFC7230 + >>> h["Accept"] + "text/html, application/xml" + + # Setting a header removes all existing headers with the same name. + >>> h["Accept"] = "application/text" + >>> h["Accept"] + "application/text" + + # str(h) returns a HTTP1 header block. + >>> print(h) + Host: example.com + Accept: application/text + + # For full control, the raw header lines can be accessed + >>> h.lines + + # Headers can also be crated from keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") + + Caveats: + For use with the "Set-Cookie" header, see :py:meth:`get_all`. + """ + + def __init__(self, lines=None, **headers): + """ + For convenience, underscores in header names will be transformed to dashes. + This behaviour does not extend to other methods. + + If ``**headers`` contains multiple keys that have equal ``.lower()``s, + the behavior is undefined. + """ + self.lines = lines or [] + + # content_type -> content-type + headers = {k.replace("_", "-"): v for k, v in headers.iteritems()} + self.update(headers) + + def __str__(self): + return "\r\n".join(": ".join(line) for line in self.lines) + + def __getitem__(self, key): + values = self.get_all(key) + if not values: + raise KeyError(key) + else: + return ", ".join(values) + def __setitem__(self, key, value): + idx = self._index(key) + + # To please the human eye, we insert at the same position the first existing header occured. + if idx is not None: + del self[key] + self.lines.insert(idx, [key, value]) + else: + self.lines.append([key, value]) + + def __delitem__(self, key): + key = key.lower() + self.lines = [ + line for line in self.lines + if key != line[0].lower() + ] + + def _index(self, key): + key = key.lower() + for i, line in enumerate(self): + if line[0].lower() == key: + return i + return None + + def keys(self): + return list(set(line[0] for line in self.lines)) + + def __eq__(self, other): + return self.lines == other.lines + + def __ne__(self, other): + return not self.__eq__(other) + + def get_all(self, key, default=None): + """ + Like :py:meth:`get`, but does not fold multiple headers into a single one. + This is useful for Set-Cookie headers, which do not support folding. + + See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + key = key.lower() + values = [line[1] for line in self.lines if line[0].lower() == key] + return values or default + + def set_all(self, key, values): + """ + Explicitly set multiple headers for the given key. + See: :py:meth:`get_all` + """ + if key in self: + del self[key] + self.lines.extend( + [key, value] for value in values + ) + + +class ProtocolMixin(object): def read_request(self, *args, **kwargs): # pragma: no cover raise NotImplementedError -- cgit v1.2.3 From 3718e59308745e4582f4e8061b4ff6113d9dfc74 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 5 Sep 2015 15:27:48 +0200 Subject: finalize Headers, add tests --- netlib/http/semantics.py | 109 ++++++++++++++++++++------------- test/http/test_semantics.py | 145 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+), 41 deletions(-) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 162cdbf5..2fadf2c4 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -51,8 +51,8 @@ class Headers(UserDict.DictMixin): Host: example.com Accept: application/text - # For full control, the raw header lines can be accessed - >>> h.lines + # For full control, the raw header fields can be accessed + >>> h.fields # Headers can also be crated from keyword arguments >>> h = Headers(host="example.com", content_type="application/xml") @@ -61,85 +61,112 @@ class Headers(UserDict.DictMixin): For use with the "Set-Cookie" header, see :py:meth:`get_all`. """ - def __init__(self, lines=None, **headers): + def __init__(self, fields=None, **headers): """ - For convenience, underscores in header names will be transformed to dashes. - This behaviour does not extend to other methods. - - If ``**headers`` contains multiple keys that have equal ``.lower()``s, - the behavior is undefined. + Args: + fields: (optional) list of ``(name, value)`` header tuples, e.g. ``[("Host","example.com")]`` + **headers: Additional headers to set. Will overwrite existing values from `fields`. + For convenience, underscores in header names will be transformed to dashes - + this behaviour does not extend to other methods. + If ``**headers`` contains multiple keys that have equal ``.lower()`` s, + the behavior is undefined. """ - self.lines = lines or [] + self.fields = fields or [] # content_type -> content-type - headers = {k.replace("_", "-"): v for k, v in headers.iteritems()} + headers = { + name.replace("_", "-"): value + for name, value in headers.iteritems() + } self.update(headers) def __str__(self): - return "\r\n".join(": ".join(line) for line in self.lines) + return "\r\n".join(": ".join(field) for field in self.fields) - def __getitem__(self, key): - values = self.get_all(key) + def __getitem__(self, name): + values = self.get_all(name) if not values: - raise KeyError(key) + raise KeyError(name) else: return ", ".join(values) - def __setitem__(self, key, value): - idx = self._index(key) + def __setitem__(self, name, value): + idx = self._index(name) # To please the human eye, we insert at the same position the first existing header occured. if idx is not None: - del self[key] - self.lines.insert(idx, [key, value]) + del self[name] + self.fields.insert(idx, [name, value]) else: - self.lines.append([key, value]) - - def __delitem__(self, key): - key = key.lower() - self.lines = [ - line for line in self.lines - if key != line[0].lower() - ] - - def _index(self, key): - key = key.lower() - for i, line in enumerate(self): - if line[0].lower() == key: + self.fields.append([name, value]) + + def __delitem__(self, name): + if name not in self: + raise KeyError(name) + name = name.lower() + self.fields = [ + field for field in self.fields + if name != field[0].lower() + ] + + def _index(self, name): + name = name.lower() + for i, field in enumerate(self.fields): + if field[0].lower() == name: return i return None def keys(self): - return list(set(line[0] for line in self.lines)) + seen = set() + names = [] + for name, _ in self.fields: + name_lower = name.lower() + if name_lower not in seen: + seen.add(name_lower) + names.append(name) + return names def __eq__(self, other): - return self.lines == other.lines + if isinstance(other, Headers): + return self.fields == other.fields + return False def __ne__(self, other): return not self.__eq__(other) - def get_all(self, key, default=None): + def get_all(self, name, default=None): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. This is useful for Set-Cookie headers, which do not support folding. See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 """ - key = key.lower() - values = [line[1] for line in self.lines if line[0].lower() == key] + name = name.lower() + values = [value for n, value in self.fields if n.lower() == name] return values or default - def set_all(self, key, values): + def set_all(self, name, values): """ Explicitly set multiple headers for the given key. See: :py:meth:`get_all` """ - if key in self: - del self[key] - self.lines.extend( - [key, value] for value in values + if name in self: + del self[name] + self.fields.extend( + [name, value] for value in values ) + # Implement the StateObject protocol from mitmproxy + def get_state(self, short=False): + return tuple(tuple(field) for field in self.fields) + + def load_state(self, state): + self.fields = [list(field) for field in state] + + @classmethod + def from_state(cls, state): + return cls([list(field) for field in state]) + class ProtocolMixin(object): def read_request(self, *args, **kwargs): # pragma: no cover diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py index 2a799044..74743eff 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_semantics.py @@ -445,3 +445,148 @@ class TestResponse(object): v = resp.get_cookies() assert len(v) == 1 assert v["foo"] == [["bar", odict.ODictCaseless()]] + + +class TestHeaders(object): + def _2host(self): + return semantics.Headers( + [ + ["Host", "example.com"], + ["host", "example.org"] + ] + ) + + def test_init(self): + h = semantics.Headers() + assert len(h) == 0 + + h = semantics.Headers([["Host", "example.com"]]) + assert len(h) == 1 + assert h["Host"] == "example.com" + + h = semantics.Headers(Host="example.com") + assert len(h) == 1 + assert h["Host"] == "example.com" + + h = semantics.Headers( + [["Host", "invalid"]], + Host="example.com" + ) + assert len(h) == 1 + assert h["Host"] == "example.com" + + h = semantics.Headers( + [["Host", "invalid"], ["Accept", "text/plain"]], + Host="example.com" + ) + assert len(h) == 2 + assert h["Host"] == "example.com" + assert h["Accept"] == "text/plain" + + def test_getitem(self): + h = semantics.Headers(Host="example.com") + assert h["Host"] == "example.com" + assert h["host"] == "example.com" + tutils.raises(KeyError, h.__getitem__, "Accept") + + h = self._2host() + assert h["Host"] == "example.com, example.org" + + def test_str(self): + h = semantics.Headers(Host="example.com") + assert str(h) == "Host: example.com" + + h = semantics.Headers([ + ["Host", "example.com"], + ["Accept", "text/plain"] + ]) + assert str(h) == "Host: example.com\r\nAccept: text/plain" + + def test_setitem(self): + h = semantics.Headers() + h["Host"] = "example.com" + assert "Host" in h + assert "host" in h + assert h["Host"] == "example.com" + + h["host"] = "example.org" + assert "Host" in h + assert "host" in h + assert h["Host"] == "example.org" + + h["accept"] = "text/plain" + assert len(h) == 2 + assert "Accept" in h + assert "Host" in h + + h = self._2host() + assert len(h.fields) == 2 + h["Host"] = "example.com" + assert len(h.fields) == 1 + assert "Host" in h + + def test_delitem(self): + h = semantics.Headers(Host="example.com") + assert len(h) == 1 + del h["host"] + assert len(h) == 0 + try: + del h["host"] + except KeyError: + assert True + else: + assert False + + h = self._2host() + del h["Host"] + assert len(h) == 0 + + def test_keys(self): + h = semantics.Headers(Host="example.com") + assert len(h.keys()) == 1 + assert h.keys()[0] == "Host" + + h = self._2host() + assert len(h.keys()) == 1 + assert h.keys()[0] == "Host" + + def test_eq_ne(self): + h1 = semantics.Headers(Host="example.com") + h2 = semantics.Headers(host="example.com") + assert not (h1 == h2) + assert h1 != h2 + + h1 = semantics.Headers(Host="example.com") + h2 = semantics.Headers(Host="example.com") + assert h1 == h2 + assert not (h1 != h2) + + assert h1 != None + + def test_get_all(self): + h = self._2host() + assert h.get_all("host") == ["example.com", "example.org"] + assert h.get_all("accept", 42) is 42 + + def test_set_all(self): + h = semantics.Headers(Host="example.com") + h.set_all("Accept", ["text/plain"]) + assert len(h) == 2 + assert "accept" in h + + h = self._2host() + h.set_all("Host", ["example.org"]) + assert h["host"] == "example.org" + + h.set_all("Host", ["example.org", "example.net"]) + assert h["host"] == "example.org, example.net" + + def test_state(self): + h = self._2host() + assert len(h.get_state()) == 2 + assert h == semantics.Headers.from_state(h.get_state()) + + h2 = semantics.Headers() + assert h != h2 + h2.load_state(h.get_state()) + assert h == h2 -- cgit v1.2.3 From 66ee1f465f6c492d5a4ff5659e6f0346fb243d67 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 5 Sep 2015 18:15:47 +0200 Subject: headers: adjust everything --- netlib/http/authentication.py | 4 +- netlib/http/exceptions.py | 18 --- netlib/http/http1/protocol.py | 41 +++--- netlib/http/http2/protocol.py | 44 +++--- netlib/http/semantics.py | 148 ++++++++++---------- netlib/tutils.py | 10 +- netlib/utils.py | 13 +- netlib/websockets/protocol.py | 28 ++-- netlib/wsgi.py | 22 +-- test/http/http1/test_protocol.py | 127 +++++++++-------- test/http/http2/test_protocol.py | 22 +-- test/http/test_authentication.py | 44 +++--- test/http/test_exceptions.py | 20 --- test/http/test_semantics.py | 273 +++++++++++++++++-------------------- test/test_utils.py | 25 ++-- test/test_wsgi.py | 8 +- test/websockets/test_websockets.py | 12 +- 17 files changed, 400 insertions(+), 459 deletions(-) diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 29b9eb3c..fe1f0d14 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -62,10 +62,10 @@ class BasicProxyAuth(NullProxyAuth): del headers[self.AUTH_HEADER] def authenticate(self, headers): - auth_value = headers.get(self.AUTH_HEADER, []) + auth_value = headers.get(self.AUTH_HEADER) if not auth_value: return False - parts = parse_http_basic_auth(auth_value[0]) + parts = parse_http_basic_auth(auth_value) if not parts: return False scheme, username, password = parts diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py index 987a7908..8a2bbebc 100644 --- a/netlib/http/exceptions.py +++ b/netlib/http/exceptions.py @@ -1,6 +1,3 @@ -from netlib import odict - - class HttpError(Exception): def __init__(self, code, message): @@ -10,18 +7,3 @@ class HttpError(Exception): class HttpErrorConnClosed(HttpError): pass - - -class HttpAuthenticationError(Exception): - - def __init__(self, auth_headers=None): - super(HttpAuthenticationError, self).__init__( - "Proxy Authentication Required" - ) - if isinstance(auth_headers, dict): - auth_headers = odict.ODictCaseless(auth_headers.items()) - self.headers = auth_headers - self.code = 407 - - def __repr__(self): - return "Proxy Authentication Required" diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 50975818..bf33a18e 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -3,8 +3,8 @@ import string import sys import time -from netlib import odict, utils, tcp, http -from netlib.http import semantics +from ... import utils, tcp, http +from .. import semantics, Headers from ..exceptions import * @@ -96,7 +96,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): if headers is None: raise HttpError(400, "Invalid headers") - expect_header = headers.get_first("expect", "").lower() + expect_header = headers.get("expect", "").lower() if expect_header == "100-continue" and httpversion == (1, 1): self.tcp_handler.wfile.write( 'HTTP/1.1 100 Continue\r\n' @@ -232,10 +232,9 @@ class HTTP1Protocol(semantics.ProtocolMixin): Read a set of headers. Stop once a blank line is reached. - Return a ODictCaseless object, or None if headers are invalid. + Return a Header object, or None if headers are invalid. """ ret = [] - name = '' while True: line = self.tcp_handler.rfile.readline() if not line or line == '\r\n' or line == '\n': @@ -254,7 +253,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): ret.append([name, value]) else: return None - return odict.ODictCaseless(ret) + return Headers(ret) def read_http_body(self, *args, **kwargs): @@ -272,7 +271,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): ): """ Read an HTTP message body: - headers: An ODictCaseless object + headers: A Header object limit: Size limit. is_request: True if the body to read belongs to a request, False otherwise @@ -356,7 +355,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None if "content-length" in headers: try: - size = int(headers["content-length"][0]) + size = int(headers["content-length"]) if size < 0: raise ValueError() return size @@ -369,9 +368,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): @classmethod def has_chunked_encoding(self, headers): - return "chunked" in [ - i.lower() for i in utils.get_header_tokens(headers, "transfer-encoding") - ] + return "chunked" in headers.get("transfer-encoding", "").lower() def _get_request_line(self): @@ -547,18 +544,20 @@ class HTTP1Protocol(semantics.ProtocolMixin): def _assemble_request_headers(self, request): headers = request.headers.copy() for k in request._headers_to_strip_off: - del headers[k] + headers.pop(k, None) if 'host' not in headers and request.scheme and request.host and request.port: - headers["Host"] = [utils.hostport(request.scheme, - request.host, - request.port)] + headers["Host"] = utils.hostport( + request.scheme, + request.host, + request.port + ) # If content is defined (i.e. not None or CONTENT_MISSING), we always # add a content-length header. if request.body or request.body == "": - headers["Content-Length"] = [str(len(request.body))] + headers["Content-Length"] = str(len(request.body)) - return headers.format() + return str(headers) def _assemble_response_first_line(self, response): return 'HTTP/%s.%s %s %s' % ( @@ -575,13 +574,13 @@ class HTTP1Protocol(semantics.ProtocolMixin): ): headers = response.headers.copy() for k in response._headers_to_strip_off: - del headers[k] + headers.pop(k, None) if not preserve_transfer_encoding: - del headers['Transfer-Encoding'] + headers.pop('Transfer-Encoding', None) # If body is defined (i.e. not None or CONTENT_MISSING), we always # add a content-length header. if response.body or response.body == "": - headers["Content-Length"] = [str(len(response.body))] + headers["Content-Length"] = str(len(response.body)) - return headers.format() + return str(headers) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index b297e0b8..f3254caa 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -3,7 +3,7 @@ import itertools import time from hpack.hpack import Encoder, Decoder -from netlib import http, utils, odict +from netlib import http, utils from netlib.http import semantics from . import frame @@ -85,10 +85,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): timestamp_end = time.time() - authority = headers.get_first(':authority', '') - method = headers.get_first(':method', 'GET') - scheme = headers.get_first(':scheme', 'https') - path = headers.get_first(':path', '/') + authority = headers.get(':authority', '') + method = headers.get(':method', 'GET') + scheme = headers.get(':scheme', 'https') + path = headers.get(':path', '/') host = None port = None @@ -161,7 +161,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): response = http.Response( (2, 0), - int(headers.get_first(':status')), + int(headers.get(':status', 502)), "", headers, body, @@ -181,16 +181,14 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = request.headers.copy() - if ':authority' not in headers.keys(): - headers.add(':authority', bytes(authority), prepend=True) - if ':scheme' not in headers.keys(): - headers.add(':scheme', bytes(request.scheme), prepend=True) - if ':path' not in headers.keys(): - headers.add(':path', bytes(request.path), prepend=True) - if ':method' not in headers.keys(): - headers.add(':method', bytes(request.method), prepend=True) - - headers = headers.items() + if ':authority' not in headers: + headers.fields.insert(0, (':authority', bytes(authority))) + if ':scheme' not in headers: + headers.fields.insert(0, (':scheme', bytes(request.scheme))) + if ':path' not in headers: + headers.fields.insert(0, (':path', bytes(request.path))) + if ':method' not in headers: + headers.fields.insert(0, (':method', bytes(request.method))) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -206,10 +204,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = response.headers.copy() - if ':status' not in headers.keys(): - headers.add(':status', bytes(str(response.status_code)), prepend=True) - - headers = headers.items() + if ':status' not in headers: + headers.fields.insert(0, (':status', bytes(str(response.status_code)))) if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -329,7 +325,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: yield frame.ContinuationFrame, i - header_block_fragment = self.encoder.encode(headers) + header_block_fragment = self.encoder.encode(headers.fields) chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] chunks = range(0, len(header_block_fragment), chunk_size) @@ -402,8 +398,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: self._handle_unexpected_frame(frm) - headers = odict.ODictCaseless() - for header, value in self.decoder.decode(header_block_fragment): - headers.add(header, value) + headers = http.Headers( + [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] + ) return stream_id, headers, body diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 2fadf2c4..edf5fc07 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -1,9 +1,10 @@ from __future__ import (absolute_import, print_function, division) import UserDict +import copy import urllib import urlparse -from .. import utils, odict +from .. import odict from . import cookies, exceptions from netlib import utils, encoding @@ -77,11 +78,11 @@ class Headers(UserDict.DictMixin): headers = { name.replace("_", "-"): value for name, value in headers.iteritems() - } + } self.update(headers) def __str__(self): - return "\r\n".join(": ".join(field) for field in self.fields) + return "\r\n".join(": ".join(field) for field in self.fields) + "\r\n" def __getitem__(self, name): values = self.get_all(name) @@ -107,7 +108,7 @@ class Headers(UserDict.DictMixin): self.fields = [ field for field in self.fields if name != field[0].lower() - ] + ] def _index(self, name): name = name.lower() @@ -134,7 +135,7 @@ class Headers(UserDict.DictMixin): def __ne__(self, other): return not self.__eq__(other) - def get_all(self, name, default=None): + def get_all(self, name, default=[]): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. This is useful for Set-Cookie headers, which do not support folding. @@ -156,6 +157,9 @@ class Headers(UserDict.DictMixin): [name, value] for value in values ) + def copy(self): + return Headers(copy.copy(self.fields)) + # Implement the StateObject protocol from mitmproxy def get_state(self, short=False): return tuple(tuple(field) for field in self.fields) @@ -202,23 +206,23 @@ class Request(object): ] def __init__( - self, - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - form_out=None + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers=None, + body=None, + timestamp_start=None, + timestamp_end=None, + form_out=None ): if not headers: - headers = odict.ODictCaseless() - assert isinstance(headers, odict.ODictCaseless) + headers = Headers() + assert isinstance(headers, Headers) self.form_in = form_in self.method = method @@ -235,8 +239,10 @@ class Request(object): def __eq__(self, other): try: - self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] return self_d == other_d except: return False @@ -289,30 +295,35 @@ class Request(object): "if-none-match", ] for i in delheaders: - del self.headers[i] + self.headers.pop(i, None) def anticomp(self): """ Modifies this request to remove headers that will compress the resource's data. """ - self.headers["accept-encoding"] = ["identity"] + self.headers["accept-encoding"] = "identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - if self.headers["accept-encoding"]: - self.headers["accept-encoding"] = [ + accept_encoding = self.headers.get("accept-encoding") + if accept_encoding: + self.headers["accept-encoding"] = ( ', '.join( - e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))] + e + for e in encoding.ENCODINGS + if e in accept_encoding + ) + ) def update_host_header(self): """ Update the host header to reflect the current target. """ - self.headers["Host"] = [self.host] + self.headers["Host"] = self.host def get_form(self): """ @@ -321,9 +332,9 @@ class Request(object): indicates non-form data. """ if self.body: - if self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): + if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): return self.get_form_urlencoded() - elif self.headers.in_any("content-type", HDR_FORM_MULTIPART, True): + elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): return self.get_form_multipart() return odict.ODict([]) @@ -333,18 +344,12 @@ class Request(object): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and self.headers.in_any( - "content-type", - HDR_FORM_URLENCODED, - True): + if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): return odict.ODict(utils.urldecode(self.body)) return odict.ODict([]) def get_form_multipart(self): - if self.body and self.headers.in_any( - "content-type", - HDR_FORM_MULTIPART, - True): + if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): return odict.ODict( utils.multipartdecode( self.headers, @@ -359,7 +364,7 @@ class Request(object): """ # FIXME: If there's an existing content-type header indicating a # url-encoded form, leave it alone. - self.headers["Content-Type"] = [HDR_FORM_URLENCODED] + self.headers["Content-Type"] = HDR_FORM_URLENCODED self.body = utils.urlencode(odict.lst) def get_path_components(self): @@ -418,7 +423,7 @@ class Request(object): """ host = None if hostheader: - host = self.headers.get_first("host") + host = self.headers.get("Host") if not host: host = self.host if host: @@ -442,7 +447,7 @@ class Request(object): Returns a possibly empty netlib.odict.ODict object. """ ret = odict.ODict() - for i in self.headers["cookie"]: + for i in self.headers.get_all("cookie"): ret.extend(cookies.parse_cookie_header(i)) return ret @@ -452,7 +457,7 @@ class Request(object): headers. """ v = cookies.format_cookie_header(odict) - self.headers["Cookie"] = [v] + self.headers["Cookie"] = v @property def url(self): @@ -491,18 +496,17 @@ class Request(object): class EmptyRequest(Request): - def __init__( - self, - form_in="", - method="", - scheme="", - host="", - port="", - path="", - httpversion=(0, 0), - headers=None, - body="" + self, + form_in="", + method="", + scheme="", + host="", + port="", + path="", + httpversion=(0, 0), + headers=None, + body="" ): super(EmptyRequest, self).__init__( form_in=form_in, @@ -512,7 +516,7 @@ class EmptyRequest(Request): port=port, path=path, httpversion=httpversion, - headers=(headers or odict.ODictCaseless()), + headers=headers, body=body, ) @@ -525,19 +529,19 @@ class Response(object): ] def __init__( - self, - httpversion, - status_code, - msg=None, - headers=None, - body=None, - sslinfo=None, - timestamp_start=None, - timestamp_end=None, + self, + httpversion, + status_code, + msg=None, + headers=None, + body=None, + sslinfo=None, + timestamp_start=None, + timestamp_end=None, ): if not headers: - headers = odict.ODictCaseless() - assert isinstance(headers, odict.ODictCaseless) + headers = Headers() + assert isinstance(headers, Headers) self.httpversion = httpversion self.status_code = status_code @@ -550,8 +554,10 @@ class Response(object): def __eq__(self, other): try: - self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] return self_d == other_d except: return False @@ -567,9 +573,7 @@ class Response(object): return "".format( status_code=self.status_code, msg=self.msg, - contenttype=self.headers.get_first( - "content-type", - "unknown content type"), + contenttype=self.headers.get("content-type", "unknown content type"), size=size) def get_cookies(self): @@ -582,7 +586,7 @@ class Response(object): attributes (e.g. HTTPOnly) are indicated by a Null value. """ ret = [] - for header in self.headers["set-cookie"]: + for header in self.headers.get_all("set-cookie"): v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v @@ -605,7 +609,7 @@ class Response(object): i[1][1] ) ) - self.headers["Set-Cookie"] = values + self.headers.set_all("Set-Cookie", values) @property def content(self): # pragma: no cover diff --git a/netlib/tutils.py b/netlib/tutils.py index 7434c108..951ef3d9 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -5,7 +5,7 @@ import time import shutil from contextlib import contextmanager -from netlib import tcp, utils, odict, http +from netlib import tcp, utils, http def treader(bytes): @@ -73,8 +73,8 @@ def treq(content="content", scheme="http", host="address", port=22): """ @return: libmproxy.protocol.http.HTTPRequest """ - headers = odict.ODictCaseless() - headers["header"] = ["qvalue"] + headers = http.Headers() + headers["header"] = "qvalue" req = http.Request( "relative", "GET", @@ -108,8 +108,8 @@ def tresp(content="message"): @return: libmproxy.protocol.http.HTTPResponse """ - headers = odict.ODictCaseless() - headers["header_response"] = ["svalue"] + headers = http.Headers() + headers["header_response"] = "svalue" resp = http.semantics.Response( (1, 1), diff --git a/netlib/utils.py b/netlib/utils.py index d6190673..aae187da 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -204,11 +204,10 @@ def get_header_tokens(headers, key): follow a pattern where each header line can containe comma-separated tokens, and headers can be set multiple times. """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks + if key not in headers: + return [] + tokens = headers[key].split(",") + return [token.strip() for token in tokens] def hostport(scheme, host, port): @@ -270,11 +269,11 @@ def parse_content_type(c): return ts[0].lower(), ts[1].lower(), d -def multipartdecode(hdrs, content): +def multipartdecode(headers, content): """ Takes a multipart boundary encoded string and returns list of (key, value) tuples. """ - v = hdrs.get_first("content-type") + v = headers.get("content-type") if v: v = parse_content_type(v) if not v: diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 6ce32eac..46c02875 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -1,10 +1,5 @@ -from __future__ import absolute_import -import base64 -import hashlib -import os -from netlib import odict -from netlib import utils + # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -18,6 +13,13 @@ from netlib import utils # The magic sha that websocket servers must know to prove they understand # RFC6455 +from __future__ import absolute_import +import base64 +import hashlib +import os +from ..http import Headers +from .. import utils + websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" @@ -66,11 +68,11 @@ class WebsocketsProtocol(object): specified, it is generated, and can be found in sec-websocket-key in the returned header set. - Returns an instance of ODictCaseless + Returns an instance of Headers """ if not key: key = base64.b64encode(os.urandom(16)).decode('utf-8') - return odict.ODictCaseless([ + return Headers([ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), (HEADER_WEBSOCKET_KEY, key), @@ -82,7 +84,7 @@ class WebsocketsProtocol(object): """ The server response is a valid HTTP 101 response. """ - return odict.ODictCaseless( + return Headers( [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), @@ -93,16 +95,16 @@ class WebsocketsProtocol(object): @classmethod def check_client_handshake(self, headers): - if headers.get_first("upgrade", None) != "websocket": + if headers.get("upgrade") != "websocket": return - return headers.get_first(HEADER_WEBSOCKET_KEY) + return headers.get(HEADER_WEBSOCKET_KEY) @classmethod def check_server_handshake(self, headers): - if headers.get_first("upgrade", None) != "websocket": + if headers.get("upgrade") != "websocket": return - return headers.get_first(HEADER_WEBSOCKET_ACCEPT) + return headers.get(HEADER_WEBSOCKET_ACCEPT) @classmethod diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 99afe00e..8a98884a 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -3,7 +3,7 @@ import cStringIO import urllib import time import traceback -from . import odict, tcp +from . import http, tcp class ClientConn(object): @@ -68,8 +68,8 @@ class WSGIAdaptor(object): 'SCRIPT_NAME': '', 'PATH_INFO': urllib.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], - 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], + 'CONTENT_TYPE': flow.request.headers.get('Content-Type', ''), + 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', ''), 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. @@ -115,12 +115,12 @@ class WSGIAdaptor(object): def write(data): if not state["headers_sent"]: soc.write("HTTP/1.1 %s\r\n" % state["status"]) - h = state["headers"] - if 'server' not in h: - h["Server"] = [self.sversion] - if 'date' not in h: - h["Date"] = [date_time_string()] - soc.write(h.format()) + headers = state["headers"] + if 'server' not in headers: + headers["Server"] = self.sversion + if 'date' not in headers: + headers["Date"] = date_time_string() + soc.write(str(headers)) soc.write("\r\n") state["headers_sent"] = True if data: @@ -137,7 +137,7 @@ class WSGIAdaptor(object): elif state["status"]: raise AssertionError('Response already started') state["status"] = status - state["headers"] = odict.ODictCaseless(headers) + state["headers"] = http.Headers(headers) return write errs = cStringIO.StringIO() @@ -149,7 +149,7 @@ class WSGIAdaptor(object): write(i) if not state["headers_sent"]: write("") - except Exception: + except Exception as e: try: s = traceback.format_exc() errs.write(s) diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 6704647f..f7c615bd 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -2,7 +2,7 @@ import cStringIO import textwrap from netlib import http, odict, tcp, tutils -from netlib.http import semantics +from netlib.http import semantics, Headers from netlib.http.http1 import HTTP1Protocol from ... import tservers @@ -29,164 +29,161 @@ def test_stripped_chunked_encoding_no_content(): """ r = tutils.treq(content="") - r.headers["Transfer-Encoding"] = ["chunked"] + r.headers["Transfer-Encoding"] = "chunked" assert "Content-Length" in mock_protocol()._assemble_request_headers(r) r = tutils.tresp(content="") - r.headers["Transfer-Encoding"] = ["chunked"] + r.headers["Transfer-Encoding"] = "chunked" assert "Content-Length" in mock_protocol()._assemble_response_headers(r) def test_has_chunked_encoding(): - h = odict.ODictCaseless() - assert not HTTP1Protocol.has_chunked_encoding(h) - h["transfer-encoding"] = ["chunked"] - assert HTTP1Protocol.has_chunked_encoding(h) + headers = http.Headers() + assert not HTTP1Protocol.has_chunked_encoding(headers) + headers["transfer-encoding"] = "chunked" + assert HTTP1Protocol.has_chunked_encoding(headers) def test_read_chunked(): - h = odict.ODictCaseless() - h["transfer-encoding"] = ["chunked"] + headers = http.Headers() + headers["transfer-encoding"] = "chunked" data = "1\r\na\r\n0\r\n" tutils.raises( "malformed chunked body", mock_protocol(data).read_http_body, - h, None, "GET", None, True + headers, None, "GET", None, True ) data = "1\r\na\r\n0\r\n\r\n" - assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a" + assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "a" data = "\r\n\r\n1\r\na\r\n0\r\n\r\n" - assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a" + assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "a" data = "\r\n" tutils.raises( "closed prematurely", mock_protocol(data).read_http_body, - h, None, "GET", None, True + headers, None, "GET", None, True ) data = "1\r\nfoo" tutils.raises( "malformed chunked body", mock_protocol(data).read_http_body, - h, None, "GET", None, True + headers, None, "GET", None, True ) data = "foo\r\nfoo" tutils.raises( http.HttpError, mock_protocol(data).read_http_body, - h, None, "GET", None, True + headers, None, "GET", None, True ) data = "5\r\naaaaa\r\n0\r\n\r\n" - tutils.raises("too large", mock_protocol(data).read_http_body, h, 2, "GET", None, True) + tutils.raises("too large", mock_protocol(data).read_http_body, headers, 2, "GET", None, True) def test_connection_close(): - h = odict.ODictCaseless() - assert HTTP1Protocol.connection_close((1, 0), h) - assert not HTTP1Protocol.connection_close((1, 1), h) + headers = Headers() + assert HTTP1Protocol.connection_close((1, 0), headers) + assert not HTTP1Protocol.connection_close((1, 1), headers) - h["connection"] = ["keep-alive"] - assert not HTTP1Protocol.connection_close((1, 1), h) + headers["connection"] = "keep-alive" + assert not HTTP1Protocol.connection_close((1, 1), headers) - h["connection"] = ["close"] - assert HTTP1Protocol.connection_close((1, 1), h) + headers["connection"] = "close" + assert HTTP1Protocol.connection_close((1, 1), headers) def test_read_http_body_request(): - h = odict.ODictCaseless() + headers = Headers() data = "testing" - assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "" + assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "" def test_read_http_body_response(): - h = odict.ODictCaseless() + headers = Headers() data = "testing" - assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing" + assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing" def test_read_http_body(): # test default case - h = odict.ODictCaseless() - h["content-length"] = [7] + headers = Headers() + headers["content-length"] = "7" data = "testing" - assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing" + assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing" # test content length: invalid header - h["content-length"] = ["foo"] + headers["content-length"] = "foo" data = "testing" tutils.raises( http.HttpError, mock_protocol(data).read_http_body, - h, None, "GET", 200, False + headers, None, "GET", 200, False ) # test content length: invalid header #2 - h["content-length"] = [-1] + headers["content-length"] = "-1" data = "testing" tutils.raises( http.HttpError, mock_protocol(data).read_http_body, - h, None, "GET", 200, False + headers, None, "GET", 200, False ) # test content length: content length > actual content - h["content-length"] = [5] + headers["content-length"] = "5" data = "testing" tutils.raises( http.HttpError, mock_protocol(data).read_http_body, - h, 4, "GET", 200, False + headers, 4, "GET", 200, False ) # test content length: content length < actual content data = "testing" - assert len(mock_protocol(data).read_http_body(h, None, "GET", 200, False)) == 5 + assert len(mock_protocol(data).read_http_body(headers, None, "GET", 200, False)) == 5 # test no content length: limit > actual content - h = odict.ODictCaseless() + headers = Headers() data = "testing" - assert len(mock_protocol(data).read_http_body(h, 100, "GET", 200, False)) == 7 + assert len(mock_protocol(data).read_http_body(headers, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content data = "testing" tutils.raises( http.HttpError, mock_protocol(data).read_http_body, - h, 4, "GET", 200, False + headers, 4, "GET", 200, False ) # test chunked - h = odict.ODictCaseless() - h["transfer-encoding"] = ["chunked"] + headers = Headers() + headers["transfer-encoding"] = "chunked" data = "5\r\naaaaa\r\n0\r\n\r\n" - assert mock_protocol(data).read_http_body(h, 100, "GET", 200, False) == "aaaaa" + assert mock_protocol(data).read_http_body(headers, 100, "GET", 200, False) == "aaaaa" def test_expected_http_body_size(): # gibber in the content-length field - h = odict.ODictCaseless() - h["content-length"] = ["foo"] - assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None + headers = Headers(content_length="foo") + assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) is None # negative number in the content-length field - h = odict.ODictCaseless() - h["content-length"] = ["-7"] - assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None + headers = Headers(content_length="-7") + assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) is None # explicit length - h = odict.ODictCaseless() - h["content-length"] = ["5"] - assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == 5 + headers = Headers(content_length="5") + assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) == 5 # no length - h = odict.ODictCaseless() - assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == -1 + headers = Headers() + assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) == -1 # no length request - h = odict.ODictCaseless() - assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0 + headers = Headers() + assert HTTP1Protocol.expected_http_body_size(headers, True, "GET", None) == 0 def test_get_request_line(): @@ -265,8 +262,8 @@ class TestReadHeaders: Header2: two \r\n """ - h = self._read(data) - assert h.lst == [["Header", "one"], ["Header2", "two"]] + headers = self._read(data) + assert headers.fields == [["Header", "one"], ["Header2", "two"]] def test_read_multi(self): data = """ @@ -274,8 +271,8 @@ class TestReadHeaders: Header: two \r\n """ - h = self._read(data) - assert h.lst == [["Header", "one"], ["Header", "two"]] + headers = self._read(data) + assert headers.fields == [["Header", "one"], ["Header", "two"]] def test_read_continued(self): data = """ @@ -284,8 +281,8 @@ class TestReadHeaders: Header2: three \r\n """ - h = self._read(data) - assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] + headers = self._read(data) + assert headers.fields == [["Header", "one\r\n two"], ["Header2", "three"]] def test_read_continued_err(self): data = "\tfoo: bar\r\n" @@ -389,7 +386,7 @@ class TestReadResponse(object): HTTP/1.1 200 """ assert self.tst(data, "GET", None) == http.Response( - (1, 1), 200, '', odict.ODictCaseless(), '' + (1, 1), 200, '', Headers(), '' ) def test_simple_message(self): @@ -397,7 +394,7 @@ class TestReadResponse(object): HTTP/1.1 200 OK """ assert self.tst(data, "GET", None) == http.Response( - (1, 1), 200, 'OK', odict.ODictCaseless(), '' + (1, 1), 200, 'OK', Headers(), '' ) def test_invalid_http_version(self): @@ -419,7 +416,7 @@ class TestReadResponse(object): HTTP/1.1 200 OK """ assert self.tst(data, "GET", None) == http.Response( - (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' + (1, 1), 100, 'CONTINUE', Headers(), '' ) def test_simple_body(self): diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 8810894f..2b7d7958 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,8 +1,8 @@ import OpenSSL import mock -from netlib import tcp, odict, http, tutils -from netlib.http import http2 +from netlib import tcp, http, tutils +from netlib.http import http2, Headers from netlib.http.http2 import HTTP2Protocol from netlib.http.http2.frame import * from ... import tservers @@ -229,11 +229,11 @@ class TestCreateHeaders(): c = tcp.TCPClient(("127.0.0.1", 0)) def test_create_headers(self): - headers = [ + headers = http.Headers([ (b':method', b'GET'), (b':path', b'index.html'), (b':scheme', b'https'), - (b'foo', b'bar')] + (b'foo', b'bar')]) bytes = HTTP2Protocol(self.c)._create_headers( headers, 1, end_stream=True) @@ -248,12 +248,12 @@ class TestCreateHeaders(): .decode('hex') def test_create_headers_multiple_frames(self): - headers = [ + headers = http.Headers([ (b':method', b'GET'), (b':path', b'/'), (b':scheme', b'https'), (b'foo', b'bar'), - (b'server', b'version')] + (b'server', b'version')]) protocol = HTTP2Protocol(self.c) protocol.http2_settings[SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] = 8 @@ -309,7 +309,7 @@ class TestReadRequest(tservers.ServerTestBase): req = protocol.read_request() assert req.stream_id - assert req.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] + assert req.headers.fields == [[':method', 'GET'], [':path', '/'], [':scheme', 'https']] assert req.body == b'foobar' @@ -415,7 +415,7 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.httpversion == (2, 0) assert resp.status_code == 200 assert resp.msg == "" - assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] + assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'foobar' assert resp.timestamp_end @@ -442,7 +442,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.httpversion == (2, 0) assert resp.status_code == 200 assert resp.msg == "" - assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] + assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'' @@ -490,7 +490,7 @@ class TestAssembleRequest(object): '', '/', (2, 0), - odict.ODictCaseless([('foo', 'bar')]), + http.Headers([('foo', 'bar')]), 'foobar', )) assert len(bytes) == 2 @@ -528,7 +528,7 @@ class TestAssembleResponse(object): (2, 0), 200, '', - odict.ODictCaseless([('foo', 'bar')]), + Headers(foo="bar"), 'foobar' )) assert len(bytes) == 2 diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py index 5261e029..17c91fe5 100644 --- a/test/http/test_authentication.py +++ b/test/http/test_authentication.py @@ -1,18 +1,18 @@ import binascii -from netlib import odict, http, tutils -from netlib.http import authentication +from netlib import tutils +from netlib.http import authentication, Headers def test_parse_http_basic_auth(): vals = ("basic", "foo", "bar") - assert http.authentication.parse_http_basic_auth( - http.authentication.assemble_http_basic_auth(*vals) + assert authentication.parse_http_basic_auth( + authentication.assemble_http_basic_auth(*vals) ) == vals - assert not http.authentication.parse_http_basic_auth("") - assert not http.authentication.parse_http_basic_auth("foo bar") + assert not authentication.parse_http_basic_auth("") + assert not authentication.parse_http_basic_auth("foo bar") v = "basic " + binascii.b2a_base64("foo") - assert not http.authentication.parse_http_basic_auth(v) + assert not authentication.parse_http_basic_auth(v) class TestPassManNonAnon: @@ -65,35 +65,35 @@ class TestBasicProxyAuth: def test_simple(self): ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") - h = odict.ODictCaseless() + headers = Headers() assert ba.auth_challenge_headers() - assert not ba.authenticate(h) + assert not ba.authenticate(headers) def test_authenticate_clean(self): ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") - hdrs = odict.ODictCaseless() + headers = Headers() vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] - assert ba.authenticate(hdrs) + headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) + assert ba.authenticate(headers) - ba.clean(hdrs) - assert not ba.AUTH_HEADER in hdrs + ba.clean(headers) + assert not ba.AUTH_HEADER in headers - hdrs[ba.AUTH_HEADER] = [""] - assert not ba.authenticate(hdrs) + headers[ba.AUTH_HEADER] = "" + assert not ba.authenticate(headers) - hdrs[ba.AUTH_HEADER] = ["foo"] - assert not ba.authenticate(hdrs) + headers[ba.AUTH_HEADER] = "foo" + assert not ba.authenticate(headers) vals = ("foo", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] - assert not ba.authenticate(hdrs) + headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) + assert not ba.authenticate(headers) ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] - assert not ba.authenticate(hdrs) + headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) + assert not ba.authenticate(headers) class Bunch: diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py index d7c438f7..49588d0a 100644 --- a/test/http/test_exceptions.py +++ b/test/http/test_exceptions.py @@ -1,26 +1,6 @@ from netlib.http.exceptions import * -from netlib import odict class TestHttpError: def test_simple(self): e = HttpError(404, "Not found") assert str(e) - -class TestHttpAuthenticationError: - def test_init(self): - headers = odict.ODictCaseless([("foo", "bar")]) - x = HttpAuthenticationError(headers) - assert str(x) - assert isinstance(x.headers, odict.ODictCaseless) - assert x.code == 407 - assert x.headers == headers - assert "foo" in x.headers.keys() - - def test_header_conversion(self): - headers = {"foo": "bar"} - x = HttpAuthenticationError(headers) - assert isinstance(x.headers, odict.ODictCaseless) - assert x.headers.lst == headers.items() - - def test_repr(self): - assert repr(HttpAuthenticationError()) == "Proxy Authentication Required" diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py index 74743eff..22fe992c 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_semantics.py @@ -33,7 +33,7 @@ class TestRequest(object): r = tutils.treq() assert repr(r) - def test_headers_odict(self): + def test_headers(self): tutils.raises(AssertionError, semantics.Request, 'form_in', 'method', @@ -54,7 +54,7 @@ class TestRequest(object): 'path', (1, 1), ) - assert isinstance(req.headers, odict.ODictCaseless) + assert isinstance(req.headers, http.Headers) def test_equal(self): a = tutils.treq() @@ -76,30 +76,30 @@ class TestRequest(object): def test_anticache(self): req = tutils.treq() - req.headers.add("If-Modified-Since", "foo") - req.headers.add("If-None-Match", "bar") + req.headers["If-Modified-Since"] = "foo" + req.headers["If-None-Match"] = "bar" req.anticache() assert "If-Modified-Since" not in req.headers assert "If-None-Match" not in req.headers def test_anticomp(self): req = tutils.treq() - req.headers.add("Accept-Encoding", "foobar") + req.headers["Accept-Encoding"] = "foobar" req.anticomp() - assert req.headers["Accept-Encoding"] == ["identity"] + assert req.headers["Accept-Encoding"] == "identity" def test_constrain_encoding(self): req = tutils.treq() - req.headers.add("Accept-Encoding", "identity, gzip, foo") + req.headers["Accept-Encoding"] = "identity, gzip, foo" req.constrain_encoding() - assert "foo" not in req.headers.get_first("Accept-Encoding") + assert "foo" not in req.headers["Accept-Encoding"] def test_update_host(self): req = tutils.treq() - req.headers.add("Host", "") + req.headers["Host"] = "" req.host = "foobar" req.update_host_header() - assert req.headers.get_first("Host") == "foobar" + assert req.headers["Host"] == "foobar" def test_get_form(self): req = tutils.treq() @@ -113,7 +113,7 @@ class TestRequest(object): req = tutils.treq() req.body = "foobar" - req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED] + req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED req.get_form() assert req.get_form_urlencoded.called assert not req.get_form_multipart.called @@ -123,7 +123,7 @@ class TestRequest(object): def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart): req = tutils.treq() req.body = "foobar" - req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART] + req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART req.get_form() assert not req.get_form_urlencoded.called assert req.get_form_multipart.called @@ -132,23 +132,25 @@ class TestRequest(object): req = tutils.treq("foobar") assert req.get_form_urlencoded() == odict.ODict() - req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED] + req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body)) def test_get_form_multipart(self): req = tutils.treq("foobar") assert req.get_form_multipart() == odict.ODict() - req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART] + req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART assert req.get_form_multipart() == odict.ODict( utils.multipartdecode( req.headers, - req.body)) + req.body + ) + ) def test_set_form_urlencoded(self): req = tutils.treq() req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')])) - assert req.headers.get_first("Content-Type") == semantics.HDR_FORM_URLENCODED + assert req.headers["Content-Type"] == semantics.HDR_FORM_URLENCODED assert req.body def test_get_path_components(self): @@ -176,7 +178,7 @@ class TestRequest(object): r = tutils.treq() assert r.pretty_host(True) == "address" assert r.pretty_host(False) == "address" - r.headers["host"] = ["other"] + r.headers["host"] = "other" assert r.pretty_host(True) == "other" assert r.pretty_host(False) == "address" r.host = None @@ -187,7 +189,7 @@ class TestRequest(object): assert r.pretty_host(False) is None # Invalid IDNA - r.headers["host"] = [".disqus.com"] + r.headers["host"] = ".disqus.com" assert r.pretty_host(True) == ".disqus.com" def test_pretty_url(self): @@ -201,49 +203,37 @@ class TestRequest(object): assert req.pretty_url(False) == "http://address:22/path" def test_get_cookies_none(self): - h = odict.ODictCaseless() + headers = http.Headers() r = tutils.treq() - r.headers = h + r.headers = headers assert len(r.get_cookies()) == 0 def test_get_cookies_single(self): - h = odict.ODictCaseless() - h["Cookie"] = ["cookiename=cookievalue"] r = tutils.treq() - r.headers = h + r.headers = http.Headers(cookie="cookiename=cookievalue") result = r.get_cookies() assert len(result) == 1 assert result['cookiename'] == ['cookievalue'] def test_get_cookies_double(self): - h = odict.ODictCaseless() - h["Cookie"] = [ - "cookiename=cookievalue;othercookiename=othercookievalue" - ] r = tutils.treq() - r.headers = h + r.headers = http.Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") result = r.get_cookies() assert len(result) == 2 assert result['cookiename'] == ['cookievalue'] assert result['othercookiename'] == ['othercookievalue'] def test_get_cookies_withequalsign(self): - h = odict.ODictCaseless() - h["Cookie"] = [ - "cookiename=coo=kievalue;othercookiename=othercookievalue" - ] r = tutils.treq() - r.headers = h + r.headers = http.Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") result = r.get_cookies() assert len(result) == 2 assert result['cookiename'] == ['coo=kievalue'] assert result['othercookiename'] == ['othercookievalue'] def test_set_cookies(self): - h = odict.ODictCaseless() - h["Cookie"] = ["cookiename=cookievalue"] r = tutils.treq() - r.headers = h + r.headers = http.Headers(cookie="cookiename=cookievalue") result = r.get_cookies() result["cookiename"] = ["foo"] r.set_cookies(result) @@ -348,7 +338,7 @@ class TestEmptyRequest(object): assert req class TestResponse(object): - def test_headers_odict(self): + def test_headers(self): tutils.raises(AssertionError, semantics.Response, (1, 1), 200, @@ -359,7 +349,7 @@ class TestResponse(object): (1, 1), 200, ) - assert isinstance(resp.headers, odict.ODictCaseless) + assert isinstance(resp.headers, http.Headers) def test_equal(self): a = tutils.tresp() @@ -374,32 +364,26 @@ class TestResponse(object): def test_repr(self): r = tutils.tresp() assert "unknown content type" in repr(r) - r.headers["content-type"] = ["foo"] + r.headers["content-type"] = "foo" assert "foo" in repr(r) assert repr(tutils.tresp(content=CONTENT_MISSING)) def test_get_cookies_none(self): - h = odict.ODictCaseless() resp = tutils.tresp() - resp.headers = h + resp.headers = http.Headers() assert not resp.get_cookies() def test_get_cookies_simple(self): - h = odict.ODictCaseless() - h["Set-Cookie"] = ["cookiename=cookievalue"] resp = tutils.tresp() - resp.headers = h + resp.headers = http.Headers(set_cookie="cookiename=cookievalue") result = resp.get_cookies() assert len(result) == 1 assert "cookiename" in result assert result["cookiename"][0] == ["cookievalue", odict.ODict()] def test_get_cookies_with_parameters(self): - h = odict.ODictCaseless() - h["Set-Cookie"] = [ - "cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly"] resp = tutils.tresp() - resp.headers = h + resp.headers = http.Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly") result = resp.get_cookies() assert len(result) == 1 assert "cookiename" in result @@ -412,12 +396,8 @@ class TestResponse(object): assert attrs["httponly"] == [None] def test_get_cookies_no_value(self): - h = odict.ODictCaseless() - h["Set-Cookie"] = [ - "cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/" - ] resp = tutils.tresp() - resp.headers = h + resp.headers = http.Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") result = resp.get_cookies() assert len(result) == 1 assert "cookiename" in result @@ -425,10 +405,11 @@ class TestResponse(object): assert len(result["cookiename"][0][1]) == 2 def test_get_cookies_twocookies(self): - h = odict.ODictCaseless() - h["Set-Cookie"] = ["cookiename=cookievalue", "othercookie=othervalue"] resp = tutils.tresp() - resp.headers = h + resp.headers = http.Headers([ + ["Set-Cookie", "cookiename=cookievalue"], + ["Set-Cookie", "othercookie=othervalue"] + ]) result = resp.get_cookies() assert len(result) == 2 assert "cookiename" in result @@ -453,140 +434,140 @@ class TestHeaders(object): [ ["Host", "example.com"], ["host", "example.org"] - ] + ] ) def test_init(self): - h = semantics.Headers() - assert len(h) == 0 + headers = semantics.Headers() + assert len(headers) == 0 - h = semantics.Headers([["Host", "example.com"]]) - assert len(h) == 1 - assert h["Host"] == "example.com" + headers = semantics.Headers([["Host", "example.com"]]) + assert len(headers) == 1 + assert headers["Host"] == "example.com" - h = semantics.Headers(Host="example.com") - assert len(h) == 1 - assert h["Host"] == "example.com" + headers = semantics.Headers(Host="example.com") + assert len(headers) == 1 + assert headers["Host"] == "example.com" - h = semantics.Headers( + headers = semantics.Headers( [["Host", "invalid"]], Host="example.com" ) - assert len(h) == 1 - assert h["Host"] == "example.com" + assert len(headers) == 1 + assert headers["Host"] == "example.com" - h = semantics.Headers( + headers = semantics.Headers( [["Host", "invalid"], ["Accept", "text/plain"]], Host="example.com" ) - assert len(h) == 2 - assert h["Host"] == "example.com" - assert h["Accept"] == "text/plain" + assert len(headers) == 2 + assert headers["Host"] == "example.com" + assert headers["Accept"] == "text/plain" def test_getitem(self): - h = semantics.Headers(Host="example.com") - assert h["Host"] == "example.com" - assert h["host"] == "example.com" - tutils.raises(KeyError, h.__getitem__, "Accept") + headers = semantics.Headers(Host="example.com") + assert headers["Host"] == "example.com" + assert headers["host"] == "example.com" + tutils.raises(KeyError, headers.__getitem__, "Accept") - h = self._2host() - assert h["Host"] == "example.com, example.org" + headers = self._2host() + assert headers["Host"] == "example.com, example.org" def test_str(self): - h = semantics.Headers(Host="example.com") - assert str(h) == "Host: example.com" + headers = semantics.Headers(Host="example.com") + assert str(headers) == "Host: example.com\r\n" - h = semantics.Headers([ + headers = semantics.Headers([ ["Host", "example.com"], ["Accept", "text/plain"] ]) - assert str(h) == "Host: example.com\r\nAccept: text/plain" + assert str(headers) == "Host: example.com\r\nAccept: text/plain\r\n" def test_setitem(self): - h = semantics.Headers() - h["Host"] = "example.com" - assert "Host" in h - assert "host" in h - assert h["Host"] == "example.com" - - h["host"] = "example.org" - assert "Host" in h - assert "host" in h - assert h["Host"] == "example.org" - - h["accept"] = "text/plain" - assert len(h) == 2 - assert "Accept" in h - assert "Host" in h - - h = self._2host() - assert len(h.fields) == 2 - h["Host"] = "example.com" - assert len(h.fields) == 1 - assert "Host" in h + headers = semantics.Headers() + headers["Host"] = "example.com" + assert "Host" in headers + assert "host" in headers + assert headers["Host"] == "example.com" + + headers["host"] = "example.org" + assert "Host" in headers + assert "host" in headers + assert headers["Host"] == "example.org" + + headers["accept"] = "text/plain" + assert len(headers) == 2 + assert "Accept" in headers + assert "Host" in headers + + headers = self._2host() + assert len(headers.fields) == 2 + headers["Host"] = "example.com" + assert len(headers.fields) == 1 + assert "Host" in headers def test_delitem(self): - h = semantics.Headers(Host="example.com") - assert len(h) == 1 - del h["host"] - assert len(h) == 0 + headers = semantics.Headers(Host="example.com") + assert len(headers) == 1 + del headers["host"] + assert len(headers) == 0 try: - del h["host"] + del headers["host"] except KeyError: assert True else: assert False - h = self._2host() - del h["Host"] - assert len(h) == 0 + headers = self._2host() + del headers["Host"] + assert len(headers) == 0 def test_keys(self): - h = semantics.Headers(Host="example.com") - assert len(h.keys()) == 1 - assert h.keys()[0] == "Host" + headers = semantics.Headers(Host="example.com") + assert len(headers.keys()) == 1 + assert headers.keys()[0] == "Host" - h = self._2host() - assert len(h.keys()) == 1 - assert h.keys()[0] == "Host" + headers = self._2host() + assert len(headers.keys()) == 1 + assert headers.keys()[0] == "Host" def test_eq_ne(self): - h1 = semantics.Headers(Host="example.com") - h2 = semantics.Headers(host="example.com") - assert not (h1 == h2) - assert h1 != h2 + headers1 = semantics.Headers(Host="example.com") + headers2 = semantics.Headers(host="example.com") + assert not (headers1 == headers2) + assert headers1 != headers2 - h1 = semantics.Headers(Host="example.com") - h2 = semantics.Headers(Host="example.com") - assert h1 == h2 - assert not (h1 != h2) + headers1 = semantics.Headers(Host="example.com") + headers2 = semantics.Headers(Host="example.com") + assert headers1 == headers2 + assert not (headers1 != headers2) - assert h1 != None + assert headers1 != 42 def test_get_all(self): - h = self._2host() - assert h.get_all("host") == ["example.com", "example.org"] - assert h.get_all("accept", 42) is 42 + headers = self._2host() + assert headers.get_all("host") == ["example.com", "example.org"] + assert headers.get_all("accept", 42) is 42 def test_set_all(self): - h = semantics.Headers(Host="example.com") - h.set_all("Accept", ["text/plain"]) - assert len(h) == 2 - assert "accept" in h + headers = semantics.Headers(Host="example.com") + headers.set_all("Accept", ["text/plain"]) + assert len(headers) == 2 + assert "accept" in headers - h = self._2host() - h.set_all("Host", ["example.org"]) - assert h["host"] == "example.org" + headers = self._2host() + headers.set_all("Host", ["example.org"]) + assert headers["host"] == "example.org" - h.set_all("Host", ["example.org", "example.net"]) - assert h["host"] == "example.org, example.net" + headers.set_all("Host", ["example.org", "example.net"]) + assert headers["host"] == "example.org, example.net" def test_state(self): - h = self._2host() - assert len(h.get_state()) == 2 - assert h == semantics.Headers.from_state(h.get_state()) - - h2 = semantics.Headers() - assert h != h2 - h2.load_state(h.get_state()) - assert h == h2 + headers = self._2host() + assert len(headers.get_state()) == 2 + assert headers == semantics.Headers.from_state(headers.get_state()) + + headers2 = semantics.Headers() + assert headers != headers2 + headers2.load_state(headers.get_state()) + assert headers == headers2 diff --git a/test/test_utils.py b/test/test_utils.py index fc7174d6..374d09ba 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,5 +1,5 @@ -from netlib import utils, odict, tutils - +from netlib import utils, tutils +from netlib.http import Headers def test_bidi(): b = utils.BiDi(a=1, b=2) @@ -88,20 +88,21 @@ def test_urldecode(): def test_get_header_tokens(): - h = odict.ODictCaseless() - assert utils.get_header_tokens(h, "foo") == [] - h["foo"] = ["bar"] - assert utils.get_header_tokens(h, "foo") == ["bar"] - h["foo"] = ["bar, voing"] - assert utils.get_header_tokens(h, "foo") == ["bar", "voing"] - h["foo"] = ["bar, voing", "oink"] - assert utils.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] + headers = Headers() + assert utils.get_header_tokens(headers, "foo") == [] + headers["foo"] = "bar" + assert utils.get_header_tokens(headers, "foo") == ["bar"] + headers["foo"] = "bar, voing" + assert utils.get_header_tokens(headers, "foo") == ["bar", "voing"] + headers.set_all("foo", ["bar, voing", "oink"]) + assert utils.get_header_tokens(headers, "foo") == ["bar", "voing", "oink"] def test_multipartdecode(): boundary = 'somefancyboundary' - headers = odict.ODict( - [('content-type', ('multipart/form-data; boundary=%s' % boundary))]) + headers = Headers( + content_type='multipart/form-data; boundary=%s' % boundary + ) content = "--{0}\n" \ "Content-Disposition: form-data; name=\"field1\"\n\n" \ "value1\n" \ diff --git a/test/test_wsgi.py b/test/test_wsgi.py index 41572d49..e26e1413 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -1,12 +1,12 @@ import cStringIO import sys -from netlib import wsgi, odict +from netlib import wsgi +from netlib.http import Headers def tflow(): - h = odict.ODictCaseless() - h["test"] = ["value"] - req = wsgi.Request("http", "GET", "/", h, "") + headers = Headers(test="value") + req = wsgi.Request("http", "GET", "/", headers, "") return wsgi.Flow(("127.0.0.1", 8888), req) diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index be87b20a..57cfd166 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -42,7 +42,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers(key) - self.wfile.write(headers.format() + "\r\n") + self.wfile.write(str(headers) + "\r\n") self.wfile.flush() self.handshake_done = True @@ -66,8 +66,8 @@ class WebSocketsClient(tcp.TCPClient): preamble = 'GET / HTTP/1.1' self.wfile.write(preamble + "\r\n") headers = self.protocol.client_handshake_headers() - self.client_nonce = headers.get_first("sec-websocket-key") - self.wfile.write(headers.format() + "\r\n") + self.client_nonce = headers["sec-websocket-key"] + self.wfile.write(str(headers) + "\r\n") self.wfile.flush() resp = http1_protocol.read_response("GET", None) @@ -145,13 +145,13 @@ class TestWebSockets(tservers.ServerTestBase): def test_check_server_handshake(self): headers = self.protocol.server_handshake_headers("key") assert self.protocol.check_server_handshake(headers) - headers["Upgrade"] = ["not_websocket"] + headers["Upgrade"] = "not_websocket" assert not self.protocol.check_server_handshake(headers) def test_check_client_handshake(self): headers = self.protocol.client_handshake_headers("key") assert self.protocol.check_client_handshake(headers) == "key" - headers["Upgrade"] = ["not_websocket"] + headers["Upgrade"] = "not_websocket" assert not self.protocol.check_client_handshake(headers) @@ -166,7 +166,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler): preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers("malformed key") - self.wfile.write(headers.format() + "\r\n") + self.wfile.write(str(headers) + "\r\n") self.wfile.flush() self.handshake_done = True -- cgit v1.2.3 From fc86bbd03e7806bf5d3dc0d226b607192642c810 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 8 Sep 2015 15:16:25 +0200 Subject: let Headers inherit from object fixes mitmproxy/mitmproxy#753 --- netlib/http/semantics.py | 6 +++--- test/http/test_semantics.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index edf5fc07..5bb098a7 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -14,7 +14,7 @@ HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 -class Headers(UserDict.DictMixin): +class Headers(object, UserDict.DictMixin): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. @@ -135,7 +135,7 @@ class Headers(UserDict.DictMixin): def __ne__(self, other): return not self.__eq__(other) - def get_all(self, name, default=[]): + def get_all(self, name): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. This is useful for Set-Cookie headers, which do not support folding. @@ -144,7 +144,7 @@ class Headers(UserDict.DictMixin): """ name = name.lower() values = [value for n, value in self.fields if n.lower() == name] - return values or default + return values def set_all(self, name, values): """ diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py index 22fe992c..6dcbbe07 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_semantics.py @@ -547,7 +547,7 @@ class TestHeaders(object): def test_get_all(self): headers = self._2host() assert headers.get_all("host") == ["example.com", "example.org"] - assert headers.get_all("accept", 42) is 42 + assert headers.get_all("accept") == [] def test_set_all(self): headers = semantics.Headers(Host="example.com") -- cgit v1.2.3 From 32b3c32138847cb1f5b0c1958fc9ad0a49f8810f Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 8 Sep 2015 21:31:27 +0200 Subject: add tcp.Address.__hash__ --- netlib/tcp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 0d83816b..5c9d26de 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -352,6 +352,9 @@ class Address(object): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + return hash(self.address) ^ 42 # different hash than the tuple alone. + def close_socket(sock): """ -- cgit v1.2.3 From a5f7752cf18a9c6b34916107abc89bbdf0050566 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 10 Sep 2015 11:30:17 +0200 Subject: add ssl_read_select --- netlib/tcp.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 5c9d26de..e9610099 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -356,6 +356,27 @@ class Address(object): return hash(self.address) ^ 42 # different hash than the tuple alone. +def ssl_read_select(rlist, timeout): + """ + This is a wrapper around select.select() which also works for SSL.Connections + by taking ssl_connection.pending() into account. + + Caveats: + If .pending() > 0 for any of the connections in rlist, we avoid the select syscall + and **will not include any other connections which may or may not be ready**. + + Args: + rlist: wait until ready for reading + + Returns: + subset of rlist which is ready for reading. + """ + return [ + conn for conn in rlist + if isinstance(conn, SSL.Connection) and conn.pending() > 0 + ] or select.select(rlist, (), (), timeout)[0] + + def close_socket(sock): """ Does a hard close of a socket, without emitting a RST. -- cgit v1.2.3 From 92c763f469fdf721f3d981346f8a40e33b06de23 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 10 Sep 2015 12:32:38 +0200 Subject: fix mitmproxy/mitmproxy#759 --- netlib/version_check.py | 23 +++++++++++++++++------ test/test_version_check.py | 12 +++++++++++- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/netlib/version_check.py b/netlib/version_check.py index aae4e8c7..1d7e025c 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -3,10 +3,11 @@ Having installed a wrong version of pyOpenSSL or netlib is unfortunately a very common source of error. Check before every start that both versions are somewhat okay. """ -from __future__ import division, absolute_import, print_function, unicode_literals +from __future__ import division, absolute_import, print_function import sys import inspect import os.path + import OpenSSL from . import version @@ -28,19 +29,29 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): - v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) + min_version_str = ".".join(str(x) for x in min_version) + try: + v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) + except ValueError: + print( + "Cannot parse pyOpenSSL version: {}" + "mitmproxy requires pyOpenSSL {} or greater.".format( + OpenSSL.__version__, min_version_str + ), + file=fp + ) + return if v < min_version: print( - "You are using an outdated version of pyOpenSSL:" - " mitmproxy requires pyOpenSSL %s or greater." % - str(min_version), + "You are using an outdated version of pyOpenSSL: " + "mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), file=fp ) # Some users apparently have multiple versions of pyOpenSSL installed. # Report which one we got. pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL)) print( - "Your pyOpenSSL %s installation is located at %s" % ( + "Your pyOpenSSL {} installation is located at {}".format( OpenSSL.__version__, pyopenssl_path ), file=fp diff --git a/test/test_version_check.py b/test/test_version_check.py index a16969d2..9a127814 100644 --- a/test/test_version_check.py +++ b/test/test_version_check.py @@ -24,5 +24,15 @@ def test_check_pyopenssl_version(sexit): assert not sexit.called version_check.check_pyopenssl_version((9999,), fp=fp) - assert fp.getvalue() + assert "outdated" in fp.getvalue() assert sexit.called + + +@mock.patch("sys.exit") +@mock.patch("OpenSSL.__version__") +def test_unparseable_pyopenssl_version(version, sexit): + version.split.return_value = ["foo", "bar"] + fp = cStringIO.StringIO() + version_check.check_pyopenssl_version(fp=fp) + assert "Cannot parse" in fp.getvalue() + assert not sexit.called -- cgit v1.2.3 From a38142d5950a899c6e3f854841a45f4785515761 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 11 Sep 2015 01:17:39 +0200 Subject: don't yield empty chunks --- netlib/http/http1/protocol.py | 2 +- netlib/tcp.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index bf33a18e..cf1dffa3 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -413,9 +413,9 @@ class HTTP1Protocol(semantics.ProtocolMixin): suffix = self.tcp_handler.rfile.readline(5) if suffix != '\r\n': raise HttpError(code, "Malformed chunked body") - yield chunk if length == 0: return + yield chunk @classmethod def _parse_http_protocol(self, line): diff --git a/netlib/tcp.py b/netlib/tcp.py index e9610099..4a7f6153 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -8,6 +8,7 @@ import time import traceback import certifi +import six import OpenSSL from OpenSSL import SSL @@ -295,7 +296,7 @@ class Reader(_FileLike): self.o._raise_ssl_error(self.o._ssl, result) return SSL._ffi.buffer(buf, result)[:] except SSL.Error as e: - raise NetLibSSLError(repr(e)) + six.reraise(NetLibSSLError, NetLibSSLError(str(e)), sys.exc_info()[2]) else: raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -- cgit v1.2.3 From 997fcde8ce94be9d8decddd4bc783106dbb41ab3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 12 Sep 2015 17:03:09 +0200 Subject: make clean_bin unicode-aware --- netlib/utils.py | 39 +++++++++++++++++++++++++-------------- netlib/websockets/frame.py | 2 +- setup.py | 1 + test/test_utils.py | 15 +++++++++++---- 4 files changed, 38 insertions(+), 19 deletions(-) diff --git a/netlib/utils.py b/netlib/utils.py index aae187da..d6774419 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -5,6 +5,8 @@ import urllib import urlparse import string import re +import six +import unicodedata def isascii(s): @@ -20,22 +22,31 @@ def bytes_to_int(i): return int(i.encode('hex'), 16) -def cleanBin(s, fixspacing=False): +def clean_bin(s, keep_spacing=True): """ - Cleans binary data to make it safe to display. If fixspacing is True, - tabs, newlines and so forth will be maintained, if not, they will be - replaced with a placeholder. + Cleans binary data to make it safe to display. + + Args: + keep_spacing: If False, tabs and newlines will also be replaced. """ - parts = [] - for i in s: - o = ord(i) - if (o > 31 and o < 127): - parts.append(i) - elif i in "\n\t" and not fixspacing: - parts.append(i) + if isinstance(s, six.text_type): + if keep_spacing: + keep = u" \n\r\t" + else: + keep = u" " + return u"".join( + ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." + for ch in s + ) + else: + if keep_spacing: + keep = b"\n\r\t" else: - parts.append(".") - return "".join(parts) + keep = b"" + return b"".join( + ch if (31 < ord(ch) < 127 or ch in keep) else b"." + for ch in s + ) def hexdump(s): @@ -52,7 +63,7 @@ def hexdump(s): x += " " x += " ".join(" " for i in range(16 - len(part))) parts.append( - (o, x, cleanBin(part, True)) + (o, x, clean_bin(part, False)) ) return parts diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index 1c4a03b2..e3ff1405 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -236,7 +236,7 @@ class Frame(object): def human_readable(self): ret = self.header.human_readable() if self.payload: - ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) + ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload) return ret def __repr__(self): diff --git a/setup.py b/setup.py index a4da6e69..c24d37c0 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ setup( "cryptography>=1.0", "passlib>=1.6.2", "hpack>=1.0.1", + "six>=1.9.0", "certifi" ], extras_require={ diff --git a/test/test_utils.py b/test/test_utils.py index 374d09ba..9dba5d35 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -15,10 +15,17 @@ def test_hexdump(): def test_cleanBin(): - assert utils.cleanBin("one") == "one" - assert utils.cleanBin("\00ne") == ".ne" - assert utils.cleanBin("\nne") == "\nne" - assert utils.cleanBin("\nne", True) == ".ne" + assert utils.clean_bin(b"one") == b"one" + assert utils.clean_bin(b"\00ne") == b".ne" + assert utils.clean_bin(b"\nne") == b"\nne" + assert utils.clean_bin(b"\nne", False) == b".ne" + assert utils.clean_bin(u"\u2605".encode("utf8")) == b"..." + + assert utils.clean_bin(u"one") == u"one" + assert utils.clean_bin(u"\00ne") == u".ne" + assert utils.clean_bin(u"\nne") == u"\nne" + assert utils.clean_bin(u"\nne", False) == u".ne" + assert utils.clean_bin(u"\u2605") == u"\u2605" def test_pretty_size(): -- cgit v1.2.3 From 2f9c566e480c377566a0ae044d698a75b45cd54c Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 13 Sep 2015 14:33:45 +0200 Subject: remove pathod as dependency --- requirements.txt | 3 +-- setup.py | 3 --- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index e3ef3a23..aefbcb6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ --e git+https://github.com/mitmproxy/pathod.git#egg=pathod --e .[dev] \ No newline at end of file +-e .[dev] diff --git a/setup.py b/setup.py index c24d37c0..d3c09ceb 100644 --- a/setup.py +++ b/setup.py @@ -58,9 +58,6 @@ setup( "autopep8>=1.0.3", "autoflake>=0.6.6", "wheel>=0.24.0", - "pathod>=%s, <%s" % - (version.MINORVERSION, - version.NEXT_MINORVERSION) ] }, ) -- cgit v1.2.3 From 11e7f476bd4bbcd6d072fa3659f628ae3a19705d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 15 Sep 2015 19:12:15 +0200 Subject: wip --- netlib/encoding.py | 8 +- netlib/exceptions.py | 31 ++ netlib/http/__init__.py | 9 +- netlib/http/authentication.py | 4 +- netlib/http/exceptions.py | 9 - netlib/http/http1/__init__.py | 23 +- netlib/http/http1/assemble.py | 105 +++++++ netlib/http/http1/protocol.py | 586 ------------------------------------ netlib/http/http1/read.py | 346 +++++++++++++++++++++ netlib/http/http2/__init__.py | 2 - netlib/http/http2/connections.py | 412 +++++++++++++++++++++++++ netlib/http/http2/frame.py | 633 --------------------------------------- netlib/http/http2/frames.py | 633 +++++++++++++++++++++++++++++++++++++++ netlib/http/http2/protocol.py | 412 ------------------------- netlib/http/models.py | 571 +++++++++++++++++++++++++++++++++++ netlib/http/semantics.py | 632 -------------------------------------- netlib/tcp.py | 8 +- netlib/tutils.py | 70 +++-- netlib/utils.py | 162 ++++++---- netlib/version_check.py | 17 +- netlib/websockets/__init__.py | 4 +- test/http/http1/test_protocol.py | 159 ++++------ test/http/http2/test_frames.py | 4 +- test/http/test_authentication.py | 2 +- test/http/test_semantics.py | 2 +- test/test_encoding.py | 24 +- test/test_utils.py | 75 ++--- test/test_version_check.py | 8 +- test/tservers.py | 8 +- 29 files changed, 2426 insertions(+), 2533 deletions(-) create mode 100644 netlib/exceptions.py delete mode 100644 netlib/http/exceptions.py create mode 100644 netlib/http/http1/assemble.py delete mode 100644 netlib/http/http1/protocol.py create mode 100644 netlib/http/http1/read.py create mode 100644 netlib/http/http2/connections.py delete mode 100644 netlib/http/http2/frame.py create mode 100644 netlib/http/http2/frames.py delete mode 100644 netlib/http/http2/protocol.py create mode 100644 netlib/http/models.py delete mode 100644 netlib/http/semantics.py diff --git a/netlib/encoding.py b/netlib/encoding.py index f107eb5f..06830f2c 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -2,13 +2,13 @@ Utility functions for decoding response bodies. """ from __future__ import absolute_import -import cStringIO +from io import BytesIO import gzip import zlib __ALL__ = ["ENCODINGS"] -ENCODINGS = set(["identity", "gzip", "deflate"]) +ENCODINGS = {"identity", "gzip", "deflate"} def decode(e, content): @@ -42,7 +42,7 @@ def identity(content): def decode_gzip(content): - gfile = gzip.GzipFile(fileobj=cStringIO.StringIO(content)) + gfile = gzip.GzipFile(fileobj=BytesIO(content)) try: return gfile.read() except (IOError, EOFError): @@ -50,7 +50,7 @@ def decode_gzip(content): def encode_gzip(content): - s = cStringIO.StringIO() + s = BytesIO() gf = gzip.GzipFile(fileobj=s, mode='wb') gf.write(content) gf.close() diff --git a/netlib/exceptions.py b/netlib/exceptions.py new file mode 100644 index 00000000..637be3df --- /dev/null +++ b/netlib/exceptions.py @@ -0,0 +1,31 @@ +""" +We try to be very hygienic regarding the exceptions we throw: +Every Exception netlib raises shall be a subclass of NetlibException. + + +See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/ +""" +from __future__ import absolute_import, print_function, division + + +class NetlibException(Exception): + """ + Base class for all exceptions thrown by netlib. + """ + def __init__(self, message=None): + super(NetlibException, self).__init__(message) + + +class ReadDisconnect(object): + """Immediate EOF""" + + +class HttpException(NetlibException): + pass + + +class HttpReadDisconnect(HttpException, ReadDisconnect): + pass + +class HttpSyntaxException(HttpException): + pass diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 9b4b0e6b..0b1a0bc5 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,2 +1,7 @@ -from exceptions import * -from semantics import * +from .models import Request, Response, Headers, CONTENT_MISSING +from . import http1, http2 + +__all__ = [ + "Request", "Response", "Headers", "CONTENT_MISSING" + "http1", "http2" +] diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index fe1f0d14..2055f843 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -19,8 +19,8 @@ def parse_http_basic_auth(s): def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v + v = binascii.b2a_base64(username + b":" + password) + return scheme + b" " + v class NullProxyAuth(object): diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py deleted file mode 100644 index 8a2bbebc..00000000 --- a/netlib/http/exceptions.py +++ /dev/null @@ -1,9 +0,0 @@ -class HttpError(Exception): - - def __init__(self, code, message): - super(HttpError, self).__init__(message) - self.code = code - - -class HttpErrorConnClosed(HttpError): - pass diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py index 6b5043af..4d223f97 100644 --- a/netlib/http/http1/__init__.py +++ b/netlib/http/http1/__init__.py @@ -1 +1,22 @@ -from protocol import * +from .read import ( + read_request, read_request_head, + read_response, read_response_head, + read_message_body, read_message_body_chunked, + connection_close, + expected_http_body_size, +) +from .assemble import ( + assemble_request, assemble_request_head, + assemble_response, assemble_response_head, +) + + +__all__ = [ + "read_request", "read_request_head", + "read_response", "read_response_head", + "read_message_body", "read_message_body_chunked", + "connection_close", + "expected_http_body_size", + "assemble_request", "assemble_request_head", + "assemble_response", "assemble_response_head", +] diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py new file mode 100644 index 00000000..a3269eed --- /dev/null +++ b/netlib/http/http1/assemble.py @@ -0,0 +1,105 @@ +from __future__ import absolute_import, print_function, division + +from ... import utils +from ...exceptions import HttpException +from .. import CONTENT_MISSING + + +def assemble_request(request): + if request.body == CONTENT_MISSING: + raise HttpException("Cannot assemble flow with CONTENT_MISSING") + head = assemble_request_head(request) + return head + request.body + + +def assemble_request_head(request): + first_line = _assemble_request_line(request) + headers = _assemble_request_headers(request) + return b"%s\r\n%s\r\n" % (first_line, headers) + + +def assemble_response(response): + if response.body == CONTENT_MISSING: + raise HttpException("Cannot assemble flow with CONTENT_MISSING") + head = assemble_response_head(response) + return head + response.body + + +def assemble_response_head(response): + first_line = _assemble_response_line(response) + headers = _assemble_response_headers(response) + return b"%s\r\n%s\r\n" % (first_line, headers) + + + + +def _assemble_request_line(request, form=None): + if form is None: + form = request.form_out + if form == "relative": + return b"%s %s %s" % ( + request.method, + request.path, + request.httpversion + ) + elif form == "authority": + return b"%s %s:%d %s" % ( + request.method, + request.host, + request.port, + request.httpversion + ) + elif form == "absolute": + return b"%s %s://%s:%s%s %s" % ( + request.method, + request.scheme, + request.host, + request.port, + request.path, + request.httpversion + ) + else: # pragma: nocover + raise RuntimeError("Invalid request form") + + +def _assemble_request_headers(request): + headers = request.headers.copy() + for k in request._headers_to_strip_off: + headers.pop(k, None) + if b"host" not in headers and request.scheme and request.host and request.port: + headers[b"Host"] = utils.hostport( + request.scheme, + request.host, + request.port + ) + + # If content is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if request.body or request.body == b"": + headers[b"Content-Length"] = str(len(request.body)).encode("ascii") + + return str(headers) + + +def _assemble_response_line(response): + return b"%s %s %s" % ( + response.httpversion, + response.status_code, + response.msg, + ) + + +def _assemble_response_headers(response, preserve_transfer_encoding=False): + # TODO: Remove preserve_transfer_encoding + headers = response.headers.copy() + for k in response._headers_to_strip_off: + headers.pop(k, None) + if not preserve_transfer_encoding: + headers.pop(b"Transfer-Encoding", None) + + # If body is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if response.body or response.body == b"": + headers[b"Content-Length"] = str(len(response.body)).encode("ascii") + + return bytes(headers) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py deleted file mode 100644 index cf1dffa3..00000000 --- a/netlib/http/http1/protocol.py +++ /dev/null @@ -1,586 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import string -import sys -import time - -from ... import utils, tcp, http -from .. import semantics, Headers -from ..exceptions import * - - -class TCPHandler(object): - - def __init__(self, rfile, wfile=None): - self.rfile = rfile - self.wfile = wfile - - -class HTTP1Protocol(semantics.ProtocolMixin): - - ALPN_PROTO_HTTP1 = 'http/1.1' - - def __init__(self, tcp_handler=None, rfile=None, wfile=None): - self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) - - def read_request( - self, - include_body=True, - body_size_limit=None, - allow_empty=False, - ): - """ - Parse an HTTP request from a file stream - - Args: - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - httpversion, host, port, scheme, method, path, headers, body = ( - None, None, None, None, None, None, None, None) - - request_line = self._get_request_line() - if not request_line: - if allow_empty: - return http.EmptyRequest() - else: - raise tcp.NetLibDisconnect() - - request_line_parts = self._parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method == 'CONNECT': - form_in = "authority" - r = self._parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, httpversion = r - path = None - else: - form_in = "absolute" - r = self._parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = self.read_headers() - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get("expect", "").lower() - if expect_header == "100-continue" and httpversion == (1, 1): - self.tcp_handler.wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - self.tcp_handler.wfile.flush() - del headers['expect'] - - if include_body: - body = self.read_http_body( - headers, - body_size_limit, - method, - None, - True - ) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - timestamp_end = time.time() - - return http.Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - body, - timestamp_start, - timestamp_end, - ) - - def read_response( - self, - request_method, - body_size_limit=None, - include_body=True, - ): - """ - Returns an http.Response - - By default, both response header and body are read. - If include_body=False is specified, body may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - line = self.tcp_handler.rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = self.tcp_handler.rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = self.parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = self._parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = self.read_headers() - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - body = self.read_http_body( - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None body means the body should be - # read separately - body = None - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - if include_body: - timestamp_end = time.time() - else: - timestamp_end = None - - return http.Response( - httpversion, - code, - msg, - headers, - body, - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - ) - - def assemble_request(self, request): - assert isinstance(request, semantics.Request) - - if request.body == semantics.CONTENT_MISSING: - raise http.HttpError( - 502, - "Cannot assemble flow with CONTENT_MISSING" - ) - first_line = self._assemble_request_first_line(request) - headers = self._assemble_request_headers(request) - return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) - - def assemble_response(self, response): - assert isinstance(response, semantics.Response) - - if response.body == semantics.CONTENT_MISSING: - raise http.HttpError( - 502, - "Cannot assemble flow with CONTENT_MISSING" - ) - first_line = self._assemble_response_first_line(response) - headers = self._assemble_response_headers(response) - return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) - - def read_headers(self): - """ - Read a set of headers. - Stop once a blank line is reached. - - Return a Header object, or None if headers are invalid. - """ - ret = [] - while True: - line = self.tcp_handler.rfile.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) - else: - return None - return Headers(ret) - - - def read_http_body(self, *args, **kwargs): - return "".join(self.read_http_body_chunked(*args, **kwargs)) - - - def read_http_body_chunked( - self, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None - ): - """ - Read an HTTP message body: - headers: A Header object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = self.expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if self.has_chunked_encoding(headers): - # Python 3: yield from - for x in self._read_chunked(limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = self.tcp_handler.rfile.read(chunk_size) - yield content - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = self.tcp_handler.rfile.read(chunk_size) - if not content: - return - yield content - bytes_left -= chunk_size - not_done = self.tcp_handler.rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - @classmethod - def expected_http_body_size( - self, - headers, - is_request, - request_method, - response_code, - ): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if self.has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - - @classmethod - def has_chunked_encoding(self, headers): - return "chunked" in headers.get("transfer-encoding", "").lower() - - - def _get_request_line(self): - """ - Get a line, possibly preceded by a blank. - """ - line = self.tcp_handler.rfile.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = self.tcp_handler.rfile.readline() - return line - - def _read_chunked(self, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = self.tcp_handler.rfile.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = self.tcp_handler.rfile.read(length) - suffix = self.tcp_handler.rfile.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - if length == 0: - return - yield chunk - - @classmethod - def _parse_http_protocol(self, line): - """ - Parse an HTTP protocol declaration. - Returns a (major, minor) tuple, or None. - """ - if not line.startswith("HTTP/"): - return None - _, version = line.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - @classmethod - def _parse_init(self, line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = self._parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - @classmethod - def _parse_init_connect(self, line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not utils.is_valid_port(port): - return None - if not utils.is_valid_host(host): - return None - return host, port, httpversion - - @classmethod - def _parse_init_proxy(self, line): - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = utils.parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - @classmethod - def _parse_init_http(self, line): - """ - Returns (method, url, httpversion) - """ - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - @classmethod - def connection_close(self, httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = utils.get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - return httpversion != (1, 1) - - @classmethod - def parse_response_line(self, line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - @classmethod - def _assemble_request_first_line(self, request): - return request.legacy_first_line() - - def _assemble_request_headers(self, request): - headers = request.headers.copy() - for k in request._headers_to_strip_off: - headers.pop(k, None) - if 'host' not in headers and request.scheme and request.host and request.port: - headers["Host"] = utils.hostport( - request.scheme, - request.host, - request.port - ) - - # If content is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if request.body or request.body == "": - headers["Content-Length"] = str(len(request.body)) - - return str(headers) - - def _assemble_response_first_line(self, response): - return 'HTTP/%s.%s %s %s' % ( - response.httpversion[0], - response.httpversion[1], - response.status_code, - response.msg, - ) - - def _assemble_response_headers( - self, - response, - preserve_transfer_encoding=False, - ): - headers = response.headers.copy() - for k in response._headers_to_strip_off: - headers.pop(k, None) - if not preserve_transfer_encoding: - headers.pop('Transfer-Encoding', None) - - # If body is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if response.body or response.body == "": - headers["Content-Length"] = str(len(response.body)) - - return str(headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py new file mode 100644 index 00000000..573bc739 --- /dev/null +++ b/netlib/http/http1/read.py @@ -0,0 +1,346 @@ +from __future__ import absolute_import, print_function, division +import time +import sys +import re + +from ... import utils +from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException +from .. import Request, Response, Headers + +ALPN_PROTO_HTTP1 = 'http/1.1' + + +def read_request(rfile, body_size_limit=None): + request = read_request_head(rfile) + request.body = read_message_body(rfile, request, limit=body_size_limit) + request.timestamp_end = time.time() + return request + + +def read_request_head(rfile): + """ + Parse an HTTP request head (request line + headers) from an input stream + + Args: + rfile: The input stream + body_size_limit (bool): Maximum body size + + Returns: + The HTTP request object + + Raises: + HttpReadDisconnect: If no bytes can be read from rfile. + HttpSyntaxException: If the input is invalid. + HttpException: A different error occured. + """ + timestamp_start = time.time() + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + form, method, scheme, host, port, path, http_version = _read_request_line(rfile) + headers = _read_headers(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = rfile.first_byte_timestamp + + return Request( + form, method, scheme, host, port, path, http_version, headers, None, timestamp_start + ) + + +def read_response(rfile, request, body_size_limit=None): + response = read_response_head(rfile) + response.body = read_message_body(rfile, request, response, body_size_limit) + response.timestamp_end = time.time() + return response + + +def read_response_head(rfile): + timestamp_start = time.time() + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + http_version, status_code, message = _read_response_line(rfile) + headers = _read_headers(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = rfile.first_byte_timestamp + + return Response( + http_version, + status_code, + message, + headers, + None, + timestamp_start + ) + + +def read_message_body(*args, **kwargs): + chunks = read_message_body_chunked(*args, **kwargs) + return b"".join(chunks) + + +def read_message_body_chunked(rfile, request, response=None, limit=None, max_chunk_size=None): + """ + Read an HTTP message body: + + Args: + If a request body should be read, only request should be passed. + If a response body should be read, both request and response should be passed. + + Raises: + HttpException + """ + if not response: + headers = request.headers + response_code = None + is_request = True + else: + headers = response.headers + response_code = response.status_code + is_request = False + + if not limit or limit < 0: + limit = sys.maxsize + if not max_chunk_size: + max_chunk_size = limit + + expected_size = expected_http_body_size( + headers, is_request, request.method, response_code + ) + + if expected_size is None: + for x in _read_chunked(rfile, limit): + yield x + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpException( + "HTTP Body too large. " + "Limit is {}, content length was advertised as {}".format(limit, expected_size) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + yield content + bytes_left -= chunk_size + else: + bytes_left = limit + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield content + bytes_left -= chunk_size + not_done = rfile.read(1) + if not_done: + raise HttpException("HTTP body too large. Limit is {}.".format(limit)) + + +def connection_close(http_version, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1. + """ + # At first, check if we have an explicit Connection header. + if b"connection" in headers: + toks = utils.get_header_tokens(headers, "connection") + if b"close" in toks: + return True + elif b"keep-alive" in toks: + return False + + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + return http_version != (1, 1) + + +def expected_http_body_size( + headers, + is_request, + request_method, + response_code, +): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding) + - -1, if all data should be read until end of stream. + + Raises: + HttpSyntaxException, if the content length header is invalid + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + is_empty_response = (not is_request and ( + request_method == b"HEAD" or + 100 <= response_code <= 199 or + (response_code == 200 and request_method == b"CONNECT") or + response_code in (204, 304) + )) + + if is_empty_response: + return 0 + if is_request and headers.get(b"expect", b"").lower() == b"100-continue": + return 0 + if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + return None + if b"content-length" in headers: + try: + size = int(headers[b"content-length"]) + if size < 0: + raise ValueError() + return size + except ValueError: + raise HttpSyntaxException("Unparseable Content Length") + if is_request: + return 0 + return -1 + + +def _get_first_line(rfile): + line = rfile.readline() + if line == b"\r\n" or line == b"\n": + # Possible leftover from previous message + line = rfile.readline() + if not line: + raise HttpReadDisconnect() + return line + + +def _read_request_line(rfile): + line = _get_first_line(rfile) + + try: + method, path, http_version = line.strip().split(b" ") + + if path == b"*" or path.startswith(b"/"): + form = "relative" + path.decode("ascii") # should not raise a ValueError + scheme, host, port = None, None, None + elif method == b"CONNECT": + form = "authority" + host, port = _parse_authority_form(path) + scheme, path = None, None + else: + form = "absolute" + scheme, host, port, path = utils.parse_url(path) + + except ValueError: + raise HttpSyntaxException("Bad HTTP request line: {}".format(line)) + + return form, method, scheme, host, port, path, http_version + + +def _parse_authority_form(hostport): + """ + Returns (host, port) if hostport is a valid authority-form host specification. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + + Raises: + ValueError, if the input is malformed + """ + try: + host, port = hostport.split(b":") + port = int(port) + if not utils.is_valid_host(host) or not utils.is_valid_port(port): + raise ValueError() + except ValueError: + raise ValueError("Invalid host specification: {}".format(hostport)) + + return host, port + + +def _read_response_line(rfile): + line = _get_first_line(rfile) + + try: + + parts = line.strip().split(b" ") + if len(parts) == 2: # handle missing message gracefully + parts.append(b"") + + http_version, status_code, message = parts + status_code = int(status_code) + _check_http_version(http_version) + + except ValueError: + raise HttpSyntaxException("Bad HTTP response line: {}".format(line)) + + return http_version, status_code, message + + +def _check_http_version(http_version): + if not re.match(rb"^HTTP/\d\.\d$", http_version): + raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) + + +def _read_headers(rfile): + """ + Read a set of headers. + Stop once a blank line is reached. + + Returns: + A headers object + + Raises: + HttpSyntaxException + """ + ret = [] + while True: + line = rfile.readline() + if not line or line == b"\r\n" or line == b"\n": + break + if line[0] in b" \t": + if not ret: + raise HttpSyntaxException("Invalid headers") + # continued header + ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() + else: + try: + name, value = line.split(b":", 1) + value = value.strip() + ret.append([name, value]) + except ValueError: + raise HttpSyntaxException("Invalid headers") + return Headers(ret) + + +def _read_chunked(rfile, limit): + """ + Read a HTTP body with chunked transfer encoding. + + Args: + rfile: the input file + limit: A positive integer + """ + total = 0 + while True: + line = rfile.readline(128) + if line == b"": + raise HttpException("Connection closed prematurely") + if line != b"\r\n" and line != b"\n": + try: + length = int(line, 16) + except ValueError: + raise HttpSyntaxException("Invalid chunked encoding length: {}".format(line)) + total += length + if total > limit: + raise HttpException( + "HTTP Body too large. Limit is {}, " + "chunked content longer than {}".format(limit, total) + ) + chunk = rfile.read(length) + suffix = rfile.readline(5) + if suffix != b"\r\n": + raise HttpSyntaxException("Malformed chunked body") + if length == 0: + return + yield chunk diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py index 5acf7696..e69de29b 100644 --- a/netlib/http/http2/__init__.py +++ b/netlib/http/http2/__init__.py @@ -1,2 +0,0 @@ -from frame import * -from protocol import * diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py new file mode 100644 index 00000000..b6d376d3 --- /dev/null +++ b/netlib/http/http2/connections.py @@ -0,0 +1,412 @@ +from __future__ import (absolute_import, print_function, division) +import itertools +import time + +from hpack.hpack import Encoder, Decoder +from netlib import http, utils +from netlib.http import semantics +from . import frame + + +class TCPHandler(object): + + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + + +class HTTP2Protocol(semantics.ProtocolMixin): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + + ALPN_PROTO_H2 = 'h2' + + def __init__( + self, + tcp_handler=None, + rfile=None, + wfile=None, + is_server=False, + dump_frames=False, + encoder=None, + decoder=None, + unhandled_frame_cb=None, + ): + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) + self.is_server = is_server + self.dump_frames = dump_frames + self.encoder = encoder or Encoder() + self.decoder = decoder or Decoder() + self.unhandled_frame_cb = unhandled_frame_cb + + self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.connection_preface_performed = False + + def read_request( + self, + include_body=True, + body_size_limit=None, + allow_empty=False, + ): + if body_size_limit is not None: + raise NotImplementedError() + + self.perform_connection_preface() + + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission( + include_body=include_body, + ) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + + authority = headers.get(':authority', '') + method = headers.get(':method', 'GET') + scheme = headers.get(':scheme', 'https') + path = headers.get(':path', '/') + host = None + port = None + + if path == '*' or path.startswith("/"): + form_in = "relative" + elif method == 'CONNECT': + form_in = "authority" + if ":" in authority: + host, port = authority.split(":", 1) + else: + host = authority + else: + form_in = "absolute" + # FIXME: verify if path or :host contains what we need + scheme, host, port, _ = utils.parse_url(path) + + if host is None: + host = 'localhost' + if port is None: + port = 80 if scheme == 'http' else 443 + port = int(port) + + request = http.Request( + form_in, + method, + scheme, + host, + port, + path, + (2, 0), + headers, + body, + timestamp_start, + timestamp_end, + ) + # FIXME: We should not do this. + request.stream_id = stream_id + + return request + + def read_response( + self, + request_method='', + body_size_limit=None, + include_body=True, + stream_id=None, + ): + if body_size_limit is not None: + raise NotImplementedError() + + self.perform_connection_preface() + + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission( + stream_id=stream_id, + include_body=include_body, + ) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + response = http.Response( + (2, 0), + int(headers.get(':status', 502)), + "", + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + response.stream_id = stream_id + + return response + + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = request.headers.copy() + + if ':authority' not in headers: + headers.fields.insert(0, (':authority', bytes(authority))) + if ':scheme' not in headers: + headers.fields.insert(0, (':scheme', bytes(request.scheme))) + if ':path' not in headers: + headers.fields.insert(0, (':path', bytes(request.path))) + if ':method' not in headers: + headers.fields.insert(0, (':method', bytes(request.method))) + + if hasattr(request, 'stream_id'): + stream_id = request.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), + self._create_body(request.body, stream_id))) + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + headers = response.headers.copy() + + if ':status' not in headers: + headers.fields.insert(0, (':status', bytes(str(response.status_code)))) + + if hasattr(response, 'stream_id'): + stream_id = response.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), + self._create_body(response.body, stream_id), + )) + + def perform_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + if self.is_server: + self.perform_server_connection_preface(force) + else: + self.perform_client_connection_preface(force) + + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE + + frm = frame.SettingsFrame(state=self, settings={ + frame.SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 0, + frame.SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 1, + }) + self.send_frame(frm, hide=True) + self._receive_settings(hide=True) + + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) # server announces own settings + self._receive_settings(hide=True) # server acks my settings + + def send_frame(self, frm, hide=False): + raw_bytes = frm.to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + def read_frame(self, hide=False): + while True: + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable("<<")) + + if isinstance(frm, frame.PingFrame): + raw_bytes = frame.PingFrame(flags=frame.Frame.FLAG_ACK, payload=frm.payload).to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + continue + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) + if isinstance(frm, frame.DataFrame) and frm.length > 0: + self._update_flow_control_window(frm.stream_id, frm.length) + return frm + + def check_alpn(self): + alp = self.tcp_handler.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + return True + + def _handle_unexpected_frame(self, frm): + if isinstance(frm, frame.SettingsFrame): + return + if self.unhandled_frame_cb: + self.unhandled_frame_cb(frm) + + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + else: + self._handle_unexpected_frame(frm) + + def _next_stream_id(self): + if self.current_stream_id is None: + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def _apply_settings(self, settings, hide=False): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + self.http2_settings[setting] = value + + frm = frame.SettingsFrame( + state=self, + flags=frame.Frame.FLAG_ACK) + self.send_frame(frm, hide) + + def _update_flow_control_window(self, stream_id, increment): + frm = frame.WindowUpdateFrame(stream_id=0, window_size_increment=increment) + self.send_frame(frm) + frm = frame.WindowUpdateFrame(stream_id=stream_id, window_size_increment=increment) + self.send_frame(frm) + + def _create_headers(self, headers, stream_id, end_stream=True): + def frame_cls(chunks): + for i in chunks: + if i == 0: + yield frame.HeadersFrame, i + else: + yield frame.ContinuationFrame, i + + header_block_fragment = self.encoder.encode(headers.fields) + + chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunks = range(0, len(header_block_fragment), chunk_size) + frms = [frm_cls( + state=self, + flags=frame.Frame.FLAG_NO_FLAGS, + stream_id=stream_id, + header_block_fragment=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] + + last_flags = frame.Frame.FLAG_END_HEADERS + if end_stream: + last_flags |= frame.Frame.FLAG_END_STREAM + frms[-1].flags = last_flags + + if self.dump_frames: # pragma no cover + for frm in frms: + print(frm.human_readable(">>")) + + return [frm.to_bytes() for frm in frms] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunks = range(0, len(body), chunk_size) + frms = [frame.DataFrame( + state=self, + flags=frame.Frame.FLAG_NO_FLAGS, + stream_id=stream_id, + payload=body[i:i+chunk_size]) for i in chunks] + frms[-1].flags = frame.Frame.FLAG_END_STREAM + + if self.dump_frames: # pragma no cover + for frm in frms: + print(frm.human_readable(">>")) + + return [frm.to_bytes() for frm in frms] + + def _receive_transmission(self, stream_id=None, include_body=True): + if not include_body: + raise NotImplementedError() + + body_expected = True + + header_block_fragment = b'' + body = b'' + + while True: + frm = self.read_frame() + if ( + (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and + (stream_id is None or frm.stream_id == stream_id) + ): + stream_id = frm.stream_id + header_block_fragment += frm.header_block_fragment + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False + if frm.flags & frame.Frame.FLAG_END_HEADERS: + break + else: + self._handle_unexpected_frame(frm) + + while body_expected: + frm = self.read_frame() + if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: + body += frm.payload + if frm.flags & frame.Frame.FLAG_END_STREAM: + break + else: + self._handle_unexpected_frame(frm) + + headers = http.Headers( + [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] + ) + + return stream_id, headers, body diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py deleted file mode 100644 index b36b3adf..00000000 --- a/netlib/http/http2/frame.py +++ /dev/null @@ -1,633 +0,0 @@ -import sys -import struct -from hpack.hpack import Encoder, Decoder - -from .. import utils - - -class FrameSizeError(Exception): - pass - - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - class State(object): - pass - - state = State() - state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(cls, length, state): - if state: - settings = state.http2_settings - else: - settings = HTTP2_DEFAULT_SETTINGS.copy() - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise FrameSizeError( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(cls, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - if raw_header[:4] == b'HTTP': # pragma no cover - print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" - - cls._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self, direction="-"): - self.length = len(self.payload_bytes()) - - return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), - self.payload_human_readable(), - "===============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & Frame.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Size Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - s = [] - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - return "\n".join(s) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} - - -HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, -} diff --git a/netlib/http/http2/frames.py b/netlib/http/http2/frames.py new file mode 100644 index 00000000..b36b3adf --- /dev/null +++ b/netlib/http/http2/frames.py @@ -0,0 +1,633 @@ +import sys +import struct +from hpack.hpack import Encoder, Decoder + +from .. import utils + + +class FrameSizeError(Exception): + pass + + +class Frame(object): + + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): + valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + if state is None: + class State(object): + pass + + state = State() + state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + + self.length = length + self.type = self.TYPE + self.flags = flags + self.stream_id = stream_id + + @classmethod + def _check_frame_size(cls, length, state): + if state: + settings = state.http2_settings + else: + settings = HTTP2_DEFAULT_SETTINGS.copy() + + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: + raise FrameSizeError( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_frame_size)) + + @classmethod + def from_file(cls, fp, state=None): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + if raw_header[:4] == b'HTTP': # pragma no cover + print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" + + cls._check_frame_size(length, state) + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + self._check_frame_size(self.length, self.state) + + b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self, direction="-"): + self.length = len(self.payload_bytes()) + + return "\n".join([ + "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + self.payload_human_readable(), + "===============================================================", + ]) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): + super(DataFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + return "payload: %s" % str(self.payload) + + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b'', + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(HeadersFrame, self).__init__(state, length, flags, stream_id) + + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & Frame.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack( + '!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) + + b += self.header_block_fragment + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(PriorityFrame, self).__init__(state, length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') + + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) + + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) + + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + + def payload_human_readable(self): + return "error code: %#x" % self.error_code + + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = utils.BiDi( + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): + super(SettingsFrame, self).__init__(state, length, flags, stream_id) + + if settings is None: + settings = {} + + self.settings = settings + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + for i in xrange(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i + 6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) + + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): + super(PingFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) + + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) + + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: + raise ValueError( + 'Window Size Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment + + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + + def payload_human_readable(self): + s = [] + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} + + +HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, +} diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py deleted file mode 100644 index b6d376d3..00000000 --- a/netlib/http/http2/protocol.py +++ /dev/null @@ -1,412 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import itertools -import time - -from hpack.hpack import Encoder, Decoder -from netlib import http, utils -from netlib.http import semantics -from . import frame - - -class TCPHandler(object): - - def __init__(self, rfile, wfile=None): - self.rfile = rfile - self.wfile = wfile - - -class HTTP2Protocol(semantics.ProtocolMixin): - - ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd - ) - - CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - - ALPN_PROTO_H2 = 'h2' - - def __init__( - self, - tcp_handler=None, - rfile=None, - wfile=None, - is_server=False, - dump_frames=False, - encoder=None, - decoder=None, - unhandled_frame_cb=None, - ): - self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) - self.is_server = is_server - self.dump_frames = dump_frames - self.encoder = encoder or Encoder() - self.decoder = decoder or Decoder() - self.unhandled_frame_cb = unhandled_frame_cb - - self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() - self.current_stream_id = None - self.connection_preface_performed = False - - def read_request( - self, - include_body=True, - body_size_limit=None, - allow_empty=False, - ): - if body_size_limit is not None: - raise NotImplementedError() - - self.perform_connection_preface() - - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission( - include_body=include_body, - ) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - timestamp_end = time.time() - - authority = headers.get(':authority', '') - method = headers.get(':method', 'GET') - scheme = headers.get(':scheme', 'https') - path = headers.get(':path', '/') - host = None - port = None - - if path == '*' or path.startswith("/"): - form_in = "relative" - elif method == 'CONNECT': - form_in = "authority" - if ":" in authority: - host, port = authority.split(":", 1) - else: - host = authority - else: - form_in = "absolute" - # FIXME: verify if path or :host contains what we need - scheme, host, port, _ = utils.parse_url(path) - - if host is None: - host = 'localhost' - if port is None: - port = 80 if scheme == 'http' else 443 - port = int(port) - - request = http.Request( - form_in, - method, - scheme, - host, - port, - path, - (2, 0), - headers, - body, - timestamp_start, - timestamp_end, - ) - # FIXME: We should not do this. - request.stream_id = stream_id - - return request - - def read_response( - self, - request_method='', - body_size_limit=None, - include_body=True, - stream_id=None, - ): - if body_size_limit is not None: - raise NotImplementedError() - - self.perform_connection_preface() - - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission( - stream_id=stream_id, - include_body=include_body, - ) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - if include_body: - timestamp_end = time.time() - else: - timestamp_end = None - - response = http.Response( - (2, 0), - int(headers.get(':status', 502)), - "", - headers, - body, - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - ) - response.stream_id = stream_id - - return response - - def assemble_request(self, request): - assert isinstance(request, semantics.Request) - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = request.headers.copy() - - if ':authority' not in headers: - headers.fields.insert(0, (':authority', bytes(authority))) - if ':scheme' not in headers: - headers.fields.insert(0, (':scheme', bytes(request.scheme))) - if ':path' not in headers: - headers.fields.insert(0, (':path', bytes(request.path))) - if ':method' not in headers: - headers.fields.insert(0, (':method', bytes(request.method))) - - if hasattr(request, 'stream_id'): - stream_id = request.stream_id - else: - stream_id = self._next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), - self._create_body(request.body, stream_id))) - - def assemble_response(self, response): - assert isinstance(response, semantics.Response) - - headers = response.headers.copy() - - if ':status' not in headers: - headers.fields.insert(0, (':status', bytes(str(response.status_code)))) - - if hasattr(response, 'stream_id'): - stream_id = response.stream_id - else: - stream_id = self._next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), - self._create_body(response.body, stream_id), - )) - - def perform_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - if self.is_server: - self.perform_server_connection_preface(force) - else: - self.perform_client_connection_preface(force) - - def perform_server_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - magic_length = len(self.CLIENT_CONNECTION_PREFACE) - magic = self.tcp_handler.rfile.safe_read(magic_length) - assert magic == self.CLIENT_CONNECTION_PREFACE - - frm = frame.SettingsFrame(state=self, settings={ - frame.SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 0, - frame.SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 1, - }) - self.send_frame(frm, hide=True) - self._receive_settings(hide=True) - - def perform_client_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) # server announces own settings - self._receive_settings(hide=True) # server acks my settings - - def send_frame(self, frm, hide=False): - raw_bytes = frm.to_bytes() - self.tcp_handler.wfile.write(raw_bytes) - self.tcp_handler.wfile.flush() - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - def read_frame(self, hide=False): - while True: - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable("<<")) - - if isinstance(frm, frame.PingFrame): - raw_bytes = frame.PingFrame(flags=frame.Frame.FLAG_ACK, payload=frm.payload).to_bytes() - self.tcp_handler.wfile.write(raw_bytes) - self.tcp_handler.wfile.flush() - continue - if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: - self._apply_settings(frm.settings, hide) - if isinstance(frm, frame.DataFrame) and frm.length > 0: - self._update_flow_control_window(frm.stream_id, frm.length) - return frm - - def check_alpn(self): - alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - return True - - def _handle_unexpected_frame(self, frm): - if isinstance(frm, frame.SettingsFrame): - return - if self.unhandled_frame_cb: - self.unhandled_frame_cb(frm) - - def _receive_settings(self, hide=False): - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - break - else: - self._handle_unexpected_frame(frm) - - def _next_stream_id(self): - if self.current_stream_id is None: - if self.is_server: - # servers must use even stream ids - self.current_stream_id = 2 - else: - # clients must use odd stream ids - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - - def _apply_settings(self, settings, hide=False): - for setting, value in settings.items(): - old_value = self.http2_settings[setting] - if not old_value: - old_value = '-' - self.http2_settings[setting] = value - - frm = frame.SettingsFrame( - state=self, - flags=frame.Frame.FLAG_ACK) - self.send_frame(frm, hide) - - def _update_flow_control_window(self, stream_id, increment): - frm = frame.WindowUpdateFrame(stream_id=0, window_size_increment=increment) - self.send_frame(frm) - frm = frame.WindowUpdateFrame(stream_id=stream_id, window_size_increment=increment) - self.send_frame(frm) - - def _create_headers(self, headers, stream_id, end_stream=True): - def frame_cls(chunks): - for i in chunks: - if i == 0: - yield frame.HeadersFrame, i - else: - yield frame.ContinuationFrame, i - - header_block_fragment = self.encoder.encode(headers.fields) - - chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - chunks = range(0, len(header_block_fragment), chunk_size) - frms = [frm_cls( - state=self, - flags=frame.Frame.FLAG_NO_FLAGS, - stream_id=stream_id, - header_block_fragment=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] - - last_flags = frame.Frame.FLAG_END_HEADERS - if end_stream: - last_flags |= frame.Frame.FLAG_END_STREAM - frms[-1].flags = last_flags - - if self.dump_frames: # pragma no cover - for frm in frms: - print(frm.human_readable(">>")) - - return [frm.to_bytes() for frm in frms] - - def _create_body(self, body, stream_id): - if body is None or len(body) == 0: - return b'' - - chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - chunks = range(0, len(body), chunk_size) - frms = [frame.DataFrame( - state=self, - flags=frame.Frame.FLAG_NO_FLAGS, - stream_id=stream_id, - payload=body[i:i+chunk_size]) for i in chunks] - frms[-1].flags = frame.Frame.FLAG_END_STREAM - - if self.dump_frames: # pragma no cover - for frm in frms: - print(frm.human_readable(">>")) - - return [frm.to_bytes() for frm in frms] - - def _receive_transmission(self, stream_id=None, include_body=True): - if not include_body: - raise NotImplementedError() - - body_expected = True - - header_block_fragment = b'' - body = b'' - - while True: - frm = self.read_frame() - if ( - (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and - (stream_id is None or frm.stream_id == stream_id) - ): - stream_id = frm.stream_id - header_block_fragment += frm.header_block_fragment - if frm.flags & frame.Frame.FLAG_END_STREAM: - body_expected = False - if frm.flags & frame.Frame.FLAG_END_HEADERS: - break - else: - self._handle_unexpected_frame(frm) - - while body_expected: - frm = self.read_frame() - if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: - body += frm.payload - if frm.flags & frame.Frame.FLAG_END_STREAM: - break - else: - self._handle_unexpected_frame(frm) - - headers = http.Headers( - [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] - ) - - return stream_id, headers, body diff --git a/netlib/http/models.py b/netlib/http/models.py new file mode 100644 index 00000000..bd5863b1 --- /dev/null +++ b/netlib/http/models.py @@ -0,0 +1,571 @@ +from __future__ import absolute_import, print_function, division +import copy + +from ..odict import ODict +from .. import utils, encoding +from ..utils import always_bytes, always_byte_args +from . import cookies + +import six +from six.moves import urllib +try: + from collections import MutableMapping +except ImportError: + from collections.abc import MutableMapping + +HDR_FORM_URLENCODED = b"application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = b"multipart/form-data" + +CONTENT_MISSING = 0 + + +class Headers(MutableMapping, object): + """ + Header class which allows both convenient access to individual headers as well as + direct access to the underlying raw data. Provides a full dictionary interface. + + Example: + + .. code-block:: python + + # Create header from a list of (header_name, header_value) tuples + >>> h = Headers([ + ["Host","example.com"], + ["Accept","text/html"], + ["accept","application/xml"] + ]) + + # Headers mostly behave like a normal dict. + >>> h["Host"] + "example.com" + + # HTTP Headers are case insensitive + >>> h["host"] + "example.com" + + # Multiple headers are folded into a single header as per RFC7230 + >>> h["Accept"] + "text/html, application/xml" + + # Setting a header removes all existing headers with the same name. + >>> h["Accept"] = "application/text" + >>> h["Accept"] + "application/text" + + # str(h) returns a HTTP1 header block. + >>> print(h) + Host: example.com + Accept: application/text + + # For full control, the raw header fields can be accessed + >>> h.fields + + # Headers can also be crated from keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") + + Caveats: + For use with the "Set-Cookie" header, see :py:meth:`get_all`. + """ + + @always_byte_args("ascii") + def __init__(self, fields=None, **headers): + """ + Args: + fields: (optional) list of ``(name, value)`` header tuples, + e.g. ``[("Host","example.com")]``. All names and values must be bytes. + **headers: Additional headers to set. Will overwrite existing values from `fields`. + For convenience, underscores in header names will be transformed to dashes - + this behaviour does not extend to other methods. + If ``**headers`` contains multiple keys that have equal ``.lower()`` s, + the behavior is undefined. + """ + self.fields = fields or [] + + # content_type -> content-type + headers = { + name.encode("ascii").replace(b"_", b"-"): value + for name, value in six.iteritems(headers) + } + self.update(headers) + + def __bytes__(self): + return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + + if six.PY2: + __str__ = __bytes__ + + @always_byte_args("ascii") + def __getitem__(self, name): + values = self.get_all(name) + if not values: + raise KeyError(name) + return b", ".join(values) + + @always_byte_args("ascii") + def __setitem__(self, name, value): + idx = self._index(name) + + # To please the human eye, we insert at the same position the first existing header occured. + if idx is not None: + del self[name] + self.fields.insert(idx, [name, value]) + else: + self.fields.append([name, value]) + + @always_byte_args("ascii") + def __delitem__(self, name): + if name not in self: + raise KeyError(name) + name = name.lower() + self.fields = [ + field for field in self.fields + if name != field[0].lower() + ] + + def __iter__(self): + seen = set() + for name, _ in self.fields: + name_lower = name.lower() + if name_lower not in seen: + seen.add(name_lower) + yield name + + def __len__(self): + return len(set(name.lower() for name, _ in self.fields)) + + #__hash__ = object.__hash__ + + def _index(self, name): + name = name.lower() + for i, field in enumerate(self.fields): + if field[0].lower() == name: + return i + return None + + def __eq__(self, other): + if isinstance(other, Headers): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @always_byte_args("ascii") + def get_all(self, name): + """ + Like :py:meth:`get`, but does not fold multiple headers into a single one. + This is useful for Set-Cookie headers, which do not support folding. + + See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + name_lower = name.lower() + values = [value for n, value in self.fields if n.lower() == name_lower] + return values + + def set_all(self, name, values): + """ + Explicitly set multiple headers for the given key. + See: :py:meth:`get_all` + """ + name = always_bytes(name, "ascii") + values = (always_bytes(value, "ascii") for value in values) + if name in self: + del self[name] + self.fields.extend( + [name, value] for value in values + ) + + def copy(self): + return Headers(copy.copy(self.fields)) + + # Implement the StateObject protocol from mitmproxy + def get_state(self, short=False): + return tuple(tuple(field) for field in self.fields) + + def load_state(self, state): + self.fields = [list(field) for field in state] + + @classmethod + def from_state(cls, state): + return cls([list(field) for field in state]) + + +class Request(object): + # This list is adopted legacy code. + # We probably don't need to strip off keep-alive. + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding', + 'Upgrade', + ] + + def __init__( + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers=None, + body=None, + timestamp_start=None, + timestamp_end=None, + form_out=None + ): + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + + self.form_in = form_in + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.httpversion = httpversion + self.headers = headers + self.body = body + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + self.form_out = form_out or form_in + + def __eq__(self, other): + try: + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False + + def __repr__(self): + if self.host and self.port: + hostport = "{}:{}".format(self.host, self.port) + else: + hostport = "" + path = self.path or "" + return "HTTPRequest({} {}{})".format( + self.method, hostport, path + ) + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + self.headers.pop(i, None) + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = "identity" + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + accept_encoding = self.headers.get("accept-encoding") + if accept_encoding: + self.headers["accept-encoding"] = ( + ', '.join( + e + for e in encoding.ENCODINGS + if e in accept_encoding + ) + ) + + def update_host_header(self): + """ + Update the host header to reflect the current target. + """ + self.headers["Host"] = self.host + + def get_form(self): + """ + Retrieves the URL-encoded or multipart form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.body: + if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): + return self.get_form_urlencoded() + elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): + return self.get_form_multipart() + return ODict([]) + + def get_form_urlencoded(self): + """ + Retrieves the URL-encoded form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): + return ODict(utils.urldecode(self.body)) + return ODict([]) + + def get_form_multipart(self): + if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): + return ODict( + utils.multipartdecode( + self.headers, + self.body)) + return ODict([]) + + def set_form_urlencoded(self, odict): + """ + Sets the body to the URL-encoded form data, and adds the + appropriate content-type header. Note that this will destory the + existing body if there is one. + """ + # FIXME: If there's an existing content-type header indicating a + # url-encoded form, leave it alone. + self.headers["Content-Type"] = HDR_FORM_URLENCODED + self.body = utils.urlencode(odict.lst) + + def get_path_components(self): + """ + Returns the path components of the URL as a list of strings. + + Components are unquoted. + """ + _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + return [urllib.parse.unquote(i) for i in path.split(b"/") if i] + + def set_path_components(self, lst): + """ + Takes a list of strings, and sets the path component of the URL. + + Components are quoted. + """ + lst = [urllib.parse.quote(i, safe="") for i in lst] + path = b"/" + b"/".join(lst) + scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse( + [scheme, netloc, path, params, query, fragment] + ) + + def get_query(self): + """ + Gets the request query string. Returns an ODict object. + """ + _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + if query: + return ODict(utils.urldecode(query)) + return ODict([]) + + def set_query(self, odict): + """ + Takes an ODict object, and sets the request query string. + """ + scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) + query = utils.urlencode(odict.lst) + self.url = urllib.parse.urlunparse( + [scheme, netloc, path, params, query, fragment] + ) + + def pretty_host(self, hostheader): + """ + Heuristic to get the host of the request. + + Note that pretty_host() does not always return the TCP destination + of the request, e.g. if an upstream proxy is in place + + If hostheader is set to True, the Host: header will be used as + additional (and preferred) data source. This is handy in + transparent mode, where only the IO of the destination is known, + but not the resolved name. This is disabled by default, as an + attacker may spoof the host header to confuse an analyst. + """ + if hostheader and b"Host" in self.headers: + try: + return self.headers[b"Host"].decode("idna") + except ValueError: + pass + if self.host: + return self.host.decode("idna") + + def pretty_url(self, hostheader): + if self.form_out == "authority": # upstream proxy mode + return "%s:%s" % (self.pretty_host(hostheader), self.port) + return utils.unparse_url(self.scheme, + self.pretty_host(hostheader), + self.port, + self.path).encode('ascii') + + def get_cookies(self): + """ + Returns a possibly empty netlib.odict.ODict object. + """ + ret = ODict() + for i in self.headers.get_all("cookie"): + ret.extend(cookies.parse_cookie_header(i)) + return ret + + def set_cookies(self, odict): + """ + Takes an netlib.odict.ODict object. Over-writes any existing Cookie + headers. + """ + v = cookies.format_cookie_header(odict) + self.headers["Cookie"] = v + + @property + def url(self): + """ + Returns a URL string, constructed from the Request's URL components. + """ + return utils.unparse_url( + self.scheme, + self.host, + self.port, + self.path + ).encode('ascii') + + @url.setter + def url(self, url): + """ + Parses a URL specification, and updates the Request's information + accordingly. + + Raises: + ValueError if the URL was invalid + """ + # TODO: Should handle incoming unicode here. + parts = utils.parse_url(url) + if not parts: + raise ValueError("Invalid URL: %s" % url) + self.scheme, self.host, self.port, self.path = parts + + @property + def content(self): # pragma: no cover + # TODO: remove deprecated getter + return self.body + + @content.setter + def content(self, content): # pragma: no cover + # TODO: remove deprecated setter + self.body = content + + +class Response(object): + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Alternate-Protocol', + 'Alt-Svc', + ] + + def __init__( + self, + httpversion, + status_code, + msg=None, + headers=None, + body=None, + sslinfo=None, + timestamp_start=None, + timestamp_end=None, + ): + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + + self.httpversion = httpversion + self.status_code = status_code + self.msg = msg + self.headers = headers + self.body = body + self.sslinfo = sslinfo + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + def __eq__(self, other): + try: + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False + + def __repr__(self): + # return "Response(%s - %s)" % (self.status_code, self.msg) + + if self.body: + size = utils.pretty_size(len(self.body)) + else: + size = "content missing" + # TODO: Remove "(unknown content type, content missing)" edge-case + return "".format( + status_code=self.status_code, + msg=self.msg, + contenttype=self.headers.get("content-type", "unknown content type"), + size=size) + + def get_cookies(self): + """ + Get the contents of all Set-Cookie headers. + + Returns a possibly empty ODict, where keys are cookie name strings, + and values are [value, attr] lists. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary + attributes (e.g. HTTPOnly) are indicated by a Null value. + """ + ret = [] + for header in self.headers.get_all("set-cookie"): + v = cookies.parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append([name, [value, attrs]]) + return ODict(ret) + + def set_cookies(self, odict): + """ + Set the Set-Cookie headers on this response, over-writing existing + headers. + + Accepts an ODict of the same format as that returned by get_cookies. + """ + values = [] + for i in odict.lst: + values.append( + cookies.format_set_cookie_header( + i[0], + i[1][0], + i[1][1] + ) + ) + self.headers.set_all("Set-Cookie", values) + + @property + def content(self): # pragma: no cover + # TODO: remove deprecated getter + return self.body + + @content.setter + def content(self, content): # pragma: no cover + # TODO: remove deprecated setter + self.body = content + + @property + def code(self): # pragma: no cover + # TODO: remove deprecated getter + return self.status_code + + @code.setter + def code(self, code): # pragma: no cover + # TODO: remove deprecated setter + self.status_code = code diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py deleted file mode 100644 index 5bb098a7..00000000 --- a/netlib/http/semantics.py +++ /dev/null @@ -1,632 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import UserDict -import copy -import urllib -import urlparse - -from .. import odict -from . import cookies, exceptions -from netlib import utils, encoding - -HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" -HDR_FORM_MULTIPART = "multipart/form-data" - -CONTENT_MISSING = 0 - - -class Headers(object, UserDict.DictMixin): - """ - Header class which allows both convenient access to individual headers as well as - direct access to the underlying raw data. Provides a full dictionary interface. - - Example: - - .. code-block:: python - - # Create header from a list of (header_name, header_value) tuples - >>> h = Headers([ - ["Host","example.com"], - ["Accept","text/html"], - ["accept","application/xml"] - ]) - - # Headers mostly behave like a normal dict. - >>> h["Host"] - "example.com" - - # HTTP Headers are case insensitive - >>> h["host"] - "example.com" - - # Multiple headers are folded into a single header as per RFC7230 - >>> h["Accept"] - "text/html, application/xml" - - # Setting a header removes all existing headers with the same name. - >>> h["Accept"] = "application/text" - >>> h["Accept"] - "application/text" - - # str(h) returns a HTTP1 header block. - >>> print(h) - Host: example.com - Accept: application/text - - # For full control, the raw header fields can be accessed - >>> h.fields - - # Headers can also be crated from keyword arguments - >>> h = Headers(host="example.com", content_type="application/xml") - - Caveats: - For use with the "Set-Cookie" header, see :py:meth:`get_all`. - """ - - def __init__(self, fields=None, **headers): - """ - Args: - fields: (optional) list of ``(name, value)`` header tuples, e.g. ``[("Host","example.com")]`` - **headers: Additional headers to set. Will overwrite existing values from `fields`. - For convenience, underscores in header names will be transformed to dashes - - this behaviour does not extend to other methods. - If ``**headers`` contains multiple keys that have equal ``.lower()`` s, - the behavior is undefined. - """ - self.fields = fields or [] - - # content_type -> content-type - headers = { - name.replace("_", "-"): value - for name, value in headers.iteritems() - } - self.update(headers) - - def __str__(self): - return "\r\n".join(": ".join(field) for field in self.fields) + "\r\n" - - def __getitem__(self, name): - values = self.get_all(name) - if not values: - raise KeyError(name) - else: - return ", ".join(values) - - def __setitem__(self, name, value): - idx = self._index(name) - - # To please the human eye, we insert at the same position the first existing header occured. - if idx is not None: - del self[name] - self.fields.insert(idx, [name, value]) - else: - self.fields.append([name, value]) - - def __delitem__(self, name): - if name not in self: - raise KeyError(name) - name = name.lower() - self.fields = [ - field for field in self.fields - if name != field[0].lower() - ] - - def _index(self, name): - name = name.lower() - for i, field in enumerate(self.fields): - if field[0].lower() == name: - return i - return None - - def keys(self): - seen = set() - names = [] - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - names.append(name) - return names - - def __eq__(self, other): - if isinstance(other, Headers): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def get_all(self, name): - """ - Like :py:meth:`get`, but does not fold multiple headers into a single one. - This is useful for Set-Cookie headers, which do not support folding. - - See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 - """ - name = name.lower() - values = [value for n, value in self.fields if n.lower() == name] - return values - - def set_all(self, name, values): - """ - Explicitly set multiple headers for the given key. - See: :py:meth:`get_all` - """ - if name in self: - del self[name] - self.fields.extend( - [name, value] for value in values - ) - - def copy(self): - return Headers(copy.copy(self.fields)) - - # Implement the StateObject protocol from mitmproxy - def get_state(self, short=False): - return tuple(tuple(field) for field in self.fields) - - def load_state(self, state): - self.fields = [list(field) for field in state] - - @classmethod - def from_state(cls, state): - return cls([list(field) for field in state]) - - -class ProtocolMixin(object): - def read_request(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def read_response(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def assemble(self, message): - if isinstance(message, Request): - return self.assemble_request(message) - elif isinstance(message, Response): - return self.assemble_response(message) - else: - raise ValueError("HTTP message not supported.") - - def assemble_request(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def assemble_response(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - -class Request(object): - # This list is adopted legacy code. - # We probably don't need to strip off keep-alive. - _headers_to_strip_off = [ - 'Proxy-Connection', - 'Keep-Alive', - 'Connection', - 'Transfer-Encoding', - 'Upgrade', - ] - - def __init__( - self, - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - form_out=None - ): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) - - self.form_in = form_in - self.method = method - self.scheme = scheme - self.host = host - self.port = port - self.path = path - self.httpversion = httpversion - self.headers = headers - self.body = body - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - self.form_out = form_out or form_in - - def __eq__(self, other): - try: - self_d = [self.__dict__[k] for k in self.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - return self_d == other_d - except: - return False - - def __repr__(self): - # return "Request(%s - %s, %s)" % (self.method, self.host, self.path) - - return "".format( - self.legacy_first_line()[:-9] - ) - - def legacy_first_line(self, form=None): - if form is None: - form = self.form_out - if form == "relative": - return '%s %s HTTP/%s.%s' % ( - self.method, - self.path, - self.httpversion[0], - self.httpversion[1], - ) - elif form == "authority": - return '%s %s:%s HTTP/%s.%s' % ( - self.method, - self.host, - self.port, - self.httpversion[0], - self.httpversion[1], - ) - elif form == "absolute": - return '%s %s://%s:%s%s HTTP/%s.%s' % ( - self.method, - self.scheme, - self.host, - self.port, - self.path, - self.httpversion[0], - self.httpversion[1], - ) - else: - raise exceptions.HttpError(400, "Invalid request form") - - def anticache(self): - """ - Modifies this request to remove headers that might produce a cached - response. That is, we remove ETags and If-Modified-Since headers. - """ - delheaders = [ - "if-modified-since", - "if-none-match", - ] - for i in delheaders: - self.headers.pop(i, None) - - def anticomp(self): - """ - Modifies this request to remove headers that will compress the - resource's data. - """ - self.headers["accept-encoding"] = "identity" - - def constrain_encoding(self): - """ - Limits the permissible Accept-Encoding values, based on what we can - decode appropriately. - """ - accept_encoding = self.headers.get("accept-encoding") - if accept_encoding: - self.headers["accept-encoding"] = ( - ', '.join( - e - for e in encoding.ENCODINGS - if e in accept_encoding - ) - ) - - def update_host_header(self): - """ - Update the host header to reflect the current target. - """ - self.headers["Host"] = self.host - - def get_form(self): - """ - Retrieves the URL-encoded or multipart form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.body: - if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): - return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): - return self.get_form_multipart() - return odict.ODict([]) - - def get_form_urlencoded(self): - """ - Retrieves the URL-encoded form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): - return odict.ODict(utils.urldecode(self.body)) - return odict.ODict([]) - - def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): - return odict.ODict( - utils.multipartdecode( - self.headers, - self.body)) - return odict.ODict([]) - - def set_form_urlencoded(self, odict): - """ - Sets the body to the URL-encoded form data, and adds the - appropriate content-type header. Note that this will destory the - existing body if there is one. - """ - # FIXME: If there's an existing content-type header indicating a - # url-encoded form, leave it alone. - self.headers["Content-Type"] = HDR_FORM_URLENCODED - self.body = utils.urlencode(odict.lst) - - def get_path_components(self): - """ - Returns the path components of the URL as a list of strings. - - Components are unquoted. - """ - _, _, path, _, _, _ = urlparse.urlparse(self.url) - return [urllib.unquote(i) for i in path.split("/") if i] - - def set_path_components(self, lst): - """ - Takes a list of strings, and sets the path component of the URL. - - Components are quoted. - """ - lst = [urllib.quote(i, safe="") for i in lst] - path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url) - self.url = urlparse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def get_query(self): - """ - Gets the request query string. Returns an ODict object. - """ - _, _, _, _, query, _ = urlparse.urlparse(self.url) - if query: - return odict.ODict(utils.urldecode(query)) - return odict.ODict([]) - - def set_query(self, odict): - """ - Takes an ODict object, and sets the request query string. - """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.url) - query = utils.urlencode(odict.lst) - self.url = urlparse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def pretty_host(self, hostheader): - """ - Heuristic to get the host of the request. - - Note that pretty_host() does not always return the TCP destination - of the request, e.g. if an upstream proxy is in place - - If hostheader is set to True, the Host: header will be used as - additional (and preferred) data source. This is handy in - transparent mode, where only the IO of the destination is known, - but not the resolved name. This is disabled by default, as an - attacker may spoof the host header to confuse an analyst. - """ - host = None - if hostheader: - host = self.headers.get("Host") - if not host: - host = self.host - if host: - try: - return host.encode("idna") - except ValueError: - return host - else: - return None - - def pretty_url(self, hostheader): - if self.form_out == "authority": # upstream proxy mode - return "%s:%s" % (self.pretty_host(hostheader), self.port) - return utils.unparse_url(self.scheme, - self.pretty_host(hostheader), - self.port, - self.path).encode('ascii') - - def get_cookies(self): - """ - Returns a possibly empty netlib.odict.ODict object. - """ - ret = odict.ODict() - for i in self.headers.get_all("cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret - - def set_cookies(self, odict): - """ - Takes an netlib.odict.ODict object. Over-writes any existing Cookie - headers. - """ - v = cookies.format_cookie_header(odict) - self.headers["Cookie"] = v - - @property - def url(self): - """ - Returns a URL string, constructed from the Request's URL components. - """ - return utils.unparse_url( - self.scheme, - self.host, - self.port, - self.path - ).encode('ascii') - - @url.setter - def url(self, url): - """ - Parses a URL specification, and updates the Request's information - accordingly. - - Returns False if the URL was invalid, True if the request succeeded. - """ - parts = utils.parse_url(url) - if not parts: - raise ValueError("Invalid URL: %s" % url) - self.scheme, self.host, self.port, self.path = parts - - @property - def content(self): # pragma: no cover - # TODO: remove deprecated getter - return self.body - - @content.setter - def content(self, content): # pragma: no cover - # TODO: remove deprecated setter - self.body = content - - -class EmptyRequest(Request): - def __init__( - self, - form_in="", - method="", - scheme="", - host="", - port="", - path="", - httpversion=(0, 0), - headers=None, - body="" - ): - super(EmptyRequest, self).__init__( - form_in=form_in, - method=method, - scheme=scheme, - host=host, - port=port, - path=path, - httpversion=httpversion, - headers=headers, - body=body, - ) - - -class Response(object): - _headers_to_strip_off = [ - 'Proxy-Connection', - 'Alternate-Protocol', - 'Alt-Svc', - ] - - def __init__( - self, - httpversion, - status_code, - msg=None, - headers=None, - body=None, - sslinfo=None, - timestamp_start=None, - timestamp_end=None, - ): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) - - self.httpversion = httpversion - self.status_code = status_code - self.msg = msg - self.headers = headers - self.body = body - self.sslinfo = sslinfo - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - def __eq__(self, other): - try: - self_d = [self.__dict__[k] for k in self.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - return self_d == other_d - except: - return False - - def __repr__(self): - # return "Response(%s - %s)" % (self.status_code, self.msg) - - if self.body: - size = utils.pretty_size(len(self.body)) - else: - size = "content missing" - # TODO: Remove "(unknown content type, content missing)" edge-case - return "".format( - status_code=self.status_code, - msg=self.msg, - contenttype=self.headers.get("content-type", "unknown content type"), - size=size) - - def get_cookies(self): - """ - Get the contents of all Set-Cookie headers. - - Returns a possibly empty ODict, where keys are cookie name strings, - and values are [value, attr] lists. Value is a string, and attr is - an ODictCaseless containing cookie attributes. Within attrs, unary - attributes (e.g. HTTPOnly) are indicated by a Null value. - """ - ret = [] - for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(header) - if v: - name, value, attrs = v - ret.append([name, [value, attrs]]) - return odict.ODict(ret) - - def set_cookies(self, odict): - """ - Set the Set-Cookie headers on this response, over-writing existing - headers. - - Accepts an ODict of the same format as that returned by get_cookies. - """ - values = [] - for i in odict.lst: - values.append( - cookies.format_set_cookie_header( - i[0], - i[1][0], - i[1][1] - ) - ) - self.headers.set_all("Set-Cookie", values) - - @property - def content(self): # pragma: no cover - # TODO: remove deprecated getter - return self.body - - @content.setter - def content(self, content): # pragma: no cover - # TODO: remove deprecated setter - self.body = content - - @property - def code(self): # pragma: no cover - # TODO: remove deprecated getter - return self.status_code - - @code.setter - def code(self, code): # pragma: no cover - # TODO: remove deprecated setter - self.status_code = code diff --git a/netlib/tcp.py b/netlib/tcp.py index 4a7f6153..1eb417b4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -834,14 +834,14 @@ class TCPServer(object): # If a thread has persisted after interpreter exit, the module might be # none. if traceback: - exc = traceback.format_exc() - print('-' * 40, file=fp) + exc = six.text_type(traceback.format_exc()) + print(u'-' * 40, file=fp) print( - "Error in processing of request from %s:%s" % ( + u"Error in processing of request from %s:%s" % ( client_address.host, client_address.port ), file=fp) print(exc, file=fp) - print('-' * 40, file=fp) + print(u'-' * 40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ diff --git a/netlib/tutils.py b/netlib/tutils.py index 951ef3d9..65c4a313 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -1,9 +1,11 @@ -import cStringIO +from io import BytesIO import tempfile import os import time import shutil from contextlib import contextmanager +import six +import sys from netlib import tcp, utils, http @@ -12,7 +14,7 @@ def treader(bytes): """ Construct a tcp.Read object from bytes. """ - fp = cStringIO.StringIO(bytes) + fp = BytesIO(bytes) return tcp.Reader(fp) @@ -28,7 +30,24 @@ def tmpdir(*args, **kwargs): shutil.rmtree(temp_workdir) -def raises(exc, obj, *args, **kwargs): +def _check_exception(expected, actual, exc_tb): + if isinstance(expected, six.string_types): + if expected.lower() not in str(actual).lower(): + six.reraise(AssertionError, AssertionError( + "Expected %s, but caught %s" % ( + repr(str(expected)), actual + ) + ), exc_tb) + else: + if not isinstance(actual, expected): + six.reraise(AssertionError, AssertionError( + "Expected %s, but caught %s %s" % ( + expected.__name__, actual.__class__.__name__, str(actual) + ) + ), exc_tb) + + +def raises(expected_exception, obj=None, *args, **kwargs): """ Assert that a callable raises a specified exception. @@ -43,28 +62,31 @@ def raises(exc, obj, *args, **kwargs): :kwargs Arguments to be passed to the callable. """ - try: - ret = obj(*args, **kwargs) - except Exception as v: - if isinstance(exc, basestring): - if exc.lower() in str(v).lower(): - return - else: - raise AssertionError( - "Expected %s, but caught %s" % ( - repr(str(exc)), v - ) - ) + if obj is None: + return RaisesContext(expected_exception) + else: + try: + ret = obj(*args, **kwargs) + except Exception as actual: + _check_exception(expected_exception, actual, sys.exc_info()[2]) else: - if isinstance(v, exc): - return - else: - raise AssertionError( - "Expected %s, but caught %s %s" % ( - exc.__name__, v.__class__.__name__, str(v) - ) - ) - raise AssertionError("No exception raised. Return value: {}".format(ret)) + raise AssertionError("No exception raised. Return value: {}".format(ret)) + + +class RaisesContext(object): + def __init__(self, expected_exception): + self.expected_exception = expected_exception + + def __enter__(self): + return + + def __exit__(self, exc_type, exc_val, exc_tb): + if not exc_type: + raise AssertionError("No exception raised.") + else: + _check_exception(self.expected_exception, exc_val, exc_tb) + return True + test_data = utils.Data(__name__) diff --git a/netlib/utils.py b/netlib/utils.py index d6774419..fb579cac 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,17 +1,17 @@ -from __future__ import (absolute_import, print_function, division) +from __future__ import absolute_import, print_function, division import os.path -import cgi -import urllib -import urlparse -import string import re -import six +import string import unicodedata +import six + +from six.moves import urllib + -def isascii(s): +def isascii(bytes): try: - s.decode("ascii") + bytes.decode("ascii") except ValueError: return False return True @@ -44,8 +44,8 @@ def clean_bin(s, keep_spacing=True): else: keep = b"" return b"".join( - ch if (31 < ord(ch) < 127 or ch in keep) else b"." - for ch in s + six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." + for ch in six.iterbytes(s) ) @@ -149,10 +149,7 @@ class Data(object): return fullpath -def is_valid_port(port): - if not 0 <= port <= 65535: - return False - return True +_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? 255: + return False + if host[-1] == ".": + host = host[:-1] + return all(_label_valid.match(x) for x in host.split(b".")) + + +def is_valid_port(port): + return 0 <= port <= 65535 + + +# PY2 workaround +def decode_parse_result(result, enc): + if hasattr(result, "decode"): + return result.decode(enc) + else: + return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) + + +# PY2 workaround +def encode_parse_result(result, enc): + if hasattr(result, "encode"): + return result.encode(enc) + else: + return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) def parse_url(url): """ - Returns a (scheme, host, port, path) tuple, or None on error. + URL-parsing function that checks that + - port is an integer 0-65535 + - host is a valid IDNA-encoded hostname with no null-bytes + - path is valid ASCII - Checks that: - port is an integer 0-65535 - host is a valid IDNA-encoded hostname with no null-bytes - path is valid ASCII + Args: + A URL (as bytes or as unicode) + + Returns: + A (scheme, host, port, path) tuple + + Raises: + ValueError, if the URL is not properly formatted. """ - try: - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - except ValueError: - return None - if not scheme: - return None - if '@' in netloc: - # FIXME: Consider what to do with the discarded credentials here Most - # probably we should extend the signature to return these as a separate - # value. - _, netloc = string.rsplit(netloc, '@', maxsplit=1) - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None + parsed = urllib.parse.urlparse(url) + + if not parsed.hostname: + raise ValueError("No hostname given") + + if isinstance(url, six.binary_type): + host = parsed.hostname + + # this should not raise a ValueError + decode_parse_result(parsed, "ascii") else: - host = netloc - if scheme.endswith("https"): - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path + host = parsed.hostname.encode("idna") + parsed = encode_parse_result(parsed, "ascii") + + port = parsed.port + if not port: + port = 443 if parsed.scheme == b"https" else 80 + + full_path = urllib.parse.urlunparse( + (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) + ) + if not full_path.startswith(b"/"): + full_path = b"/" + full_path + if not is_valid_host(host): - return None - if not isascii(path): - return None + raise ValueError("Invalid Host") if not is_valid_port(port): - return None - return scheme, host, port, path + raise ValueError("Invalid Port") + + return parsed.scheme, host, port, full_path def get_header_tokens(headers, key): @@ -217,7 +240,7 @@ def get_header_tokens(headers, key): """ if key not in headers: return [] - tokens = headers[key].split(",") + tokens = headers[key].split(b",") return [token.strip() for token in tokens] @@ -228,7 +251,7 @@ def hostport(scheme, host, port): if (port, scheme) in [(80, "http"), (443, "https")]: return host else: - return "%s:%s" % (host, port) + return b"%s:%s" % (host, port) def unparse_url(scheme, host, port, path=""): @@ -243,14 +266,14 @@ def urlencode(s): Takes a list of (key, value) tuples and returns a urlencoded string. """ s = [tuple(i) for i in s] - return urllib.urlencode(s, False) + return urllib.parse.urlencode(s, False) def urldecode(s): """ Takes a urlencoded string and returns a list of (key, value) tuples. """ - return cgi.parse_qsl(s, keep_blank_values=True) + return urllib.parse.parse_qsl(s, keep_blank_values=True) def parse_content_type(c): @@ -267,14 +290,14 @@ def parse_content_type(c): ("text", "html", {"charset": "UTF-8"}) """ - parts = c.split(";", 1) - ts = parts[0].split("/", 1) + parts = c.split(b";", 1) + ts = parts[0].split(b"/", 1) if len(ts) != 2: return None d = {} if len(parts) == 2: - for i in parts[1].split(";"): - clause = i.split("=", 1) + for i in parts[1].split(b";"): + clause = i.split(b"=", 1) if len(clause) == 2: d[clause[0].strip()] = clause[1].strip() return ts[0].lower(), ts[1].lower(), d @@ -289,7 +312,7 @@ def multipartdecode(headers, content): v = parse_content_type(v) if not v: return [] - boundary = v[2].get("boundary") + boundary = v[2].get(b"boundary") if not boundary: return [] @@ -306,3 +329,20 @@ def multipartdecode(headers, content): r.append((key, value)) return r return [] + + +def always_bytes(unicode_or_bytes, encoding): + if isinstance(unicode_or_bytes, six.text_type): + return unicode_or_bytes.encode(encoding) + return unicode_or_bytes + + +def always_byte_args(encoding): + """Decorator that transparently encodes all arguments passed as unicode""" + def decorator(fun): + def _fun(*args, **kwargs): + args = [always_bytes(arg, encoding) for arg in args] + kwargs = {k: always_bytes(v, encoding) for k, v in six.iteritems(kwargs)} + return fun(*args, **kwargs) + return _fun + return decorator diff --git a/netlib/version_check.py b/netlib/version_check.py index 1d7e025c..9cf27eea 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -7,6 +7,7 @@ from __future__ import division, absolute_import, print_function import sys import inspect import os.path +import six import OpenSSL from . import version @@ -19,8 +20,8 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): # consider major and minor version. if version.IVERSION[:2] != mitmproxy_version[:2]: print( - "You are using mitmproxy %s with netlib %s. " - "Most likely, that won't work - please upgrade!" % ( + u"You are using mitmproxy %s with netlib %s. " + u"Most likely, that won't work - please upgrade!" % ( mitmproxy_version, version.VERSION ), file=fp @@ -29,13 +30,13 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): - min_version_str = ".".join(str(x) for x in min_version) + min_version_str = u".".join(six.text_type(x) for x in min_version) try: v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) except ValueError: print( - "Cannot parse pyOpenSSL version: {}" - "mitmproxy requires pyOpenSSL {} or greater.".format( + u"Cannot parse pyOpenSSL version: {}" + u"mitmproxy requires pyOpenSSL {} or greater.".format( OpenSSL.__version__, min_version_str ), file=fp @@ -43,15 +44,15 @@ def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): return if v < min_version: print( - "You are using an outdated version of pyOpenSSL: " - "mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), + u"You are using an outdated version of pyOpenSSL: " + u"mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), file=fp ) # Some users apparently have multiple versions of pyOpenSSL installed. # Report which one we got. pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL)) print( - "Your pyOpenSSL {} installation is located at {}".format( + u"Your pyOpenSSL {} installation is located at {}".format( OpenSSL.__version__, pyopenssl_path ), file=fp diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py index 5acf7696..1c143919 100644 --- a/netlib/websockets/__init__.py +++ b/netlib/websockets/__init__.py @@ -1,2 +1,2 @@ -from frame import * -from protocol import * +from .frame import * +from .protocol import * diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index f7c615bd..bdcba5cb 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -1,9 +1,12 @@ -import cStringIO +from io import BytesIO import textwrap +from http.http1.protocol import _parse_authority_form +from netlib.exceptions import HttpSyntaxException, HttpReadDisconnect, HttpException -from netlib import http, odict, tcp, tutils +from netlib import http, tcp, tutils from netlib.http import semantics, Headers -from netlib.http.http1 import HTTP1Protocol +from netlib.http.http1 import HTTP1Protocol, read_message_body, read_request, \ + read_message_body_chunked, expected_http_body_size from ... import tservers @@ -14,8 +17,8 @@ class NoContentLengthHTTPHandler(tcp.BaseHandler): def mock_protocol(data=''): - rfile = cStringIO.StringIO(data) - wfile = cStringIO.StringIO() + rfile = BytesIO(data) + wfile = BytesIO() return HTTP1Protocol(rfile=rfile, wfile=wfile) @@ -37,53 +40,35 @@ def test_stripped_chunked_encoding_no_content(): assert "Content-Length" in mock_protocol()._assemble_response_headers(r) -def test_has_chunked_encoding(): - headers = http.Headers() - assert not HTTP1Protocol.has_chunked_encoding(headers) - headers["transfer-encoding"] = "chunked" - assert HTTP1Protocol.has_chunked_encoding(headers) - - def test_read_chunked(): - headers = http.Headers() - headers["transfer-encoding"] = "chunked" + req = tutils.treq(None) + req.headers["Transfer-Encoding"] = "chunked" - data = "1\r\na\r\n0\r\n" - tutils.raises( - "malformed chunked body", - mock_protocol(data).read_http_body, - headers, None, "GET", None, True - ) + data = b"1\r\na\r\n0\r\n" + with tutils.raises(HttpSyntaxException): + read_message_body(BytesIO(data), req) - data = "1\r\na\r\n0\r\n\r\n" - assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "a" + data = b"1\r\na\r\n0\r\n\r\n" + assert read_message_body(BytesIO(data), req) == b"a" - data = "\r\n\r\n1\r\na\r\n0\r\n\r\n" - assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "a" + data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n" + assert read_message_body(BytesIO(data), req) == b"ab" - data = "\r\n" - tutils.raises( - "closed prematurely", - mock_protocol(data).read_http_body, - headers, None, "GET", None, True - ) + data = b"\r\n" + with tutils.raises("closed prematurely"): + read_message_body(BytesIO(data), req) - data = "1\r\nfoo" - tutils.raises( - "malformed chunked body", - mock_protocol(data).read_http_body, - headers, None, "GET", None, True - ) + data = b"1\r\nfoo" + with tutils.raises("malformed chunked body"): + read_message_body(BytesIO(data), req) - data = "foo\r\nfoo" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, None, "GET", None, True - ) + data = b"foo\r\nfoo" + with tutils.raises(HttpSyntaxException): + read_message_body(BytesIO(data), req) - data = "5\r\naaaaa\r\n0\r\n\r\n" - tutils.raises("too large", mock_protocol(data).read_http_body, headers, 2, "GET", None, True) + data = b"5\r\naaaaa\r\n0\r\n\r\n" + with tutils.raises("too large"): + read_message_body(BytesIO(data), req, limit=2) def test_connection_close(): @@ -171,52 +156,37 @@ def test_read_http_body(): def test_expected_http_body_size(): # gibber in the content-length field headers = Headers(content_length="foo") - assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) is None + with tutils.raises(HttpSyntaxException): + expected_http_body_size(headers, False, "GET", 200) is None # negative number in the content-length field headers = Headers(content_length="-7") - assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) is None + with tutils.raises(HttpSyntaxException): + expected_http_body_size(headers, False, "GET", 200) is None # explicit length headers = Headers(content_length="5") - assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) == 5 + assert expected_http_body_size(headers, False, "GET", 200) == 5 # no length headers = Headers() - assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) == -1 + assert expected_http_body_size(headers, False, "GET", 200) == -1 # no length request headers = Headers() - assert HTTP1Protocol.expected_http_body_size(headers, True, "GET", None) == 0 - - -def test_get_request_line(): - data = "\nfoo" - p = mock_protocol(data) - assert p._get_request_line() == "foo" - assert not p._get_request_line() - - -def test_parse_http_protocol(): - assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1) - assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0) - assert not HTTP1Protocol._parse_http_protocol("HTTP/a.1") - assert not HTTP1Protocol._parse_http_protocol("HTTP/1.a") - assert not HTTP1Protocol._parse_http_protocol("foo/0.0") - assert not HTTP1Protocol._parse_http_protocol("HTTP/x") + assert expected_http_body_size(headers, True, "GET", None) == 0 + # expect header + headers = Headers(content_length="5", expect="100-continue") + assert expected_http_body_size(headers, True, "GET", None) == 0 def test_parse_init_connect(): - assert HTTP1Protocol._parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:444444 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("bogus") - assert not HTTP1Protocol._parse_init_connect("GET host.com:443 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:443 foo/1.0") - assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:foo HTTP/1.0") + assert _parse_authority_form(b"CONNECT host.com:443 HTTP/1.0") + tutils.raises(ValueError,_parse_authority_form, b"\0host.com:443") + tutils.raises(ValueError,_parse_authority_form, b"host.com:444444") + tutils.raises(ValueError,_parse_authority_form, b"CONNECT host.com443 HTTP/1.0") + tutils.raises(ValueError,_parse_authority_form, b"CONNECT host.com:foo HTTP/1.0") def test_parse_init_proxy(): - u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = HTTP1Protocol._parse_init_proxy(u) + u = b"GET http://foo.com:8888/test HTTP/1.1" + m, s, h, po, pa, httpversion = HTTP1Protocol._parse_absolute_form(u) assert m == "GET" assert s == "http" assert h == "foo.com" @@ -225,11 +195,14 @@ def test_parse_init_proxy(): assert httpversion == (1, 1) u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not HTTP1Protocol._parse_init_proxy(u) + assert not HTTP1Protocol._parse_absolute_form(u) - assert not HTTP1Protocol._parse_init_proxy("invalid") - assert not HTTP1Protocol._parse_init_proxy("GET invalid HTTP/1.1") - assert not HTTP1Protocol._parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + with tutils.raises(ValueError): + assert not HTTP1Protocol._parse_absolute_form("invalid") + with tutils.raises(ValueError): + assert not HTTP1Protocol._parse_absolute_form("GET invalid HTTP/1.1") + with tutils.raises(ValueError): + assert not HTTP1Protocol._parse_absolute_form("GET http://foo.com:8888/test foo/1.1") def test_parse_init_http(): @@ -317,15 +290,11 @@ class TestReadRequest(object): "get / HTTP/1.1\r\nfoo" ) tutils.raises( - tcp.NetLibDisconnect, + HttpReadDisconnect, self.tst, "\r\n" ) - def test_empty(self): - v = self.tst("", allow_empty=True) - assert isinstance(v, semantics.EmptyRequest) - def test_asterisk_form_in(self): v = self.tst("OPTIONS * HTTP/1.1") assert v.form_in == "relative" @@ -356,18 +325,18 @@ class TestReadRequest(object): assert v.host == "foo.com" def test_expect(self): - data = "".join( - "GET / HTTP/1.1\r\n" - "Content-Length: 3\r\n" - "Expect: 100-continue\r\n\r\n" - "foobar" + data = ( + b"GET / HTTP/1.1\r\n" + b"Content-Length: 3\r\n" + b"Expect: 100-continue\r\n" + b"\r\n" + b"foobar" ) - p = mock_protocol(data) - v = p.read_request() - assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" - assert v.body == "foo" - assert p.tcp_handler.rfile.read(3) == "bar" + rfile = BytesIO(data) + r = read_request(rfile) + assert r.body == b"" + assert rfile.read(-1) == b"foobar" class TestReadResponse(object): diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py index 5d5cb0ba..efdb55e2 100644 --- a/test/http/http2/test_frames.py +++ b/test/http/http2/test_frames.py @@ -1,4 +1,4 @@ -import cStringIO +from io import BytesIO from nose.tools import assert_equal from netlib import tcp, tutils @@ -7,7 +7,7 @@ from netlib.http.http2.frame import * def hex_to_file(data): data = data.decode('hex') - return tcp.Reader(cStringIO.StringIO(data)) + return tcp.Reader(BytesIO(data)) def test_invalid_flags(): diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py index 17c91fe5..ee192dd7 100644 --- a/test/http/test_authentication.py +++ b/test/http/test_authentication.py @@ -5,7 +5,7 @@ from netlib.http import authentication, Headers def test_parse_http_basic_auth(): - vals = ("basic", "foo", "bar") + vals = (b"basic", b"foo", b"bar") assert authentication.parse_http_basic_auth( authentication.assemble_http_basic_auth(*vals) ) == vals diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py index 6dcbbe07..44d3c85e 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_semantics.py @@ -475,7 +475,7 @@ class TestHeaders(object): def test_str(self): headers = semantics.Headers(Host="example.com") - assert str(headers) == "Host: example.com\r\n" + assert bytes(headers) == "Host: example.com\r\n" headers = semantics.Headers([ ["Host", "example.com"], diff --git a/test/test_encoding.py b/test/test_encoding.py index 612aea89..9da3a38d 100644 --- a/test/test_encoding.py +++ b/test/test_encoding.py @@ -9,25 +9,29 @@ def test_identity(): def test_gzip(): - assert "string" == encoding.decode( + assert b"string" == encoding.decode( "gzip", encoding.encode( "gzip", - "string")) - assert None == encoding.decode("gzip", "bogus") + b"string" + ) + ) + assert encoding.decode("gzip", b"bogus") is None def test_deflate(): - assert "string" == encoding.decode( + assert b"string" == encoding.decode( "deflate", encoding.encode( "deflate", - "string")) - assert "string" == encoding.decode( + b"string" + ) + ) + assert b"string" == encoding.decode( "deflate", encoding.encode( "deflate", - "string")[ - 2:- - 4]) - assert None == encoding.decode("deflate", "bogus") + b"string" + )[2:-4] + ) + assert encoding.decode("deflate", b"bogus") is None diff --git a/test/test_utils.py b/test/test_utils.py index 9dba5d35..8b2ddae4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -36,46 +36,51 @@ def test_pretty_size(): def test_parse_url(): - assert not utils.parse_url("") + with tutils.raises(ValueError): + utils.parse_url("") - u = "http://foo.com:8888/test" - s, h, po, pa = utils.parse_url(u) - assert s == "http" - assert h == "foo.com" + s, h, po, pa = utils.parse_url(b"http://foo.com:8888/test") + assert s == b"http" + assert h == b"foo.com" assert po == 8888 - assert pa == "/test" + assert pa == b"/test" s, h, po, pa = utils.parse_url("http://foo/bar") - assert s == "http" - assert h == "foo" + assert s == b"http" + assert h == b"foo" assert po == 80 - assert pa == "/bar" + assert pa == b"/bar" - s, h, po, pa = utils.parse_url("http://user:pass@foo/bar") - assert s == "http" - assert h == "foo" + s, h, po, pa = utils.parse_url(b"http://user:pass@foo/bar") + assert s == b"http" + assert h == b"foo" assert po == 80 - assert pa == "/bar" + assert pa == b"/bar" - s, h, po, pa = utils.parse_url("http://foo") - assert pa == "/" + s, h, po, pa = utils.parse_url(b"http://foo") + assert pa == b"/" - s, h, po, pa = utils.parse_url("https://foo") + s, h, po, pa = utils.parse_url(b"https://foo") assert po == 443 - assert not utils.parse_url("https://foo:bar") - assert not utils.parse_url("https://foo:") + with tutils.raises(ValueError): + utils.parse_url(b"https://foo:bar") # Invalid IDNA - assert not utils.parse_url("http://\xfafoo") + with tutils.raises(ValueError): + utils.parse_url("http://\xfafoo") # Invalid PATH - assert not utils.parse_url("http:/\xc6/localhost:56121") + with tutils.raises(ValueError): + utils.parse_url("http:/\xc6/localhost:56121") # Null byte in host - assert not utils.parse_url("http://foo\0") + with tutils.raises(ValueError): + utils.parse_url("http://foo\0") # Port out of range - assert not utils.parse_url("http://foo:999999") + _, _, port, _ = utils.parse_url("http://foo:999999") + assert port == 80 # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt - assert not utils.parse_url('http://lo[calhost') + with tutils.raises(ValueError): + utils.parse_url('http://lo[calhost') def test_unparse_url(): @@ -106,23 +111,25 @@ def test_get_header_tokens(): def test_multipartdecode(): - boundary = 'somefancyboundary' + boundary = b'somefancyboundary' headers = Headers( - content_type='multipart/form-data; boundary=%s' % boundary + content_type=b'multipart/form-data; boundary=' + boundary + ) + content = ( + "--{0}\n" + "Content-Disposition: form-data; name=\"field1\"\n\n" + "value1\n" + "--{0}\n" + "Content-Disposition: form-data; name=\"field2\"\n\n" + "value2\n" + "--{0}--".format(boundary).encode("ascii") ) - content = "--{0}\n" \ - "Content-Disposition: form-data; name=\"field1\"\n\n" \ - "value1\n" \ - "--{0}\n" \ - "Content-Disposition: form-data; name=\"field2\"\n\n" \ - "value2\n" \ - "--{0}--".format(boundary) form = utils.multipartdecode(headers, content) assert len(form) == 2 - assert form[0] == ('field1', 'value1') - assert form[1] == ('field2', 'value2') + assert form[0] == (b"field1", b"value1") + assert form[1] == (b"field2", b"value2") def test_parse_content_type(): diff --git a/test/test_version_check.py b/test/test_version_check.py index 9a127814..ec2396fe 100644 --- a/test/test_version_check.py +++ b/test/test_version_check.py @@ -1,11 +1,11 @@ -import cStringIO +from io import StringIO import mock from netlib import version_check, version @mock.patch("sys.exit") def test_check_mitmproxy_version(sexit): - fp = cStringIO.StringIO() + fp = StringIO() version_check.check_mitmproxy_version(version.IVERSION, fp=fp) assert not fp.getvalue() assert not sexit.called @@ -18,7 +18,7 @@ def test_check_mitmproxy_version(sexit): @mock.patch("sys.exit") def test_check_pyopenssl_version(sexit): - fp = cStringIO.StringIO() + fp = StringIO() version_check.check_pyopenssl_version(fp=fp) assert not fp.getvalue() assert not sexit.called @@ -32,7 +32,7 @@ def test_check_pyopenssl_version(sexit): @mock.patch("OpenSSL.__version__") def test_unparseable_pyopenssl_version(version, sexit): version.split.return_value = ["foo", "bar"] - fp = cStringIO.StringIO() + fp = StringIO() version_check.check_pyopenssl_version(fp=fp) assert "Cannot parse" in fp.getvalue() assert not sexit.called diff --git a/test/tservers.py b/test/tservers.py index 682a9144..1f4ce725 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -1,7 +1,7 @@ from __future__ import (absolute_import, print_function, division) import threading -import Queue -import cStringIO +from six.moves import queue +from io import StringIO import OpenSSL from netlib import tcp from netlib import tutils @@ -27,7 +27,7 @@ class ServerTestBase(object): @classmethod def setupAll(cls): - cls.q = Queue.Queue() + cls.q = queue.Queue() s = cls.makeserver() cls.port = s.address.port cls.server = ServerThread(s) @@ -102,6 +102,6 @@ class TServer(tcp.TCPServer): h.finish() def handle_error(self, connection, client_address, fp=None): - s = cStringIO.StringIO() + s = StringIO() tcp.TCPServer.handle_error(self, connection, client_address, s) self.q.put(s.getvalue()) -- cgit v1.2.3 From a077d8877d210562f703c23e9625e8467c81222d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 16 Sep 2015 00:04:23 +0200 Subject: finish netlib.http.http1 refactor --- netlib/http/__init__.py | 6 +- netlib/http/http1/__init__.py | 4 +- netlib/http/http1/assemble.py | 8 +- netlib/http/http1/read.py | 152 +++++---- netlib/http/http2/connections.py | 4 +- netlib/http/http2/frame.py | 654 +++++++++++++++++++++++++++++++++++++ netlib/http/http2/frames.py | 633 ----------------------------------- netlib/http/models.py | 2 - netlib/tutils.py | 74 ++--- netlib/utils.py | 6 +- test/http/http1/test_assemble.py | 91 ++++++ test/http/http1/test_protocol.py | 466 -------------------------- test/http/http1/test_read.py | 313 ++++++++++++++++++ test/http/http2/test_frames.py | 2 +- test/http/http2/test_protocol.py | 16 +- test/http/test_exceptions.py | 6 - test/http/test_models.py | 540 ++++++++++++++++++++++++++++++ test/http/test_semantics.py | 573 -------------------------------- test/websockets/test_websockets.py | 16 +- 19 files changed, 1739 insertions(+), 1827 deletions(-) create mode 100644 netlib/http/http2/frame.py delete mode 100644 netlib/http/http2/frames.py create mode 100644 test/http/http1/test_assemble.py create mode 100644 test/http/http1/test_read.py delete mode 100644 test/http/test_exceptions.py create mode 100644 test/http/test_models.py delete mode 100644 test/http/test_semantics.py diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 0b1a0bc5..9303de09 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,7 +1,9 @@ -from .models import Request, Response, Headers, CONTENT_MISSING +from .models import Request, Response, Headers +from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING from . import http1, http2 __all__ = [ - "Request", "Response", "Headers", "CONTENT_MISSING" + "Request", "Response", "Headers", + "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", "http1", "http2" ] diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py index 4d223f97..a72c2e05 100644 --- a/netlib/http/http1/__init__.py +++ b/netlib/http/http1/__init__.py @@ -1,7 +1,7 @@ from .read import ( read_request, read_request_head, read_response, read_response_head, - read_message_body, read_message_body_chunked, + read_body, connection_close, expected_http_body_size, ) @@ -14,7 +14,7 @@ from .assemble import ( __all__ = [ "read_request", "read_request_head", "read_response", "read_response_head", - "read_message_body", "read_message_body_chunked", + "read_body", "connection_close", "expected_http_body_size", "assemble_request", "assemble_request_head", diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index a3269eed..47c7e95a 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -31,8 +31,6 @@ def assemble_response_head(response): return b"%s\r\n%s\r\n" % (first_line, headers) - - def _assemble_request_line(request, form=None): if form is None: form = request.form_out @@ -50,7 +48,7 @@ def _assemble_request_line(request, form=None): request.httpversion ) elif form == "absolute": - return b"%s %s://%s:%s%s %s" % ( + return b"%s %s://%s:%d%s %s" % ( request.method, request.scheme, request.host, @@ -78,11 +76,11 @@ def _assemble_request_headers(request): if request.body or request.body == b"": headers[b"Content-Length"] = str(len(request.body)).encode("ascii") - return str(headers) + return bytes(headers) def _assemble_response_line(response): - return b"%s %s %s" % ( + return b"%s %d %s" % ( response.httpversion, response.status_code, response.msg, diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 573bc739..4c423c4c 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -7,12 +7,13 @@ from ... import utils from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException from .. import Request, Response, Headers -ALPN_PROTO_HTTP1 = 'http/1.1' +ALPN_PROTO_HTTP1 = b'http/1.1' def read_request(rfile, body_size_limit=None): request = read_request_head(rfile) - request.body = read_message_body(rfile, request, limit=body_size_limit) + expected_body_size = expected_http_body_size(request) + request.body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) request.timestamp_end = time.time() return request @@ -23,15 +24,14 @@ def read_request_head(rfile): Args: rfile: The input stream - body_size_limit (bool): Maximum body size Returns: - The HTTP request object + The HTTP request object (without body) Raises: - HttpReadDisconnect: If no bytes can be read from rfile. - HttpSyntaxException: If the input is invalid. - HttpException: A different error occured. + HttpReadDisconnect: No bytes can be read from rfile. + HttpSyntaxException: The input is malformed HTTP. + HttpException: Any other error occured. """ timestamp_start = time.time() if hasattr(rfile, "reset_timestamps"): @@ -51,12 +51,28 @@ def read_request_head(rfile): def read_response(rfile, request, body_size_limit=None): response = read_response_head(rfile) - response.body = read_message_body(rfile, request, response, body_size_limit) + expected_body_size = expected_http_body_size(request, response) + response.body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) response.timestamp_end = time.time() return response def read_response_head(rfile): + """ + Parse an HTTP response head (response line + headers) from an input stream + + Args: + rfile: The input stream + + Returns: + The HTTP request object (without body) + + Raises: + HttpReadDisconnect: No bytes can be read from rfile. + HttpSyntaxException: The input is malformed HTTP. + HttpException: Any other error occured. + """ + timestamp_start = time.time() if hasattr(rfile, "reset_timestamps"): rfile.reset_timestamps() @@ -68,50 +84,33 @@ def read_response_head(rfile): # more accurate timestamp_start timestamp_start = rfile.first_byte_timestamp - return Response( - http_version, - status_code, - message, - headers, - None, - timestamp_start - ) - - -def read_message_body(*args, **kwargs): - chunks = read_message_body_chunked(*args, **kwargs) - return b"".join(chunks) + return Response(http_version, status_code, message, headers, None, timestamp_start) -def read_message_body_chunked(rfile, request, response=None, limit=None, max_chunk_size=None): +def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): """ - Read an HTTP message body: + Read an HTTP message body Args: - If a request body should be read, only request should be passed. - If a response body should be read, both request and response should be passed. + rfile: The input stream + expected_size: The expected body size (see :py:meth:`expected_body_size`) + limit: Maximum body size + max_chunk_size: Maximium chunk size that gets yielded + + Returns: + A generator that yields byte chunks of the content. Raises: - HttpException - """ - if not response: - headers = request.headers - response_code = None - is_request = True - else: - headers = response.headers - response_code = response.status_code - is_request = False + HttpException, if an error occurs + Caveats: + max_chunk_size is not considered if the transfer encoding is chunked. + """ if not limit or limit < 0: limit = sys.maxsize if not max_chunk_size: max_chunk_size = limit - expected_size = expected_http_body_size( - headers, is_request, request.method, response_code - ) - if expected_size is None: for x in _read_chunked(rfile, limit): yield x @@ -125,6 +124,8 @@ def read_message_body_chunked(rfile, request, response=None, limit=None, max_chu while bytes_left: chunk_size = min(bytes_left, max_chunk_size) content = rfile.read(chunk_size) + if len(content) < chunk_size: + raise HttpException("Unexpected EOF") yield content bytes_left -= chunk_size else: @@ -148,10 +149,10 @@ def connection_close(http_version, headers): """ # At first, check if we have an explicit Connection header. if b"connection" in headers: - toks = utils.get_header_tokens(headers, "connection") - if b"close" in toks: + tokens = utils.get_header_tokens(headers, "connection") + if b"close" in tokens: return True - elif b"keep-alive" in toks: + elif b"keep-alive" in tokens: return False # If we don't have a Connection header, HTTP 1.1 connections are assumed to @@ -159,37 +160,41 @@ def connection_close(http_version, headers): return http_version != (1, 1) -def expected_http_body_size( - headers, - is_request, - request_method, - response_code, -): +def expected_http_body_size(request, response=False): """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding) - - -1, if all data should be read until end of stream. + Returns: + The expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding) + - -1, if all data should be read until end of stream. Raises: HttpSyntaxException, if the content length header is invalid """ # Determine response size according to # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() + if not response: + headers = request.headers + response_code = None + is_request = True + else: + headers = response.headers + response_code = response.status_code + is_request = False - is_empty_response = (not is_request and ( - request_method == b"HEAD" or - 100 <= response_code <= 199 or - (response_code == 200 and request_method == b"CONNECT") or - response_code in (204, 304) - )) + if is_request: + if headers.get(b"expect", b"").lower() == b"100-continue": + return 0 + else: + if request.method.upper() == b"HEAD": + return 0 + if 100 <= response_code <= 199: + return 0 + if response_code == 200 and request.method.upper() == b"CONNECT": + return 0 + if response_code in (204, 304): + return 0 - if is_empty_response: - return 0 - if is_request and headers.get(b"expect", b"").lower() == b"100-continue": - return 0 if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): return None if b"content-length" in headers: @@ -212,18 +217,22 @@ def _get_first_line(rfile): line = rfile.readline() if not line: raise HttpReadDisconnect() - return line + line = line.strip() + try: + line.decode("ascii") + except ValueError: + raise HttpSyntaxException("Non-ascii characters in first line: {}".format(line)) + return line.strip() def _read_request_line(rfile): line = _get_first_line(rfile) try: - method, path, http_version = line.strip().split(b" ") + method, path, http_version = line.split(b" ") if path == b"*" or path.startswith(b"/"): form = "relative" - path.decode("ascii") # should not raise a ValueError scheme, host, port = None, None, None elif method == b"CONNECT": form = "authority" @@ -233,6 +242,7 @@ def _read_request_line(rfile): form = "absolute" scheme, host, port, path = utils.parse_url(path) + _check_http_version(http_version) except ValueError: raise HttpSyntaxException("Bad HTTP request line: {}".format(line)) @@ -253,7 +263,7 @@ def _parse_authority_form(hostport): if not utils.is_valid_host(host) or not utils.is_valid_port(port): raise ValueError() except ValueError: - raise ValueError("Invalid host specification: {}".format(hostport)) + raise HttpSyntaxException("Invalid host specification: {}".format(hostport)) return host, port @@ -263,7 +273,7 @@ def _read_response_line(rfile): try: - parts = line.strip().split(b" ") + parts = line.split(b" ", 2) if len(parts) == 2: # handle missing message gracefully parts.append(b"") @@ -278,7 +288,7 @@ def _read_response_line(rfile): def _check_http_version(http_version): - if not re.match(rb"^HTTP/\d\.\d$", http_version): + if not re.match(br"^HTTP/\d\.\d$", http_version): raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) @@ -313,7 +323,7 @@ def _read_headers(rfile): return Headers(ret) -def _read_chunked(rfile, limit): +def _read_chunked(rfile, limit=sys.maxsize): """ Read a HTTP body with chunked transfer encoding. diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index b6d376d3..036bf68f 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -4,7 +4,7 @@ import time from hpack.hpack import Encoder, Decoder from netlib import http, utils -from netlib.http import semantics +from netlib.http import models as semantics from . import frame @@ -15,7 +15,7 @@ class TCPHandler(object): self.wfile = wfile -class HTTP2Protocol(semantics.ProtocolMixin): +class HTTP2Protocol(object): ERROR_CODES = utils.BiDi( NO_ERROR=0x0, diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py new file mode 100644 index 00000000..cb2cde99 --- /dev/null +++ b/netlib/http/http2/frame.py @@ -0,0 +1,654 @@ +from __future__ import absolute_import, print_function, division +import struct +from hpack.hpack import Encoder, Decoder + +from ...utils import BiDi +from ...exceptions import HttpSyntaxException + + +ERROR_CODES = BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd +) + +CLIENT_CONNECTION_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + +ALPN_PROTO_H2 = b'h2' + + +class Frame(object): + + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): + valid_flags = 0 + for flag in self.VALID_FLAGS: + valid_flags |= flag + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + if state is None: + class State(object): + pass + + state = State() + state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + + self.length = length + self.type = self.TYPE + self.flags = flags + self.stream_id = stream_id + + @classmethod + def _check_frame_size(cls, length, state): + if state: + settings = state.http2_settings + else: + settings = HTTP2_DEFAULT_SETTINGS.copy() + + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: + raise HttpSyntaxException( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_frame_size)) + + @classmethod + def from_file(cls, fp, state=None): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + if raw_header[:4] == b'HTTP': # pragma no cover + raise HttpSyntaxException("Expected HTTP2 Frame, got HTTP/1 connection") + + cls._check_frame_size(length, state) + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + self._check_frame_size(self.length, self.state) + + b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self, direction="-"): + self.length = len(self.payload_bytes()) + + return "\n".join([ + "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + self.payload_human_readable(), + "===============================================================", + ]) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): + super(DataFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + return "payload: %s" % str(self.payload) + + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b'', + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(HeadersFrame, self).__init__(state, length, flags, stream_id) + + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & Frame.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack( + '!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) + + b += self.header_block_fragment + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(PriorityFrame, self).__init__(state, length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') + + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) + + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) + + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + + def payload_human_readable(self): + return "error code: %#x" % self.error_code + + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = BiDi( + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): + super(SettingsFrame, self).__init__(state, length, flags, stream_id) + + if settings is None: + settings = {} + + self.settings = settings + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + for i in range(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i + 6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) + + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): + super(PingFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) + + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) + + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: + raise ValueError( + 'Window Size Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment + + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + + def payload_human_readable(self): + s = [] + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} + + +HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, +} diff --git a/netlib/http/http2/frames.py b/netlib/http/http2/frames.py deleted file mode 100644 index b36b3adf..00000000 --- a/netlib/http/http2/frames.py +++ /dev/null @@ -1,633 +0,0 @@ -import sys -import struct -from hpack.hpack import Encoder, Decoder - -from .. import utils - - -class FrameSizeError(Exception): - pass - - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - class State(object): - pass - - state = State() - state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(cls, length, state): - if state: - settings = state.http2_settings - else: - settings = HTTP2_DEFAULT_SETTINGS.copy() - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise FrameSizeError( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(cls, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - if raw_header[:4] == b'HTTP': # pragma no cover - print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" - - cls._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self, direction="-"): - self.length = len(self.payload_bytes()) - - return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), - self.payload_human_readable(), - "===============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & Frame.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Size Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - s = [] - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - return "\n".join(s) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} - - -HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, -} diff --git a/netlib/http/models.py b/netlib/http/models.py index bd5863b1..572d66c9 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -474,7 +474,6 @@ class Response(object): msg=None, headers=None, body=None, - sslinfo=None, timestamp_start=None, timestamp_end=None, ): @@ -487,7 +486,6 @@ class Response(object): self.msg = msg self.headers = headers self.body = body - self.sslinfo = sslinfo self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end diff --git a/netlib/tutils.py b/netlib/tutils.py index 65c4a313..758f8410 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -7,13 +7,15 @@ from contextlib import contextmanager import six import sys -from netlib import tcp, utils, http +from . import utils +from .http import Request, Response, Headers def treader(bytes): """ Construct a tcp.Read object from bytes. """ + from . import tcp # TODO: move to top once cryptography is on Python 3.5 fp = BytesIO(bytes) return tcp.Reader(fp) @@ -91,55 +93,39 @@ class RaisesContext(object): test_data = utils.Data(__name__) -def treq(content="content", scheme="http", host="address", port=22): +def treq(**kwargs): """ - @return: libmproxy.protocol.http.HTTPRequest + Returns: + netlib.http.Request """ - headers = http.Headers() - headers["header"] = "qvalue" - req = http.Request( - "relative", - "GET", - scheme, - host, - port, - "/path", - (1, 1), - headers, - content, - None, - None, + default = dict( + form_in="relative", + method=b"GET", + scheme=b"http", + host=b"address", + port=22, + path=b"/path", + httpversion=b"HTTP/1.1", + headers=Headers(header=b"qvalue"), + body=b"content" ) - return req + default.update(kwargs) + return Request(**default) -def treq_absolute(content="content"): +def tresp(**kwargs): """ - @return: libmproxy.protocol.http.HTTPRequest + Returns: + netlib.http.Response """ - r = treq(content) - r.form_in = r.form_out = "absolute" - r.host = "address" - r.port = 22 - r.scheme = "http" - return r - - -def tresp(content="message"): - """ - @return: libmproxy.protocol.http.HTTPResponse - """ - - headers = http.Headers() - headers["header_response"] = "svalue" - - resp = http.semantics.Response( - (1, 1), - 200, - "OK", - headers, - content, + default = dict( + httpversion=b"HTTP/1.1", + status_code=200, + msg=b"OK", + headers=Headers(header_response=b"svalue"), + body=b"message", timestamp_start=time.time(), - timestamp_end=time.time(), + timestamp_end=time.time() ) - return resp + default.update(kwargs) + return Response(**default) diff --git a/netlib/utils.py b/netlib/utils.py index fb579cac..a86b8019 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -40,9 +40,9 @@ def clean_bin(s, keep_spacing=True): ) else: if keep_spacing: - keep = b"\n\r\t" + keep = (9, 10, 13) # \t, \n, \r, else: - keep = b"" + keep = () return b"".join( six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." for ch in six.iterbytes(s) @@ -251,7 +251,7 @@ def hostport(scheme, host, port): if (port, scheme) in [(80, "http"), (443, "https")]: return host else: - return b"%s:%s" % (host, port) + return b"%s:%d" % (host, port) def unparse_url(scheme, host, port, path=""): diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py new file mode 100644 index 00000000..8a0a54f1 --- /dev/null +++ b/test/http/http1/test_assemble.py @@ -0,0 +1,91 @@ +from __future__ import absolute_import, print_function, division +from netlib.exceptions import HttpException +from netlib.http import CONTENT_MISSING, Headers +from netlib.http.http1.assemble import ( + assemble_request, assemble_request_head, assemble_response, + assemble_response_head, _assemble_request_line, _assemble_request_headers, + _assemble_response_headers +) +from netlib.tutils import treq, raises, tresp + + +def test_assemble_request(): + c = assemble_request(treq()) == ( + b"GET /path HTTP/1.1\r\n" + b"header: qvalue\r\n" + b"Host: address:22\r\n" + b"Content-Length: 7\r\n" + b"\r\n" + b"content" + ) + + with raises(HttpException): + assemble_request(treq(body=CONTENT_MISSING)) + + +def test_assemble_request_head(): + c = assemble_request_head(treq()) + assert b"GET" in c + assert b"qvalue" in c + assert b"content" not in c + + +def test_assemble_response(): + c = assemble_response(tresp()) == ( + b"HTTP/1.1 200 OK\r\n" + b"header-response: svalue\r\n" + b"Content-Length: 7\r\n" + b"\r\n" + b"message" + ) + + with raises(HttpException): + assemble_response(tresp(body=CONTENT_MISSING)) + + +def test_assemble_response_head(): + c = assemble_response_head(tresp()) + assert b"200" in c + assert b"svalue" in c + assert b"message" not in c + + +def test_assemble_request_line(): + assert _assemble_request_line(treq()) == b"GET /path HTTP/1.1" + + authority_request = treq(method=b"CONNECT", form_in="authority") + assert _assemble_request_line(authority_request) == b"CONNECT address:22 HTTP/1.1" + + absolute_request = treq(form_in="absolute") + assert _assemble_request_line(absolute_request) == b"GET http://address:22/path HTTP/1.1" + + with raises(RuntimeError): + _assemble_request_line(treq(), "invalid_form") + + +def test_assemble_request_headers(): + # https://github.com/mitmproxy/mitmproxy/issues/186 + r = treq(body=b"") + r.headers[b"Transfer-Encoding"] = b"chunked" + c = _assemble_request_headers(r) + assert b"Content-Length" in c + assert b"Transfer-Encoding" not in c + + assert b"Host" in _assemble_request_headers(treq(headers=Headers())) + + assert b"Proxy-Connection" not in _assemble_request_headers( + treq(headers=Headers(Proxy_Connection="42")) + ) + + +def test_assemble_response_headers(): + # https://github.com/mitmproxy/mitmproxy/issues/186 + r = tresp(body=b"") + r.headers["Transfer-Encoding"] = b"chunked" + c = _assemble_response_headers(r) + assert b"Content-Length" in c + assert b"Transfer-Encoding" not in c + + assert b"Proxy-Connection" not in _assemble_response_headers( + tresp(headers=Headers(Proxy_Connection=b"42")) + ) diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index bdcba5cb..e69de29b 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -1,466 +0,0 @@ -from io import BytesIO -import textwrap -from http.http1.protocol import _parse_authority_form -from netlib.exceptions import HttpSyntaxException, HttpReadDisconnect, HttpException - -from netlib import http, tcp, tutils -from netlib.http import semantics, Headers -from netlib.http.http1 import HTTP1Protocol, read_message_body, read_request, \ - read_message_body_chunked, expected_http_body_size -from ... import tservers - - -class NoContentLengthHTTPHandler(tcp.BaseHandler): - def handle(self): - self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") - self.wfile.flush() - - -def mock_protocol(data=''): - rfile = BytesIO(data) - wfile = BytesIO() - return HTTP1Protocol(rfile=rfile, wfile=wfile) - - -def match_http_string(data): - return textwrap.dedent(data).strip().replace('\n', '\r\n') - - -def test_stripped_chunked_encoding_no_content(): - """ - https://github.com/mitmproxy/mitmproxy/issues/186 - """ - - r = tutils.treq(content="") - r.headers["Transfer-Encoding"] = "chunked" - assert "Content-Length" in mock_protocol()._assemble_request_headers(r) - - r = tutils.tresp(content="") - r.headers["Transfer-Encoding"] = "chunked" - assert "Content-Length" in mock_protocol()._assemble_response_headers(r) - - -def test_read_chunked(): - req = tutils.treq(None) - req.headers["Transfer-Encoding"] = "chunked" - - data = b"1\r\na\r\n0\r\n" - with tutils.raises(HttpSyntaxException): - read_message_body(BytesIO(data), req) - - data = b"1\r\na\r\n0\r\n\r\n" - assert read_message_body(BytesIO(data), req) == b"a" - - data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n" - assert read_message_body(BytesIO(data), req) == b"ab" - - data = b"\r\n" - with tutils.raises("closed prematurely"): - read_message_body(BytesIO(data), req) - - data = b"1\r\nfoo" - with tutils.raises("malformed chunked body"): - read_message_body(BytesIO(data), req) - - data = b"foo\r\nfoo" - with tutils.raises(HttpSyntaxException): - read_message_body(BytesIO(data), req) - - data = b"5\r\naaaaa\r\n0\r\n\r\n" - with tutils.raises("too large"): - read_message_body(BytesIO(data), req, limit=2) - - -def test_connection_close(): - headers = Headers() - assert HTTP1Protocol.connection_close((1, 0), headers) - assert not HTTP1Protocol.connection_close((1, 1), headers) - - headers["connection"] = "keep-alive" - assert not HTTP1Protocol.connection_close((1, 1), headers) - - headers["connection"] = "close" - assert HTTP1Protocol.connection_close((1, 1), headers) - - -def test_read_http_body_request(): - headers = Headers() - data = "testing" - assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "" - - -def test_read_http_body_response(): - headers = Headers() - data = "testing" - assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing" - - -def test_read_http_body(): - # test default case - headers = Headers() - headers["content-length"] = "7" - data = "testing" - assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing" - - # test content length: invalid header - headers["content-length"] = "foo" - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, None, "GET", 200, False - ) - - # test content length: invalid header #2 - headers["content-length"] = "-1" - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, None, "GET", 200, False - ) - - # test content length: content length > actual content - headers["content-length"] = "5" - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, 4, "GET", 200, False - ) - - # test content length: content length < actual content - data = "testing" - assert len(mock_protocol(data).read_http_body(headers, None, "GET", 200, False)) == 5 - - # test no content length: limit > actual content - headers = Headers() - data = "testing" - assert len(mock_protocol(data).read_http_body(headers, 100, "GET", 200, False)) == 7 - - # test no content length: limit < actual content - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, 4, "GET", 200, False - ) - - # test chunked - headers = Headers() - headers["transfer-encoding"] = "chunked" - data = "5\r\naaaaa\r\n0\r\n\r\n" - assert mock_protocol(data).read_http_body(headers, 100, "GET", 200, False) == "aaaaa" - - -def test_expected_http_body_size(): - # gibber in the content-length field - headers = Headers(content_length="foo") - with tutils.raises(HttpSyntaxException): - expected_http_body_size(headers, False, "GET", 200) is None - # negative number in the content-length field - headers = Headers(content_length="-7") - with tutils.raises(HttpSyntaxException): - expected_http_body_size(headers, False, "GET", 200) is None - # explicit length - headers = Headers(content_length="5") - assert expected_http_body_size(headers, False, "GET", 200) == 5 - # no length - headers = Headers() - assert expected_http_body_size(headers, False, "GET", 200) == -1 - # no length request - headers = Headers() - assert expected_http_body_size(headers, True, "GET", None) == 0 - # expect header - headers = Headers(content_length="5", expect="100-continue") - assert expected_http_body_size(headers, True, "GET", None) == 0 - - -def test_parse_init_connect(): - assert _parse_authority_form(b"CONNECT host.com:443 HTTP/1.0") - tutils.raises(ValueError,_parse_authority_form, b"\0host.com:443") - tutils.raises(ValueError,_parse_authority_form, b"host.com:444444") - tutils.raises(ValueError,_parse_authority_form, b"CONNECT host.com443 HTTP/1.0") - tutils.raises(ValueError,_parse_authority_form, b"CONNECT host.com:foo HTTP/1.0") - - -def test_parse_init_proxy(): - u = b"GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = HTTP1Protocol._parse_absolute_form(u) - assert m == "GET" - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not HTTP1Protocol._parse_absolute_form(u) - - with tutils.raises(ValueError): - assert not HTTP1Protocol._parse_absolute_form("invalid") - with tutils.raises(ValueError): - assert not HTTP1Protocol._parse_absolute_form("GET invalid HTTP/1.1") - with tutils.raises(ValueError): - assert not HTTP1Protocol._parse_absolute_form("GET http://foo.com:8888/test foo/1.1") - - -def test_parse_init_http(): - u = "GET /test HTTP/1.1" - m, u, httpversion = HTTP1Protocol._parse_init_http(u) - assert m == "GET" - assert u == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET /test HTTP/1.1" - assert not HTTP1Protocol._parse_init_http(u) - - assert not HTTP1Protocol._parse_init_http("invalid") - assert not HTTP1Protocol._parse_init_http("GET invalid HTTP/1.1") - assert not HTTP1Protocol._parse_init_http("GET /test foo/1.1") - assert not HTTP1Protocol._parse_init_http("GET /test\xc0 HTTP/1.1") - - -class TestReadHeaders: - - def _read(self, data, verbatim=False): - if not verbatim: - data = textwrap.dedent(data) - data = data.strip() - return mock_protocol(data).read_headers() - - def test_read_simple(self): - data = """ - Header: one - Header2: two - \r\n - """ - headers = self._read(data) - assert headers.fields == [["Header", "one"], ["Header2", "two"]] - - def test_read_multi(self): - data = """ - Header: one - Header: two - \r\n - """ - headers = self._read(data) - assert headers.fields == [["Header", "one"], ["Header", "two"]] - - def test_read_continued(self): - data = """ - Header: one - \ttwo - Header2: three - \r\n - """ - headers = self._read(data) - assert headers.fields == [["Header", "one\r\n two"], ["Header2", "three"]] - - def test_read_continued_err(self): - data = "\tfoo: bar\r\n" - assert self._read(data, True) is None - - def test_read_err(self): - data = """ - foo - """ - assert self._read(data) is None - - -class TestReadRequest(object): - - def tst(self, data, **kwargs): - return mock_protocol(data).read_request(**kwargs) - - def test_invalid(self): - tutils.raises( - "bad http request", - self.tst, - "xxx" - ) - tutils.raises( - "bad http request line", - self.tst, - "get /\xff HTTP/1.1" - ) - tutils.raises( - "invalid headers", - self.tst, - "get / HTTP/1.1\r\nfoo" - ) - tutils.raises( - HttpReadDisconnect, - self.tst, - "\r\n" - ) - - def test_asterisk_form_in(self): - v = self.tst("OPTIONS * HTTP/1.1") - assert v.form_in == "relative" - assert v.method == "OPTIONS" - - def test_absolute_form_in(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "GET oops-no-protocol.com HTTP/1.1" - ) - v = self.tst("GET http://address:22/ HTTP/1.1") - assert v.form_in == "absolute" - assert v.port == 22 - assert v.host == "address" - assert v.scheme == "http" - - def test_connect(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "CONNECT oops-no-port.com HTTP/1.1" - ) - v = self.tst("CONNECT foo.com:443 HTTP/1.1") - assert v.form_in == "authority" - assert v.method == "CONNECT" - assert v.port == 443 - assert v.host == "foo.com" - - def test_expect(self): - data = ( - b"GET / HTTP/1.1\r\n" - b"Content-Length: 3\r\n" - b"Expect: 100-continue\r\n" - b"\r\n" - b"foobar" - ) - - rfile = BytesIO(data) - r = read_request(rfile) - assert r.body == b"" - assert rfile.read(-1) == b"foobar" - - -class TestReadResponse(object): - def tst(self, data, method, body_size_limit, include_body=True): - data = textwrap.dedent(data) - return mock_protocol(data).read_response( - method, body_size_limit, include_body=include_body - ) - - def test_errors(self): - tutils.raises("server disconnect", self.tst, "", "GET", None) - tutils.raises("invalid server response", self.tst, "foo", "GET", None) - - def test_simple(self): - data = """ - HTTP/1.1 200 - """ - assert self.tst(data, "GET", None) == http.Response( - (1, 1), 200, '', Headers(), '' - ) - - def test_simple_message(self): - data = """ - HTTP/1.1 200 OK - """ - assert self.tst(data, "GET", None) == http.Response( - (1, 1), 200, 'OK', Headers(), '' - ) - - def test_invalid_http_version(self): - data = """ - HTTP/x 200 OK - """ - tutils.raises("invalid http version", self.tst, data, "GET", None) - - def test_invalid_status_code(self): - data = """ - HTTP/1.1 xx OK - """ - tutils.raises("invalid server response", self.tst, data, "GET", None) - - def test_valid_with_continue(self): - data = """ - HTTP/1.1 100 CONTINUE - - HTTP/1.1 200 OK - """ - assert self.tst(data, "GET", None) == http.Response( - (1, 1), 100, 'CONTINUE', Headers(), '' - ) - - def test_simple_body(self): - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert self.tst(data, "GET", None).body == 'foo' - assert self.tst(data, "HEAD", None).body == '' - - def test_invalid_headers(self): - data = """ - HTTP/1.1 200 OK - \tContent-Length: 3 - - foo - """ - tutils.raises("invalid headers", self.tst, data, "GET", None) - - def test_without_body(self): - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert self.tst(data, "GET", None, include_body=False).body is None - - -class TestReadResponseNoContentLength(tservers.ServerTestBase): - handler = NoContentLengthHTTPHandler - - def test_no_content_length(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - resp = HTTP1Protocol(c).read_response("GET", None) - assert resp.body == "bar\r\n\r\n" - - -class TestAssembleRequest(object): - def test_simple(self): - req = tutils.treq() - b = HTTP1Protocol().assemble_request(req) - assert b == match_http_string(""" - GET /path HTTP/1.1 - header: qvalue - Host: address:22 - Content-Length: 7 - - content""") - - def test_body_missing(self): - req = tutils.treq(content=semantics.CONTENT_MISSING) - tutils.raises(http.HttpError, HTTP1Protocol().assemble_request, req) - - def test_not_a_request(self): - tutils.raises(AssertionError, HTTP1Protocol().assemble_request, 'foo') - - -class TestAssembleResponse(object): - def test_simple(self): - resp = tutils.tresp() - b = HTTP1Protocol().assemble_response(resp) - assert b == match_http_string(""" - HTTP/1.1 200 OK - header_response: svalue - Content-Length: 7 - - message""") - - def test_body_missing(self): - resp = tutils.tresp(content=semantics.CONTENT_MISSING) - tutils.raises(http.HttpError, HTTP1Protocol().assemble_response, resp) - - def test_not_a_request(self): - tutils.raises(AssertionError, HTTP1Protocol().assemble_response, 'foo') diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py new file mode 100644 index 00000000..5e6680af --- /dev/null +++ b/test/http/http1/test_read.py @@ -0,0 +1,313 @@ +from __future__ import absolute_import, print_function, division +from io import BytesIO +import textwrap + +from mock import Mock + +from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect +from netlib.http import Headers +from netlib.http.http1.read import ( + read_request, read_response, read_request_head, + read_response_head, read_body, connection_close, expected_http_body_size, _get_first_line, + _read_request_line, _parse_authority_form, _read_response_line, _check_http_version, + _read_headers, _read_chunked +) +from netlib.tutils import treq, tresp, raises + + +def test_read_request(): + rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip") + r = read_request(rfile) + assert r.method == b"GET" + assert r.body == b"" + assert r.timestamp_end + assert rfile.read() == b"skip" + + +def test_read_request_head(): + rfile = BytesIO( + b"GET / HTTP/1.1\r\n" + b"Content-Length: 4\r\n" + b"\r\n" + b"skip" + ) + rfile.reset_timestamps = Mock() + rfile.first_byte_timestamp = 42 + r = read_request_head(rfile) + assert r.method == b"GET" + assert r.headers["Content-Length"] == b"4" + assert r.body is None + assert rfile.reset_timestamps.called + assert r.timestamp_start == 42 + assert rfile.read() == b"skip" + + +def test_read_response(): + req = treq() + rfile = BytesIO(b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody") + r = read_response(rfile, req) + assert r.status_code == 418 + assert r.body == b"body" + assert r.timestamp_end + + +def test_read_response_head(): + rfile = BytesIO( + b"HTTP/1.1 418 I'm a teapot\r\n" + b"Content-Length: 4\r\n" + b"\r\n" + b"skip" + ) + rfile.reset_timestamps = Mock() + rfile.first_byte_timestamp = 42 + r = read_response_head(rfile) + assert r.status_code == 418 + assert r.headers["Content-Length"] == b"4" + assert r.body is None + assert rfile.reset_timestamps.called + assert r.timestamp_start == 42 + assert rfile.read() == b"skip" + + +class TestReadBody(object): + def test_chunked(self): + rfile = BytesIO(b"3\r\nfoo\r\n0\r\n\r\nbar") + body = b"".join(read_body(rfile, None)) + assert body == b"foo" + assert rfile.read() == b"bar" + + + def test_known_size(self): + rfile = BytesIO(b"foobar") + body = b"".join(read_body(rfile, 3)) + assert body == b"foo" + assert rfile.read() == b"bar" + + + def test_known_size_limit(self): + rfile = BytesIO(b"foobar") + with raises(HttpException): + b"".join(read_body(rfile, 3, 2)) + + def test_known_size_too_short(self): + rfile = BytesIO(b"foo") + with raises(HttpException): + b"".join(read_body(rfile, 6)) + + def test_unknown_size(self): + rfile = BytesIO(b"foobar") + body = b"".join(read_body(rfile, -1)) + assert body == b"foobar" + + + def test_unknown_size_limit(self): + rfile = BytesIO(b"foobar") + with raises(HttpException): + b"".join(read_body(rfile, -1, 3)) + + +def test_connection_close(): + headers = Headers() + assert connection_close((1, 0), headers) + assert not connection_close((1, 1), headers) + + headers["connection"] = "keep-alive" + assert not connection_close((1, 1), headers) + + headers["connection"] = "close" + assert connection_close((1, 1), headers) + + +def test_expected_http_body_size(): + # Expect: 100-continue + assert expected_http_body_size( + treq(headers=Headers(expect=b"100-continue", content_length=b"42")) + ) == 0 + + # http://tools.ietf.org/html/rfc7230#section-3.3 + assert expected_http_body_size( + treq(method=b"HEAD"), + tresp(headers=Headers(content_length=b"42")) + ) == 0 + assert expected_http_body_size( + treq(method=b"CONNECT"), + tresp() + ) == 0 + for code in (100, 204, 304): + assert expected_http_body_size( + treq(), + tresp(status_code=code) + ) == 0 + + # chunked + assert expected_http_body_size( + treq(headers=Headers(transfer_encoding=b"chunked")), + ) is None + + # explicit length + for l in (b"foo", b"-7"): + with raises(HttpSyntaxException): + expected_http_body_size( + treq(headers=Headers(content_length=l)) + ) + assert expected_http_body_size( + treq(headers=Headers(content_length=b"42")) + ) == 42 + + # no length + assert expected_http_body_size( + treq() + ) == 0 + assert expected_http_body_size( + treq(), tresp() + ) == -1 + + +def test_get_first_line(): + rfile = BytesIO(b"foo\r\nbar") + assert _get_first_line(rfile) == b"foo" + + rfile = BytesIO(b"\r\nfoo\r\nbar") + assert _get_first_line(rfile) == b"foo" + + with raises(HttpReadDisconnect): + rfile = BytesIO(b"") + _get_first_line(rfile) + + with raises(HttpSyntaxException): + rfile = BytesIO(b"GET /\xff HTTP/1.1") + _get_first_line(rfile) + + +def test_read_request_line(): + def t(b): + return _read_request_line(BytesIO(b)) + + assert (t(b"GET / HTTP/1.1") == + ("relative", b"GET", None, None, None, b"/", b"HTTP/1.1")) + assert (t(b"OPTIONS * HTTP/1.1") == + ("relative", b"OPTIONS", None, None, None, b"*", b"HTTP/1.1")) + assert (t(b"CONNECT foo:42 HTTP/1.1") == + ("authority", b"CONNECT", None, b"foo", 42, None, b"HTTP/1.1")) + assert (t(b"GET http://foo:42/bar HTTP/1.1") == + ("absolute", b"GET", b"http", b"foo", 42, b"/bar", b"HTTP/1.1")) + + with raises(HttpSyntaxException): + t(b"GET / WTF/1.1") + with raises(HttpSyntaxException): + t(b"this is not http") + + +def test_parse_authority_form(): + assert _parse_authority_form(b"foo:42") == (b"foo", 42) + with raises(HttpSyntaxException): + _parse_authority_form(b"foo") + with raises(HttpSyntaxException): + _parse_authority_form(b"foo:bar") + with raises(HttpSyntaxException): + _parse_authority_form(b"foo:99999999") + with raises(HttpSyntaxException): + _parse_authority_form(b"f\x00oo:80") + + +def test_read_response_line(): + def t(b): + return _read_response_line(BytesIO(b)) + + assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK") + assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"") + with raises(HttpSyntaxException): + assert t(b"HTTP/1.1") + + with raises(HttpSyntaxException): + t(b"HTTP/1.1 OK OK") + with raises(HttpSyntaxException): + t(b"WTF/1.1 200 OK") + + +def test_check_http_version(): + _check_http_version(b"HTTP/0.9") + _check_http_version(b"HTTP/1.0") + _check_http_version(b"HTTP/1.1") + _check_http_version(b"HTTP/2.0") + with raises(HttpSyntaxException): + _check_http_version(b"WTF/1.0") + with raises(HttpSyntaxException): + _check_http_version(b"HTTP/1.10") + with raises(HttpSyntaxException): + _check_http_version(b"HTTP/1.b") + + +class TestReadHeaders(object): + @staticmethod + def _read(data): + return _read_headers(BytesIO(data)) + + def test_read_simple(self): + data = ( + b"Header: one\r\n" + b"Header2: two\r\n" + b"\r\n" + ) + headers = self._read(data) + assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]] + + def test_read_multi(self): + data = ( + b"Header: one\r\n" + b"Header: two\r\n" + b"\r\n" + ) + headers = self._read(data) + assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]] + + def test_read_continued(self): + data = ( + b"Header: one\r\n" + b"\ttwo\r\n" + b"Header2: three\r\n" + b"\r\n" + ) + headers = self._read(data) + assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]] + + def test_read_continued_err(self): + data = b"\tfoo: bar\r\n" + with raises(HttpSyntaxException): + self._read(data) + + def test_read_err(self): + data = b"foo" + with raises(HttpSyntaxException): + self._read(data) + + +def test_read_chunked(): + req = treq(body=None) + req.headers["Transfer-Encoding"] = "chunked" + + data = b"1\r\na\r\n0\r\n" + with raises(HttpSyntaxException): + b"".join(_read_chunked(BytesIO(data))) + + data = b"1\r\na\r\n0\r\n\r\n" + assert b"".join(_read_chunked(BytesIO(data))) == b"a" + + data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n" + assert b"".join(_read_chunked(BytesIO(data))) == b"ab" + + data = b"\r\n" + with raises("closed prematurely"): + b"".join(_read_chunked(BytesIO(data))) + + data = b"1\r\nfoo" + with raises("malformed chunked body"): + b"".join(_read_chunked(BytesIO(data))) + + data = b"foo\r\nfoo" + with raises(HttpSyntaxException): + b"".join(_read_chunked(BytesIO(data))) + + data = b"5\r\naaaaa\r\n0\r\n\r\n" + with raises("too large"): + b"".join(_read_chunked(BytesIO(data), limit=2)) diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py index efdb55e2..4c89b023 100644 --- a/test/http/http2/test_frames.py +++ b/test/http/http2/test_frames.py @@ -39,7 +39,7 @@ def test_too_large_frames(): flags=Frame.FLAG_END_STREAM, stream_id=0x1234567, payload='foobar' * 3000) - tutils.raises(FrameSizeError, f.to_bytes) + tutils.raises(HttpSyntaxException, f.to_bytes) def test_data_frame_to_bytes(): diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 2b7d7958..789b6e63 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -2,21 +2,21 @@ import OpenSSL import mock from netlib import tcp, http, tutils -from netlib.http import http2, Headers -from netlib.http.http2 import HTTP2Protocol +from netlib.http import Headers +from netlib.http.http2.connections import HTTP2Protocol, TCPHandler from netlib.http.http2.frame import * from ... import tservers class TestTCPHandlerWrapper: def test_wrapped(self): - h = http2.TCPHandler(rfile='foo', wfile='bar') + h = TCPHandler(rfile='foo', wfile='bar') p = HTTP2Protocol(h) assert p.tcp_handler.rfile == 'foo' assert p.tcp_handler.wfile == 'bar' def test_direct(self): p = HTTP2Protocol(rfile='foo', wfile='bar') - assert isinstance(p.tcp_handler, http2.TCPHandler) + assert isinstance(p.tcp_handler, TCPHandler) assert p.tcp_handler.rfile == 'foo' assert p.tcp_handler.wfile == 'bar' @@ -32,8 +32,8 @@ class EchoHandler(tcp.BaseHandler): class TestProtocol: - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") def test_perform_connection_preface(self, mock_client_method, mock_server_method): protocol = HTTP2Protocol(is_server=False) protocol.connection_preface_performed = True @@ -46,8 +46,8 @@ class TestProtocol: assert mock_client_method.called assert not mock_server_method.called - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): protocol = HTTP2Protocol(is_server=True) protocol.connection_preface_performed = True diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py deleted file mode 100644 index 49588d0a..00000000 --- a/test/http/test_exceptions.py +++ /dev/null @@ -1,6 +0,0 @@ -from netlib.http.exceptions import * - -class TestHttpError: - def test_simple(self): - e = HttpError(404, "Not found") - assert str(e) diff --git a/test/http/test_models.py b/test/http/test_models.py new file mode 100644 index 00000000..0f4dcc3b --- /dev/null +++ b/test/http/test_models.py @@ -0,0 +1,540 @@ +import mock + +from netlib import tutils +from netlib import utils +from netlib.odict import ODict, ODictCaseless +from netlib.http import Request, Response, Headers, CONTENT_MISSING, HDR_FORM_URLENCODED, \ + HDR_FORM_MULTIPART + + +class TestRequest(object): + def test_repr(self): + r = tutils.treq() + assert repr(r) + + def test_headers(self): + tutils.raises(AssertionError, Request, + 'form_in', + 'method', + 'scheme', + 'host', + 'port', + 'path', + (1, 1), + 'foobar', + ) + + req = Request( + 'form_in', + 'method', + 'scheme', + 'host', + 'port', + 'path', + (1, 1), + ) + assert isinstance(req.headers, Headers) + + def test_equal(self): + a = tutils.treq() + b = tutils.treq() + assert a == b + + assert not a == 'foo' + assert not b == 'foo' + assert not 'foo' == a + assert not 'foo' == b + + + def test_anticache(self): + req = tutils.treq() + req.headers["If-Modified-Since"] = "foo" + req.headers["If-None-Match"] = "bar" + req.anticache() + assert "If-Modified-Since" not in req.headers + assert "If-None-Match" not in req.headers + + def test_anticomp(self): + req = tutils.treq() + req.headers["Accept-Encoding"] = "foobar" + req.anticomp() + assert req.headers["Accept-Encoding"] == "identity" + + def test_constrain_encoding(self): + req = tutils.treq() + req.headers["Accept-Encoding"] = "identity, gzip, foo" + req.constrain_encoding() + assert "foo" not in req.headers["Accept-Encoding"] + + def test_update_host(self): + req = tutils.treq() + req.headers["Host"] = "" + req.host = "foobar" + req.update_host_header() + assert req.headers["Host"] == "foobar" + + def test_get_form(self): + req = tutils.treq() + assert req.get_form() == ODict() + + @mock.patch("netlib.http.Request.get_form_multipart") + @mock.patch("netlib.http.Request.get_form_urlencoded") + def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart): + req = tutils.treq() + assert req.get_form() == ODict() + + req = tutils.treq() + req.body = "foobar" + req.headers["Content-Type"] = HDR_FORM_URLENCODED + req.get_form() + assert req.get_form_urlencoded.called + assert not req.get_form_multipart.called + + @mock.patch("netlib.http.Request.get_form_multipart") + @mock.patch("netlib.http.Request.get_form_urlencoded") + def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart): + req = tutils.treq() + req.body = "foobar" + req.headers["Content-Type"] = HDR_FORM_MULTIPART + req.get_form() + assert not req.get_form_urlencoded.called + assert req.get_form_multipart.called + + def test_get_form_urlencoded(self): + req = tutils.treq(body="foobar") + assert req.get_form_urlencoded() == ODict() + + req.headers["Content-Type"] = HDR_FORM_URLENCODED + assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body)) + + def test_get_form_multipart(self): + req = tutils.treq(body="foobar") + assert req.get_form_multipart() == ODict() + + req.headers["Content-Type"] = HDR_FORM_MULTIPART + assert req.get_form_multipart() == ODict( + utils.multipartdecode( + req.headers, + req.body + ) + ) + + def test_set_form_urlencoded(self): + req = tutils.treq() + req.set_form_urlencoded(ODict([('foo', 'bar'), ('rab', 'oof')])) + assert req.headers["Content-Type"] == HDR_FORM_URLENCODED + assert req.body + + def test_get_path_components(self): + req = tutils.treq() + assert req.get_path_components() + # TODO: add meaningful assertions + + def test_set_path_components(self): + req = tutils.treq() + req.set_path_components(["foo", "bar"]) + # TODO: add meaningful assertions + + def test_get_query(self): + req = tutils.treq() + assert req.get_query().lst == [] + + req.url = "http://localhost:80/foo?bar=42" + assert req.get_query().lst == [("bar", "42")] + + def test_set_query(self): + req = tutils.treq() + req.set_query(ODict([])) + + def test_pretty_host(self): + r = tutils.treq() + assert r.pretty_host(True) == "address" + assert r.pretty_host(False) == "address" + r.headers["host"] = "other" + assert r.pretty_host(True) == "other" + assert r.pretty_host(False) == "address" + r.host = None + assert r.pretty_host(True) == "other" + assert r.pretty_host(False) is None + del r.headers["host"] + assert r.pretty_host(True) is None + assert r.pretty_host(False) is None + + # Invalid IDNA + r.headers["host"] = ".disqus.com" + assert r.pretty_host(True) == ".disqus.com" + + def test_pretty_url(self): + req = tutils.treq() + req.form_out = "authority" + assert req.pretty_url(True) == "address:22" + assert req.pretty_url(False) == "address:22" + + req.form_out = "relative" + assert req.pretty_url(True) == "http://address:22/path" + assert req.pretty_url(False) == "http://address:22/path" + + def test_get_cookies_none(self): + headers = Headers() + r = tutils.treq() + r.headers = headers + assert len(r.get_cookies()) == 0 + + def test_get_cookies_single(self): + r = tutils.treq() + r.headers = Headers(cookie="cookiename=cookievalue") + result = r.get_cookies() + assert len(result) == 1 + assert result['cookiename'] == ['cookievalue'] + + def test_get_cookies_double(self): + r = tutils.treq() + r.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") + result = r.get_cookies() + assert len(result) == 2 + assert result['cookiename'] == ['cookievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_get_cookies_withequalsign(self): + r = tutils.treq() + r.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") + result = r.get_cookies() + assert len(result) == 2 + assert result['cookiename'] == ['coo=kievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_set_cookies(self): + r = tutils.treq() + r.headers = Headers(cookie="cookiename=cookievalue") + result = r.get_cookies() + result["cookiename"] = ["foo"] + r.set_cookies(result) + assert r.get_cookies()["cookiename"] == ["foo"] + + def test_set_url(self): + r = tutils.treq(form_in="absolute") + r.url = "https://otheraddress:42/ORLY" + assert r.scheme == "https" + assert r.host == "otheraddress" + assert r.port == 42 + assert r.path == "/ORLY" + + try: + r.url = "//localhost:80/foo@bar" + assert False + except: + assert True + + # def test_asterisk_form_in(self): + # f = tutils.tflow(req=None) + # protocol = mock_protocol("OPTIONS * HTTP/1.1") + # f.request = HTTPRequest.from_protocol(protocol) + # + # assert f.request.form_in == "relative" + # f.request.host = f.server_conn.address.host + # f.request.port = f.server_conn.address.port + # f.request.scheme = "http" + # assert protocol.assemble(f.request) == ( + # "OPTIONS * HTTP/1.1\r\n" + # "Host: address:22\r\n" + # "Content-Length: 0\r\n\r\n") + # + # def test_relative_form_in(self): + # protocol = mock_protocol("GET /foo\xff HTTP/1.1") + # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) + # + # protocol = mock_protocol("GET /foo HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: h2c") + # r = HTTPRequest.from_protocol(protocol) + # assert r.headers["Upgrade"] == ["h2c"] + # + # def test_expect_header(self): + # protocol = mock_protocol( + # "GET / HTTP/1.1\r\nContent-Length: 3\r\nExpect: 100-continue\r\n\r\nfoobar") + # r = HTTPRequest.from_protocol(protocol) + # assert protocol.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" + # assert r.content == "foo" + # assert protocol.tcp_handler.rfile.read(3) == "bar" + # + # def test_authority_form_in(self): + # protocol = mock_protocol("CONNECT oops-no-port.com HTTP/1.1") + # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) + # + # protocol = mock_protocol("CONNECT address:22 HTTP/1.1") + # r = HTTPRequest.from_protocol(protocol) + # r.scheme, r.host, r.port = "http", "address", 22 + # assert protocol.assemble(r) == ( + # "CONNECT address:22 HTTP/1.1\r\n" + # "Host: address:22\r\n" + # "Content-Length: 0\r\n\r\n") + # assert r.pretty_url(False) == "address:22" + # + # def test_absolute_form_in(self): + # protocol = mock_protocol("GET oops-no-protocol.com HTTP/1.1") + # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) + # + # protocol = mock_protocol("GET http://address:22/ HTTP/1.1") + # r = HTTPRequest.from_protocol(protocol) + # assert protocol.assemble(r) == ( + # "GET http://address:22/ HTTP/1.1\r\n" + # "Host: address:22\r\n" + # "Content-Length: 0\r\n\r\n") + # + # def test_http_options_relative_form_in(self): + # """ + # Exercises fix for Issue #392. + # """ + # protocol = mock_protocol("OPTIONS /secret/resource HTTP/1.1") + # r = HTTPRequest.from_protocol(protocol) + # r.host = 'address' + # r.port = 80 + # r.scheme = "http" + # assert protocol.assemble(r) == ( + # "OPTIONS /secret/resource HTTP/1.1\r\n" + # "Host: address\r\n" + # "Content-Length: 0\r\n\r\n") + # + # def test_http_options_absolute_form_in(self): + # protocol = mock_protocol("OPTIONS http://address/secret/resource HTTP/1.1") + # r = HTTPRequest.from_protocol(protocol) + # r.host = 'address' + # r.port = 80 + # r.scheme = "http" + # assert protocol.assemble(r) == ( + # "OPTIONS http://address:80/secret/resource HTTP/1.1\r\n" + # "Host: address\r\n" + # "Content-Length: 0\r\n\r\n") + +class TestResponse(object): + def test_headers(self): + tutils.raises(AssertionError, Response, + (1, 1), + 200, + headers='foobar', + ) + + resp = Response( + (1, 1), + 200, + ) + assert isinstance(resp.headers, Headers) + + def test_equal(self): + a = tutils.tresp() + b = tutils.tresp() + assert a == b + + assert not a == 'foo' + assert not b == 'foo' + assert not 'foo' == a + assert not 'foo' == b + + def test_repr(self): + r = tutils.tresp() + assert "unknown content type" in repr(r) + r.headers["content-type"] = "foo" + assert "foo" in repr(r) + assert repr(tutils.tresp(body=CONTENT_MISSING)) + + def test_get_cookies_none(self): + resp = tutils.tresp() + resp.headers = Headers() + assert not resp.get_cookies() + + def test_get_cookies_simple(self): + resp = tutils.tresp() + resp.headers = Headers(set_cookie="cookiename=cookievalue") + result = resp.get_cookies() + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0] == ["cookievalue", ODict()] + + def test_get_cookies_with_parameters(self): + resp = tutils.tresp() + resp.headers = Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly") + result = resp.get_cookies() + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0][0] == "cookievalue" + attrs = result["cookiename"][0][1] + assert len(attrs) == 4 + assert attrs["domain"] == ["example.com"] + assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"] + assert attrs["path"] == ["/"] + assert attrs["httponly"] == [None] + + def test_get_cookies_no_value(self): + resp = tutils.tresp() + resp.headers = Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") + result = resp.get_cookies() + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0][0] == "" + assert len(result["cookiename"][0][1]) == 2 + + def test_get_cookies_twocookies(self): + resp = tutils.tresp() + resp.headers = Headers([ + ["Set-Cookie", "cookiename=cookievalue"], + ["Set-Cookie", "othercookie=othervalue"] + ]) + result = resp.get_cookies() + assert len(result) == 2 + assert "cookiename" in result + assert result["cookiename"][0] == ["cookievalue", ODict()] + assert "othercookie" in result + assert result["othercookie"][0] == ["othervalue", ODict()] + + def test_set_cookies(self): + resp = tutils.tresp() + v = resp.get_cookies() + v.add("foo", ["bar", ODictCaseless()]) + resp.set_cookies(v) + + v = resp.get_cookies() + assert len(v) == 1 + assert v["foo"] == [["bar", ODictCaseless()]] + + +class TestHeaders(object): + def _2host(self): + return Headers( + [ + ["Host", "example.com"], + ["host", "example.org"] + ] + ) + + def test_init(self): + headers = Headers() + assert len(headers) == 0 + + headers = Headers([["Host", "example.com"]]) + assert len(headers) == 1 + assert headers["Host"] == "example.com" + + headers = Headers(Host="example.com") + assert len(headers) == 1 + assert headers["Host"] == "example.com" + + headers = Headers( + [["Host", "invalid"]], + Host="example.com" + ) + assert len(headers) == 1 + assert headers["Host"] == "example.com" + + headers = Headers( + [["Host", "invalid"], ["Accept", "text/plain"]], + Host="example.com" + ) + assert len(headers) == 2 + assert headers["Host"] == "example.com" + assert headers["Accept"] == "text/plain" + + def test_getitem(self): + headers = Headers(Host="example.com") + assert headers["Host"] == "example.com" + assert headers["host"] == "example.com" + tutils.raises(KeyError, headers.__getitem__, "Accept") + + headers = self._2host() + assert headers["Host"] == "example.com, example.org" + + def test_str(self): + headers = Headers(Host="example.com") + assert bytes(headers) == "Host: example.com\r\n" + + headers = Headers([ + ["Host", "example.com"], + ["Accept", "text/plain"] + ]) + assert str(headers) == "Host: example.com\r\nAccept: text/plain\r\n" + + def test_setitem(self): + headers = Headers() + headers["Host"] = "example.com" + assert "Host" in headers + assert "host" in headers + assert headers["Host"] == "example.com" + + headers["host"] = "example.org" + assert "Host" in headers + assert "host" in headers + assert headers["Host"] == "example.org" + + headers["accept"] = "text/plain" + assert len(headers) == 2 + assert "Accept" in headers + assert "Host" in headers + + headers = self._2host() + assert len(headers.fields) == 2 + headers["Host"] = "example.com" + assert len(headers.fields) == 1 + assert "Host" in headers + + def test_delitem(self): + headers = Headers(Host="example.com") + assert len(headers) == 1 + del headers["host"] + assert len(headers) == 0 + try: + del headers["host"] + except KeyError: + assert True + else: + assert False + + headers = self._2host() + del headers["Host"] + assert len(headers) == 0 + + def test_keys(self): + headers = Headers(Host="example.com") + assert len(headers.keys()) == 1 + assert headers.keys()[0] == "Host" + + headers = self._2host() + assert len(headers.keys()) == 1 + assert headers.keys()[0] == "Host" + + def test_eq_ne(self): + headers1 = Headers(Host="example.com") + headers2 = Headers(host="example.com") + assert not (headers1 == headers2) + assert headers1 != headers2 + + headers1 = Headers(Host="example.com") + headers2 = Headers(Host="example.com") + assert headers1 == headers2 + assert not (headers1 != headers2) + + assert headers1 != 42 + + def test_get_all(self): + headers = self._2host() + assert headers.get_all("host") == ["example.com", "example.org"] + assert headers.get_all("accept") == [] + + def test_set_all(self): + headers = Headers(Host="example.com") + headers.set_all("Accept", ["text/plain"]) + assert len(headers) == 2 + assert "accept" in headers + + headers = self._2host() + headers.set_all("Host", ["example.org"]) + assert headers["host"] == "example.org" + + headers.set_all("Host", ["example.org", "example.net"]) + assert headers["host"] == "example.org, example.net" + + def test_state(self): + headers = self._2host() + assert len(headers.get_state()) == 2 + assert headers == Headers.from_state(headers.get_state()) + + headers2 = Headers() + assert headers != headers2 + headers2.load_state(headers.get_state()) + assert headers == headers2 diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py deleted file mode 100644 index 44d3c85e..00000000 --- a/test/http/test_semantics.py +++ /dev/null @@ -1,573 +0,0 @@ -import mock - -from netlib import http -from netlib import odict -from netlib import tutils -from netlib import utils -from netlib.http import semantics -from netlib.http.semantics import CONTENT_MISSING - -class TestProtocolMixin(object): - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") - def test_assemble_request(self, mock_request_method, mock_response_method): - p = semantics.ProtocolMixin() - p.assemble(tutils.treq()) - assert mock_request_method.called - assert not mock_response_method.called - - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") - def test_assemble_response(self, mock_request_method, mock_response_method): - p = semantics.ProtocolMixin() - p.assemble(tutils.tresp()) - assert not mock_request_method.called - assert mock_response_method.called - - def test_assemble_foo(self): - p = semantics.ProtocolMixin() - tutils.raises(ValueError, p.assemble, 'foo') - -class TestRequest(object): - def test_repr(self): - r = tutils.treq() - assert repr(r) - - def test_headers(self): - tutils.raises(AssertionError, semantics.Request, - 'form_in', - 'method', - 'scheme', - 'host', - 'port', - 'path', - (1, 1), - 'foobar', - ) - - req = semantics.Request( - 'form_in', - 'method', - 'scheme', - 'host', - 'port', - 'path', - (1, 1), - ) - assert isinstance(req.headers, http.Headers) - - def test_equal(self): - a = tutils.treq() - b = tutils.treq() - assert a == b - - assert not a == 'foo' - assert not b == 'foo' - assert not 'foo' == a - assert not 'foo' == b - - def test_legacy_first_line(self): - req = tutils.treq() - - assert req.legacy_first_line('relative') == "GET /path HTTP/1.1" - assert req.legacy_first_line('authority') == "GET address:22 HTTP/1.1" - assert req.legacy_first_line('absolute') == "GET http://address:22/path HTTP/1.1" - tutils.raises(http.HttpError, req.legacy_first_line, 'foobar') - - def test_anticache(self): - req = tutils.treq() - req.headers["If-Modified-Since"] = "foo" - req.headers["If-None-Match"] = "bar" - req.anticache() - assert "If-Modified-Since" not in req.headers - assert "If-None-Match" not in req.headers - - def test_anticomp(self): - req = tutils.treq() - req.headers["Accept-Encoding"] = "foobar" - req.anticomp() - assert req.headers["Accept-Encoding"] == "identity" - - def test_constrain_encoding(self): - req = tutils.treq() - req.headers["Accept-Encoding"] = "identity, gzip, foo" - req.constrain_encoding() - assert "foo" not in req.headers["Accept-Encoding"] - - def test_update_host(self): - req = tutils.treq() - req.headers["Host"] = "" - req.host = "foobar" - req.update_host_header() - assert req.headers["Host"] == "foobar" - - def test_get_form(self): - req = tutils.treq() - assert req.get_form() == odict.ODict() - - @mock.patch("netlib.http.semantics.Request.get_form_multipart") - @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") - def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart): - req = tutils.treq() - assert req.get_form() == odict.ODict() - - req = tutils.treq() - req.body = "foobar" - req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED - req.get_form() - assert req.get_form_urlencoded.called - assert not req.get_form_multipart.called - - @mock.patch("netlib.http.semantics.Request.get_form_multipart") - @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") - def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart): - req = tutils.treq() - req.body = "foobar" - req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART - req.get_form() - assert not req.get_form_urlencoded.called - assert req.get_form_multipart.called - - def test_get_form_urlencoded(self): - req = tutils.treq("foobar") - assert req.get_form_urlencoded() == odict.ODict() - - req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED - assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body)) - - def test_get_form_multipart(self): - req = tutils.treq("foobar") - assert req.get_form_multipart() == odict.ODict() - - req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART - assert req.get_form_multipart() == odict.ODict( - utils.multipartdecode( - req.headers, - req.body - ) - ) - - def test_set_form_urlencoded(self): - req = tutils.treq() - req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')])) - assert req.headers["Content-Type"] == semantics.HDR_FORM_URLENCODED - assert req.body - - def test_get_path_components(self): - req = tutils.treq() - assert req.get_path_components() - # TODO: add meaningful assertions - - def test_set_path_components(self): - req = tutils.treq() - req.set_path_components(["foo", "bar"]) - # TODO: add meaningful assertions - - def test_get_query(self): - req = tutils.treq() - assert req.get_query().lst == [] - - req.url = "http://localhost:80/foo?bar=42" - assert req.get_query().lst == [("bar", "42")] - - def test_set_query(self): - req = tutils.treq() - req.set_query(odict.ODict([])) - - def test_pretty_host(self): - r = tutils.treq() - assert r.pretty_host(True) == "address" - assert r.pretty_host(False) == "address" - r.headers["host"] = "other" - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) == "address" - r.host = None - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) is None - del r.headers["host"] - assert r.pretty_host(True) is None - assert r.pretty_host(False) is None - - # Invalid IDNA - r.headers["host"] = ".disqus.com" - assert r.pretty_host(True) == ".disqus.com" - - def test_pretty_url(self): - req = tutils.treq() - req.form_out = "authority" - assert req.pretty_url(True) == "address:22" - assert req.pretty_url(False) == "address:22" - - req.form_out = "relative" - assert req.pretty_url(True) == "http://address:22/path" - assert req.pretty_url(False) == "http://address:22/path" - - def test_get_cookies_none(self): - headers = http.Headers() - r = tutils.treq() - r.headers = headers - assert len(r.get_cookies()) == 0 - - def test_get_cookies_single(self): - r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=cookievalue") - result = r.get_cookies() - assert len(result) == 1 - assert result['cookiename'] == ['cookievalue'] - - def test_get_cookies_double(self): - r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['cookievalue'] - assert result['othercookiename'] == ['othercookievalue'] - - def test_get_cookies_withequalsign(self): - r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['coo=kievalue'] - assert result['othercookiename'] == ['othercookievalue'] - - def test_set_cookies(self): - r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=cookievalue") - result = r.get_cookies() - result["cookiename"] = ["foo"] - r.set_cookies(result) - assert r.get_cookies()["cookiename"] == ["foo"] - - def test_set_url(self): - r = tutils.treq_absolute() - r.url = "https://otheraddress:42/ORLY" - assert r.scheme == "https" - assert r.host == "otheraddress" - assert r.port == 42 - assert r.path == "/ORLY" - - try: - r.url = "//localhost:80/foo@bar" - assert False - except: - assert True - - # def test_asterisk_form_in(self): - # f = tutils.tflow(req=None) - # protocol = mock_protocol("OPTIONS * HTTP/1.1") - # f.request = HTTPRequest.from_protocol(protocol) - # - # assert f.request.form_in == "relative" - # f.request.host = f.server_conn.address.host - # f.request.port = f.server_conn.address.port - # f.request.scheme = "http" - # assert protocol.assemble(f.request) == ( - # "OPTIONS * HTTP/1.1\r\n" - # "Host: address:22\r\n" - # "Content-Length: 0\r\n\r\n") - # - # def test_relative_form_in(self): - # protocol = mock_protocol("GET /foo\xff HTTP/1.1") - # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) - # - # protocol = mock_protocol("GET /foo HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: h2c") - # r = HTTPRequest.from_protocol(protocol) - # assert r.headers["Upgrade"] == ["h2c"] - # - # def test_expect_header(self): - # protocol = mock_protocol( - # "GET / HTTP/1.1\r\nContent-Length: 3\r\nExpect: 100-continue\r\n\r\nfoobar") - # r = HTTPRequest.from_protocol(protocol) - # assert protocol.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" - # assert r.content == "foo" - # assert protocol.tcp_handler.rfile.read(3) == "bar" - # - # def test_authority_form_in(self): - # protocol = mock_protocol("CONNECT oops-no-port.com HTTP/1.1") - # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) - # - # protocol = mock_protocol("CONNECT address:22 HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # r.scheme, r.host, r.port = "http", "address", 22 - # assert protocol.assemble(r) == ( - # "CONNECT address:22 HTTP/1.1\r\n" - # "Host: address:22\r\n" - # "Content-Length: 0\r\n\r\n") - # assert r.pretty_url(False) == "address:22" - # - # def test_absolute_form_in(self): - # protocol = mock_protocol("GET oops-no-protocol.com HTTP/1.1") - # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) - # - # protocol = mock_protocol("GET http://address:22/ HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # assert protocol.assemble(r) == ( - # "GET http://address:22/ HTTP/1.1\r\n" - # "Host: address:22\r\n" - # "Content-Length: 0\r\n\r\n") - # - # def test_http_options_relative_form_in(self): - # """ - # Exercises fix for Issue #392. - # """ - # protocol = mock_protocol("OPTIONS /secret/resource HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # r.host = 'address' - # r.port = 80 - # r.scheme = "http" - # assert protocol.assemble(r) == ( - # "OPTIONS /secret/resource HTTP/1.1\r\n" - # "Host: address\r\n" - # "Content-Length: 0\r\n\r\n") - # - # def test_http_options_absolute_form_in(self): - # protocol = mock_protocol("OPTIONS http://address/secret/resource HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # r.host = 'address' - # r.port = 80 - # r.scheme = "http" - # assert protocol.assemble(r) == ( - # "OPTIONS http://address:80/secret/resource HTTP/1.1\r\n" - # "Host: address\r\n" - # "Content-Length: 0\r\n\r\n") - -class TestEmptyRequest(object): - def test_init(self): - req = semantics.EmptyRequest() - assert req - -class TestResponse(object): - def test_headers(self): - tutils.raises(AssertionError, semantics.Response, - (1, 1), - 200, - headers='foobar', - ) - - resp = semantics.Response( - (1, 1), - 200, - ) - assert isinstance(resp.headers, http.Headers) - - def test_equal(self): - a = tutils.tresp() - b = tutils.tresp() - assert a == b - - assert not a == 'foo' - assert not b == 'foo' - assert not 'foo' == a - assert not 'foo' == b - - def test_repr(self): - r = tutils.tresp() - assert "unknown content type" in repr(r) - r.headers["content-type"] = "foo" - assert "foo" in repr(r) - assert repr(tutils.tresp(content=CONTENT_MISSING)) - - def test_get_cookies_none(self): - resp = tutils.tresp() - resp.headers = http.Headers() - assert not resp.get_cookies() - - def test_get_cookies_simple(self): - resp = tutils.tresp() - resp.headers = http.Headers(set_cookie="cookiename=cookievalue") - result = resp.get_cookies() - assert len(result) == 1 - assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", odict.ODict()] - - def test_get_cookies_with_parameters(self): - resp = tutils.tresp() - resp.headers = http.Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly") - result = resp.get_cookies() - assert len(result) == 1 - assert "cookiename" in result - assert result["cookiename"][0][0] == "cookievalue" - attrs = result["cookiename"][0][1] - assert len(attrs) == 4 - assert attrs["domain"] == ["example.com"] - assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"] - assert attrs["path"] == ["/"] - assert attrs["httponly"] == [None] - - def test_get_cookies_no_value(self): - resp = tutils.tresp() - resp.headers = http.Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") - result = resp.get_cookies() - assert len(result) == 1 - assert "cookiename" in result - assert result["cookiename"][0][0] == "" - assert len(result["cookiename"][0][1]) == 2 - - def test_get_cookies_twocookies(self): - resp = tutils.tresp() - resp.headers = http.Headers([ - ["Set-Cookie", "cookiename=cookievalue"], - ["Set-Cookie", "othercookie=othervalue"] - ]) - result = resp.get_cookies() - assert len(result) == 2 - assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", odict.ODict()] - assert "othercookie" in result - assert result["othercookie"][0] == ["othervalue", odict.ODict()] - - def test_set_cookies(self): - resp = tutils.tresp() - v = resp.get_cookies() - v.add("foo", ["bar", odict.ODictCaseless()]) - resp.set_cookies(v) - - v = resp.get_cookies() - assert len(v) == 1 - assert v["foo"] == [["bar", odict.ODictCaseless()]] - - -class TestHeaders(object): - def _2host(self): - return semantics.Headers( - [ - ["Host", "example.com"], - ["host", "example.org"] - ] - ) - - def test_init(self): - headers = semantics.Headers() - assert len(headers) == 0 - - headers = semantics.Headers([["Host", "example.com"]]) - assert len(headers) == 1 - assert headers["Host"] == "example.com" - - headers = semantics.Headers(Host="example.com") - assert len(headers) == 1 - assert headers["Host"] == "example.com" - - headers = semantics.Headers( - [["Host", "invalid"]], - Host="example.com" - ) - assert len(headers) == 1 - assert headers["Host"] == "example.com" - - headers = semantics.Headers( - [["Host", "invalid"], ["Accept", "text/plain"]], - Host="example.com" - ) - assert len(headers) == 2 - assert headers["Host"] == "example.com" - assert headers["Accept"] == "text/plain" - - def test_getitem(self): - headers = semantics.Headers(Host="example.com") - assert headers["Host"] == "example.com" - assert headers["host"] == "example.com" - tutils.raises(KeyError, headers.__getitem__, "Accept") - - headers = self._2host() - assert headers["Host"] == "example.com, example.org" - - def test_str(self): - headers = semantics.Headers(Host="example.com") - assert bytes(headers) == "Host: example.com\r\n" - - headers = semantics.Headers([ - ["Host", "example.com"], - ["Accept", "text/plain"] - ]) - assert str(headers) == "Host: example.com\r\nAccept: text/plain\r\n" - - def test_setitem(self): - headers = semantics.Headers() - headers["Host"] = "example.com" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == "example.com" - - headers["host"] = "example.org" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == "example.org" - - headers["accept"] = "text/plain" - assert len(headers) == 2 - assert "Accept" in headers - assert "Host" in headers - - headers = self._2host() - assert len(headers.fields) == 2 - headers["Host"] = "example.com" - assert len(headers.fields) == 1 - assert "Host" in headers - - def test_delitem(self): - headers = semantics.Headers(Host="example.com") - assert len(headers) == 1 - del headers["host"] - assert len(headers) == 0 - try: - del headers["host"] - except KeyError: - assert True - else: - assert False - - headers = self._2host() - del headers["Host"] - assert len(headers) == 0 - - def test_keys(self): - headers = semantics.Headers(Host="example.com") - assert len(headers.keys()) == 1 - assert headers.keys()[0] == "Host" - - headers = self._2host() - assert len(headers.keys()) == 1 - assert headers.keys()[0] == "Host" - - def test_eq_ne(self): - headers1 = semantics.Headers(Host="example.com") - headers2 = semantics.Headers(host="example.com") - assert not (headers1 == headers2) - assert headers1 != headers2 - - headers1 = semantics.Headers(Host="example.com") - headers2 = semantics.Headers(Host="example.com") - assert headers1 == headers2 - assert not (headers1 != headers2) - - assert headers1 != 42 - - def test_get_all(self): - headers = self._2host() - assert headers.get_all("host") == ["example.com", "example.org"] - assert headers.get_all("accept") == [] - - def test_set_all(self): - headers = semantics.Headers(Host="example.com") - headers.set_all("Accept", ["text/plain"]) - assert len(headers) == 2 - assert "accept" in headers - - headers = self._2host() - headers.set_all("Host", ["example.org"]) - assert headers["host"] == "example.org" - - headers.set_all("Host", ["example.org", "example.net"]) - assert headers["host"] == "example.org, example.net" - - def test_state(self): - headers = self._2host() - assert len(headers.get_state()) == 2 - assert headers == semantics.Headers.from_state(headers.get_state()) - - headers2 = semantics.Headers() - assert headers != headers2 - headers2.load_state(headers.get_state()) - assert headers == headers2 diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 57cfd166..3fdeb683 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -1,11 +1,13 @@ import os from nose.tools import raises +from netlib.http.http1 import read_response, read_request from netlib import tcp, tutils, websockets, http from netlib.http import status_codes -from netlib.http.exceptions import * -from netlib.http.http1 import HTTP1Protocol +from netlib.tutils import treq + +from netlib.exceptions import * from .. import tservers @@ -34,9 +36,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler): frame.to_file(self.wfile) def handshake(self): - http1_protocol = HTTP1Protocol(self) - req = http1_protocol.read_request() + req = read_request(self.rfile) key = self.protocol.check_client_handshake(req.headers) preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) @@ -61,8 +62,6 @@ class WebSocketsClient(tcp.TCPClient): def connect(self): super(WebSocketsClient, self).connect() - http1_protocol = HTTP1Protocol(self) - preamble = 'GET / HTTP/1.1' self.wfile.write(preamble + "\r\n") headers = self.protocol.client_handshake_headers() @@ -70,7 +69,7 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(str(headers) + "\r\n") self.wfile.flush() - resp = http1_protocol.read_response("GET", None) + resp = read_response(self.rfile, treq(method="GET")) server_nonce = self.protocol.check_server_handshake(resp.headers) if not server_nonce == self.protocol.create_server_nonce( @@ -158,9 +157,8 @@ class TestWebSockets(tservers.ServerTestBase): class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - http1_protocol = HTTP1Protocol(self) - client_hs = http1_protocol.read_request() + client_hs = read_request(self.rfile) self.protocol.check_client_handshake(client_hs.headers) preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) -- cgit v1.2.3 From 9b882d245052feec44fc77e102dc597d24de2b80 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 16 Sep 2015 00:09:43 +0200 Subject: test parts on python 3.5 --- .travis.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index fd2fba3d..fa997542 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,6 +15,8 @@ matrix: - debian-sid packages: - libssl-dev + - python: 3.5 + script: "nosetests --with-cov --cov-report term-missing test/http/http1" - python: pypy - python: pypy env: OPENSSL=1.0.2 @@ -67,4 +69,4 @@ cache: - /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages - /home/travis/virtualenv/python2.7.9/bin - /home/travis/virtualenv/pypy-2.5.0/site-packages - - /home/travis/virtualenv/pypy-2.5.0/bin \ No newline at end of file + - /home/travis/virtualenv/pypy-2.5.0/bin -- cgit v1.2.3 From 265f31e8782ee9da511ce4b63aa2da00221cbf66 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 16 Sep 2015 18:43:24 +0200 Subject: adjust http1-related code --- netlib/exceptions.py | 1 + netlib/http/__init__.py | 5 ++++- netlib/http/http1/__init__.py | 1 + netlib/http/http1/assemble.py | 4 ++-- netlib/http/http1/read.py | 18 +++++++++++------- netlib/http/http2/__init__.py | 6 ++++++ netlib/http/http2/connections.py | 28 ++++++++++++++++++---------- netlib/http/models.py | 3 +++ netlib/tutils.py | 4 ++-- test/http/http1/test_read.py | 12 ++++++++---- test/http/http2/test_protocol.py | 20 ++++++++++---------- test/http/test_models.py | 8 ++++---- 12 files changed, 70 insertions(+), 40 deletions(-) diff --git a/netlib/exceptions.py b/netlib/exceptions.py index 637be3df..e13af473 100644 --- a/netlib/exceptions.py +++ b/netlib/exceptions.py @@ -27,5 +27,6 @@ class HttpException(NetlibException): class HttpReadDisconnect(HttpException, ReadDisconnect): pass + class HttpSyntaxException(HttpException): pass diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 9303de09..d72884b3 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,9 +1,12 @@ +from __future__ import absolute_import, print_function, division from .models import Request, Response, Headers +from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING from . import http1, http2 __all__ = [ "Request", "Response", "Headers", + "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", - "http1", "http2" + "http1", "http2", ] diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py index a72c2e05..2d33ff8a 100644 --- a/netlib/http/http1/__init__.py +++ b/netlib/http/http1/__init__.py @@ -1,3 +1,4 @@ +from __future__ import absolute_import, print_function, division from .read import ( read_request, read_request_head, read_response, read_response_head, diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 47c7e95a..ace25d79 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -25,9 +25,9 @@ def assemble_response(response): return head + response.body -def assemble_response_head(response): +def assemble_response_head(response, preserve_transfer_encoding=False): first_line = _assemble_response_line(response) - headers = _assemble_response_headers(response) + headers = _assemble_response_headers(response, preserve_transfer_encoding) return b"%s\r\n%s\r\n" % (first_line, headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 4c423c4c..62025d15 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -6,8 +6,7 @@ import re from ... import utils from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException from .. import Request, Response, Headers - -ALPN_PROTO_HTTP1 = b'http/1.1' +from netlib.tcp import NetLibDisconnect def read_request(rfile, body_size_limit=None): @@ -157,10 +156,10 @@ def connection_close(http_version, headers): # If we don't have a Connection header, HTTP 1.1 connections are assumed to # be persistent - return http_version != (1, 1) + return http_version != b"HTTP/1.1" -def expected_http_body_size(request, response=False): +def expected_http_body_size(request, response=None): """ Returns: The expected body length: @@ -211,10 +210,13 @@ def expected_http_body_size(request, response=False): def _get_first_line(rfile): - line = rfile.readline() - if line == b"\r\n" or line == b"\n": - # Possible leftover from previous message + try: line = rfile.readline() + if line == b"\r\n" or line == b"\n": + # Possible leftover from previous message + line = rfile.readline() + except NetLibDisconnect: + raise HttpReadDisconnect() if not line: raise HttpReadDisconnect() line = line.strip() @@ -317,6 +319,8 @@ def _read_headers(rfile): try: name, value = line.split(b":", 1) value = value.strip() + if not name or not value: + raise ValueError() ret.append([name, value]) except ValueError: raise HttpSyntaxException("Invalid headers") diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py index e69de29b..7043d36f 100644 --- a/netlib/http/http2/__init__.py +++ b/netlib/http/http2/__init__.py @@ -0,0 +1,6 @@ +from __future__ import absolute_import, print_function, division +from .connections import HTTP2Protocol + +__all__ = [ + "HTTP2Protocol" +] diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index 036bf68f..5220d5d2 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -3,8 +3,8 @@ import itertools import time from hpack.hpack import Encoder, Decoder -from netlib import http, utils -from netlib.http import models as semantics +from ... import utils +from .. import Headers, Response, Request, ALPN_PROTO_H2 from . import frame @@ -36,8 +36,6 @@ class HTTP2Protocol(object): CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - ALPN_PROTO_H2 = 'h2' - def __init__( self, tcp_handler=None, @@ -62,6 +60,7 @@ class HTTP2Protocol(object): def read_request( self, + __rfile, include_body=True, body_size_limit=None, allow_empty=False, @@ -111,7 +110,7 @@ class HTTP2Protocol(object): port = 80 if scheme == 'http' else 443 port = int(port) - request = http.Request( + request = Request( form_in, method, scheme, @@ -131,6 +130,7 @@ class HTTP2Protocol(object): def read_response( self, + __rfile, request_method='', body_size_limit=None, include_body=True, @@ -159,7 +159,7 @@ class HTTP2Protocol(object): else: timestamp_end = None - response = http.Response( + response = Response( (2, 0), int(headers.get(':status', 502)), "", @@ -172,8 +172,16 @@ class HTTP2Protocol(object): return response + def assemble(self, message): + if isinstance(message, Request): + return self.assemble_request(message) + elif isinstance(message, Response): + return self.assemble_response(message) + else: + raise ValueError("HTTP message not supported.") + def assemble_request(self, request): - assert isinstance(request, semantics.Request) + assert isinstance(request, Request) authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host if self.tcp_handler.address.port != 443: @@ -200,7 +208,7 @@ class HTTP2Protocol(object): self._create_body(request.body, stream_id))) def assemble_response(self, response): - assert isinstance(response, semantics.Response) + assert isinstance(response, Response) headers = response.headers.copy() @@ -275,7 +283,7 @@ class HTTP2Protocol(object): def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: + if alp != ALPN_PROTO_H2: raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True @@ -405,7 +413,7 @@ class HTTP2Protocol(object): else: self._handle_unexpected_frame(frm) - headers = http.Headers( + headers = Headers( [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] ) diff --git a/netlib/http/models.py b/netlib/http/models.py index 572d66c9..2d09535c 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -13,6 +13,9 @@ try: except ImportError: from collections.abc import MutableMapping +# TODO: Move somewhere else? +ALPN_PROTO_HTTP1 = b'http/1.1' +ALPN_PROTO_H2 = b'h2' HDR_FORM_URLENCODED = b"application/x-www-form-urlencoded" HDR_FORM_MULTIPART = b"multipart/form-data" diff --git a/netlib/tutils.py b/netlib/tutils.py index 758f8410..05791c49 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -37,14 +37,14 @@ def _check_exception(expected, actual, exc_tb): if expected.lower() not in str(actual).lower(): six.reraise(AssertionError, AssertionError( "Expected %s, but caught %s" % ( - repr(str(expected)), actual + repr(expected), repr(actual) ) ), exc_tb) else: if not isinstance(actual, expected): six.reraise(AssertionError, AssertionError( "Expected %s, but caught %s %s" % ( - expected.__name__, actual.__class__.__name__, str(actual) + expected.__name__, actual.__class__.__name__, repr(actual) ) ), exc_tb) diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index 5e6680af..55def2a5 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -108,14 +108,14 @@ class TestReadBody(object): def test_connection_close(): headers = Headers() - assert connection_close((1, 0), headers) - assert not connection_close((1, 1), headers) + assert connection_close(b"HTTP/1.0", headers) + assert not connection_close(b"HTTP/1.1", headers) headers["connection"] = "keep-alive" - assert not connection_close((1, 1), headers) + assert not connection_close(b"HTTP/1.1", headers) headers["connection"] = "close" - assert connection_close((1, 1), headers) + assert connection_close(b"HTTP/1.1", headers) def test_expected_http_body_size(): @@ -281,6 +281,10 @@ class TestReadHeaders(object): with raises(HttpSyntaxException): self._read(data) + def test_read_empty_name(self): + data = b":foo" + with raises(HttpSyntaxException): + self._read(data) def test_read_chunked(): req = treq(body=None) diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 789b6e63..a369eb49 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -64,7 +64,7 @@ class TestProtocol: class TestCheckALPNMatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - alpn_select=HTTP2Protocol.ALPN_PROTO_H2, + alpn_select=ALPN_PROTO_H2, ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -72,7 +72,7 @@ class TestCheckALPNMatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + c.convert_to_ssl(alpn_protos=[ALPN_PROTO_H2]) protocol = HTTP2Protocol(c) assert protocol.check_alpn() @@ -88,7 +88,7 @@ class TestCheckALPNMismatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + c.convert_to_ssl(alpn_protos=[ALPN_PROTO_H2]) protocol = HTTP2Protocol(c) tutils.raises(NotImplementedError, protocol.check_alpn) @@ -306,7 +306,7 @@ class TestReadRequest(tservers.ServerTestBase): protocol = HTTP2Protocol(c, is_server=True) protocol.connection_preface_performed = True - req = protocol.read_request() + req = protocol.read_request(NotImplemented) assert req.stream_id assert req.headers.fields == [[':method', 'GET'], [':path', '/'], [':scheme', 'https']] @@ -329,7 +329,7 @@ class TestReadRequestRelative(tservers.ServerTestBase): protocol = HTTP2Protocol(c, is_server=True) protocol.connection_preface_performed = True - req = protocol.read_request() + req = protocol.read_request(NotImplemented) assert req.form_in == "relative" assert req.method == "OPTIONS" @@ -352,7 +352,7 @@ class TestReadRequestAbsolute(tservers.ServerTestBase): protocol = HTTP2Protocol(c, is_server=True) protocol.connection_preface_performed = True - req = protocol.read_request() + req = protocol.read_request(NotImplemented) assert req.form_in == "absolute" assert req.scheme == "http" @@ -378,13 +378,13 @@ class TestReadRequestConnect(tservers.ServerTestBase): protocol = HTTP2Protocol(c, is_server=True) protocol.connection_preface_performed = True - req = protocol.read_request() + req = protocol.read_request(NotImplemented) assert req.form_in == "authority" assert req.method == "CONNECT" assert req.host == "address" assert req.port == 22 - req = protocol.read_request() + req = protocol.read_request(NotImplemented) assert req.form_in == "authority" assert req.method == "CONNECT" assert req.host == "example.com" @@ -410,7 +410,7 @@ class TestReadResponse(tservers.ServerTestBase): protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True - resp = protocol.read_response(stream_id=42) + resp = protocol.read_response(NotImplemented, stream_id=42) assert resp.httpversion == (2, 0) assert resp.status_code == 200 @@ -436,7 +436,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True - resp = protocol.read_response(stream_id=42) + resp = protocol.read_response(NotImplemented, stream_id=42) assert resp.stream_id == 42 assert resp.httpversion == (2, 0) diff --git a/test/http/test_models.py b/test/http/test_models.py index 0f4dcc3b..8fce2e9d 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -20,7 +20,7 @@ class TestRequest(object): 'host', 'port', 'path', - (1, 1), + b"HTTP/1.1", 'foobar', ) @@ -31,7 +31,7 @@ class TestRequest(object): 'host', 'port', 'path', - (1, 1), + b"HTTP/1.1", ) assert isinstance(req.headers, Headers) @@ -307,13 +307,13 @@ class TestRequest(object): class TestResponse(object): def test_headers(self): tutils.raises(AssertionError, Response, - (1, 1), + b"HTTP/1.1", 200, headers='foobar', ) resp = Response( - (1, 1), + b"HTTP/1.1", 200, ) assert isinstance(resp.headers, Headers) -- cgit v1.2.3 From dad9f06cb9403ac88d31d0ba8422034df2bc5078 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Sep 2015 02:14:14 +0200 Subject: organize exceptions, improve content-length handling --- netlib/exceptions.py | 30 ++++++++++- netlib/http/http1/assemble.py | 8 +-- netlib/http/http1/read.py | 9 ++-- netlib/http/models.py | 24 ++++++++- netlib/tcp.py | 108 +++++++++++++++++-------------------- test/http/http2/test_protocol.py | 3 +- test/test_tcp.py | 36 +++++++------ test/test_utils.py | 2 +- test/websockets/test_websockets.py | 2 +- 9 files changed, 130 insertions(+), 92 deletions(-) diff --git a/netlib/exceptions.py b/netlib/exceptions.py index e13af473..e30235af 100644 --- a/netlib/exceptions.py +++ b/netlib/exceptions.py @@ -16,7 +16,7 @@ class NetlibException(Exception): super(NetlibException, self).__init__(message) -class ReadDisconnect(object): +class Disconnect(object): """Immediate EOF""" @@ -24,9 +24,35 @@ class HttpException(NetlibException): pass -class HttpReadDisconnect(HttpException, ReadDisconnect): +class HttpReadDisconnect(HttpException, Disconnect): pass class HttpSyntaxException(HttpException): pass + + +class TcpException(NetlibException): + pass + + +class TcpDisconnect(TcpException, Disconnect): + pass + + + + +class TcpReadIncomplete(TcpException): + pass + + +class TcpTimeout(TcpException): + pass + + +class TlsException(NetlibException): + pass + + +class InvalidCertificateException(TlsException): + pass diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index ace25d79..33b9ef25 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -95,9 +95,9 @@ def _assemble_response_headers(response, preserve_transfer_encoding=False): if not preserve_transfer_encoding: headers.pop(b"Transfer-Encoding", None) - # If body is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if response.body or response.body == b"": - headers[b"Content-Length"] = str(len(response.body)).encode("ascii") + # If body is defined (i.e. not None or CONTENT_MISSING), + # we now need to set a content-length header. + if response.body or response.body == b"": + headers[b"Content-Length"] = str(len(response.body)).encode("ascii") return bytes(headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 62025d15..7f2b7bab 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -4,15 +4,14 @@ import sys import re from ... import utils -from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException +from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect from .. import Request, Response, Headers -from netlib.tcp import NetLibDisconnect def read_request(rfile, body_size_limit=None): request = read_request_head(rfile) expected_body_size = expected_http_body_size(request) - request.body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) + request._body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) request.timestamp_end = time.time() return request @@ -51,7 +50,7 @@ def read_request_head(rfile): def read_response(rfile, request, body_size_limit=None): response = read_response_head(rfile) expected_body_size = expected_http_body_size(request, response) - response.body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) + response._body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) response.timestamp_end = time.time() return response @@ -215,7 +214,7 @@ def _get_first_line(rfile): if line == b"\r\n" or line == b"\n": # Possible leftover from previous message line = rfile.readline() - except NetLibDisconnect: + except TcpDisconnect: raise HttpReadDisconnect() if not line: raise HttpReadDisconnect() diff --git a/netlib/http/models.py b/netlib/http/models.py index 2d09535c..b4446ecb 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -231,7 +231,7 @@ class Request(object): self.path = path self.httpversion = httpversion self.headers = headers - self.body = body + self._body = body self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end self.form_out = form_out or form_in @@ -452,6 +452,16 @@ class Request(object): raise ValueError("Invalid URL: %s" % url) self.scheme, self.host, self.port, self.path = parts + @property + def body(self): + return self._body + + @body.setter + def body(self, body): + self._body = body + if isinstance(body, bytes): + self.headers["Content-Length"] = str(len(body)).encode() + @property def content(self): # pragma: no cover # TODO: remove deprecated getter @@ -488,7 +498,7 @@ class Response(object): self.status_code = status_code self.msg = msg self.headers = headers - self.body = body + self._body = body self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end @@ -551,6 +561,16 @@ class Response(object): ) self.headers.set_all("Set-Cookie", values) + @property + def body(self): + return self._body + + @body.setter + def body(self, body): + self._body = body + if isinstance(body, bytes): + self.headers["Content-Length"] = str(len(body)).encode() + @property def content(self): # pragma: no cover # TODO: remove deprecated getter diff --git a/netlib/tcp.py b/netlib/tcp.py index 1eb417b4..707e11e0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,6 +16,9 @@ from . import certutils, version_check # This is a rather hackish way to make sure that # the latest version of pyOpenSSL is actually installed. +from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ + TcpTimeout, TcpDisconnect, TcpException + version_check.check_pyopenssl_version() @@ -24,11 +27,17 @@ EINTR = 4 # To enable all SSL methods use: SSLv23 # then add options to disable certain methods # https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 +SSL_BASIC_OPTIONS = ( + SSL.OP_CIPHER_SERVER_PREFERENCE +) +if hasattr(SSL, "OP_NO_COMPRESSION"): + SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION + SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD SSL_DEFAULT_OPTIONS = ( SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | - SSL.OP_CIPHER_SERVER_PREFERENCE + SSL_BASIC_OPTIONS ) if hasattr(SSL, "OP_NO_COMPRESSION"): SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION @@ -39,42 +48,17 @@ Don't ask... https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 """ sslversion_choices = { - "all": (SSL.SSLv23_METHOD, 0), + "all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS), # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+ # TLSv1_METHOD would be TLS 1.0 only - "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)), - "SSLv2": (SSL.SSLv2_METHOD, 0), - "SSLv3": (SSL.SSLv3_METHOD, 0), - "TLSv1": (SSL.TLSv1_METHOD, 0), - "TLSv1_1": (SSL.TLSv1_1_METHOD, 0), - "TLSv1_2": (SSL.TLSv1_2_METHOD, 0), + "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)), + "SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS), + "SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS), + "TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS), + "TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS), + "TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS), } - -class NetLibError(Exception): - pass - - -class NetLibDisconnect(NetLibError): - pass - - -class NetLibIncomplete(NetLibError): - pass - - -class NetLibTimeout(NetLibError): - pass - - -class NetLibSSLError(NetLibError): - pass - - -class NetLibInvalidCertificateError(NetLibSSLError): - pass - - class SSLKeyLogger(object): def __init__(self, filename): @@ -168,17 +152,17 @@ class Writer(_FileLike): def flush(self): """ - May raise NetLibDisconnect + May raise TcpDisconnect """ if hasattr(self.o, "flush"): try: self.o.flush() except (socket.error, IOError) as v: - raise NetLibDisconnect(str(v)) + raise TcpDisconnect(str(v)) def write(self, v): """ - May raise NetLibDisconnect + May raise TcpDisconnect """ if v: self.first_byte_timestamp = self.first_byte_timestamp or time.time() @@ -191,7 +175,7 @@ class Writer(_FileLike): self.add_log(v[:r]) return r except (SSL.Error, socket.error) as e: - raise NetLibDisconnect(str(e)) + raise TcpDisconnect(str(e)) class Reader(_FileLike): @@ -210,23 +194,29 @@ class Reader(_FileLike): try: data = self.o.read(rlen) except SSL.ZeroReturnError: + # TLS connection was shut down cleanly break - except SSL.WantReadError: + except (SSL.WantWriteError, SSL.WantReadError): + # From the OpenSSL docs: + # If the underlying BIO is non-blocking, SSL_read() will also return when the + # underlying BIO could not satisfy the needs of SSL_read() to continue the + # operation. In this case a call to SSL_get_error with the return value of + # SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE. if (time.time() - start) < self.o.gettimeout(): time.sleep(0.1) continue else: - raise NetLibTimeout + raise TcpTimeout() except socket.timeout: - raise NetLibTimeout - except socket.error: - raise NetLibDisconnect + raise TcpTimeout() + except socket.error as e: + raise TcpDisconnect(str(e)) except SSL.SysCallError as e: if e.args == (-1, 'Unexpected EOF'): break - raise NetLibSSLError(e.message) + raise TlsException(e.message) except SSL.Error as e: - raise NetLibSSLError(e.message) + raise TlsException(e.message) self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break @@ -260,9 +250,9 @@ class Reader(_FileLike): result = self.read(length) if length != -1 and len(result) != length: if not result: - raise NetLibDisconnect() + raise TcpDisconnect() else: - raise NetLibIncomplete( + raise TcpReadIncomplete( "Expected %s bytes, got %s" % (length, len(result)) ) return result @@ -275,15 +265,15 @@ class Reader(_FileLike): Up to the next N bytes if peeking is successful. Raises: - NetLibError if there was an error with the socket - NetLibSSLError if there was an error with pyOpenSSL. + TcpException if there was an error with the socket + TlsException if there was an error with pyOpenSSL. NotImplementedError if the underlying file object is not a (pyOpenSSL) socket """ if isinstance(self.o, socket._fileobject): try: return self.o._sock.recv(length, socket.MSG_PEEK) except socket.error as e: - raise NetLibError(repr(e)) + raise TcpException(repr(e)) elif isinstance(self.o, SSL.Connection): try: if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): @@ -296,7 +286,7 @@ class Reader(_FileLike): self.o._raise_ssl_error(self.o._ssl, result) return SSL._ffi.buffer(buf, result)[:] except SSL.Error as e: - six.reraise(NetLibSSLError, NetLibSSLError(str(e)), sys.exc_info()[2]) + six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2]) else: raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") @@ -461,7 +451,7 @@ class _Connection(object): try: self.wfile.flush() self.wfile.close() - except NetLibDisconnect: + except TcpDisconnect: pass self.rfile.close() @@ -525,7 +515,7 @@ class _Connection(object): # TODO: maybe change this to with newer pyOpenSSL APIs context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) except SSL.Error as v: - raise NetLibError("SSL cipher specification error: %s" % str(v)) + raise TlsException("SSL cipher specification error: %s" % str(v)) # SSLKEYLOGFILE if log_ssl_key: @@ -546,7 +536,7 @@ class _Connection(object): elif alpn_select_callback is not None and alpn_select is None: context.set_alpn_select_callback(alpn_select_callback) elif alpn_select_callback is not None and alpn_select is not None: - raise NetLibError("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") + raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") return context @@ -594,7 +584,7 @@ class TCPClient(_Connection): context.use_privatekey_file(cert) context.use_certificate_file(cert) except SSL.Error as v: - raise NetLibError("SSL client certificate error: %s" % str(v)) + raise TlsException("SSL client certificate error: %s" % str(v)) return context def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): @@ -618,15 +608,15 @@ class TCPClient(_Connection): self.connection.do_handshake() except SSL.Error as v: if self.ssl_verification_error: - raise NetLibInvalidCertificateError("SSL handshake error: %s" % repr(v)) + raise InvalidCertificateException("SSL handshake error: %s" % repr(v)) else: - raise NetLibError("SSL handshake error: %s" % repr(v)) + raise TlsException("SSL handshake error: %s" % repr(v)) # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on # certificate validation failure verification_mode = sslctx_kwargs.get('verify_options', None) if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER: - raise NetLibInvalidCertificateError("SSL handshake error: certificate verify failed") + raise InvalidCertificateException("SSL handshake error: certificate verify failed") self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) @@ -644,7 +634,7 @@ class TCPClient(_Connection): self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError) as err: - raise NetLibError( + raise TcpException( 'Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection @@ -750,7 +740,7 @@ class BaseHandler(_Connection): try: self.connection.do_handshake() except SSL.Error as v: - raise NetLibError("SSL handshake error: %s" % repr(v)) + raise TlsException("SSL handshake error: %s" % repr(v)) self.ssl_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index a369eb49..598b5cd7 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -2,6 +2,7 @@ import OpenSSL import mock from netlib import tcp, http, tutils +from netlib.exceptions import TcpDisconnect from netlib.http import Headers from netlib.http.http2.connections import HTTP2Protocol, TCPHandler from netlib.http.http2.frame import * @@ -127,7 +128,7 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase): protocol.perform_server_connection_preface() assert protocol.connection_preface_performed - tutils.raises(tcp.NetLibDisconnect, protocol.perform_server_connection_preface, force=True) + tutils.raises(TcpDisconnect, protocol.perform_server_connection_preface, force=True) class TestPerformClientConnectionPreface(tservers.ServerTestBase): diff --git a/test/test_tcp.py b/test/test_tcp.py index 2a5deb2b..615900ce 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -12,6 +12,8 @@ import OpenSSL from netlib import tcp, certutils, tutils from . import tservers +from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ + TcpTimeout, TcpDisconnect, TcpException class EchoHandler(tcp.BaseHandler): @@ -93,7 +95,7 @@ class TestServerBind(tservers.ServerTestBase): c.connect() assert c.rfile.readline() == str(("127.0.0.1", random_port)) return - except tcp.NetLibError: # port probably already in use + except TcpException: # port probably already in use pass @@ -140,7 +142,7 @@ class TestFinishFail(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write("foo\n") - c.wfile.flush = mock.Mock(side_effect=tcp.NetLibDisconnect) + c.wfile.flush = mock.Mock(side_effect=TcpDisconnect) c.finish() @@ -180,7 +182,7 @@ class TestSSLv3Only(tservers.ServerTestBase): def test_failure(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com") + tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com") class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): @@ -224,7 +226,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): c.connect() tutils.raises( - tcp.NetLibInvalidCertificateError, + InvalidCertificateException, c.convert_to_ssl, verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem")) @@ -327,7 +329,7 @@ class TestSSLClientCert(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() tutils.raises( - tcp.NetLibError, + TlsException, c.convert_to_ssl, cert=tutils.test_data.path("data/clientcert/make") ) @@ -432,7 +434,7 @@ class TestSSLDisconnect(tservers.ServerTestBase): # Excercise SSL.ZeroReturnError c.rfile.read(10) c.close() - tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") + tutils.raises(TcpDisconnect, c.wfile.write, "foo") tutils.raises(Queue.Empty, self.q.get_nowait) @@ -447,7 +449,7 @@ class TestSSLHardDisconnect(tservers.ServerTestBase): # Exercise SSL.SysCallError c.rfile.read(10) c.close() - tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") + tutils.raises(TcpDisconnect, c.wfile.write, "foo") class TestDisconnect(tservers.ServerTestBase): @@ -470,7 +472,7 @@ class TestServerTimeOut(tservers.ServerTestBase): self.settimeout(0.01) try: self.rfile.read(10) - except tcp.NetLibTimeout: + except TcpTimeout: self.timeout = True def test_timeout(self): @@ -488,7 +490,7 @@ class TestTimeOut(tservers.ServerTestBase): c.connect() c.settimeout(0.1) assert c.gettimeout() == 0.1 - tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) + tutils.raises(TcpTimeout, c.rfile.read, 10) class TestALPNClient(tservers.ServerTestBase): @@ -540,7 +542,7 @@ class TestSSLTimeOut(tservers.ServerTestBase): c.connect() c.convert_to_ssl() c.settimeout(0.1) - tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) + tutils.raises(TcpTimeout, c.rfile.read, 10) class TestDHParams(tservers.ServerTestBase): @@ -570,7 +572,7 @@ class TestTCPClient: def test_conerr(self): c = tcp.TCPClient(("127.0.0.1", 0)) - tutils.raises(tcp.NetLibError, c.connect) + tutils.raises(TcpException, c.connect) class TestFileLike: @@ -639,7 +641,7 @@ class TestFileLike: o = mock.MagicMock() o.flush = mock.MagicMock(side_effect=socket.error) s.o = o - tutils.raises(tcp.NetLibDisconnect, s.flush) + tutils.raises(TcpDisconnect, s.flush) def test_reader_read_error(self): s = cStringIO.StringIO("foobar\nfoobar") @@ -647,7 +649,7 @@ class TestFileLike: o = mock.MagicMock() o.read = mock.MagicMock(side_effect=socket.error) s.o = o - tutils.raises(tcp.NetLibDisconnect, s.read, 10) + tutils.raises(TcpDisconnect, s.read, 10) def test_reset_timestamps(self): s = cStringIO.StringIO("foobar\nfoobar") @@ -678,24 +680,24 @@ class TestFileLike: s = mock.MagicMock() s.read = mock.MagicMock(side_effect=SSL.Error()) s = tcp.Reader(s) - tutils.raises(tcp.NetLibSSLError, s.read, 1) + tutils.raises(TlsException, s.read, 1) def test_read_syscall_ssl_error(self): s = mock.MagicMock() s.read = mock.MagicMock(side_effect=SSL.SysCallError()) s = tcp.Reader(s) - tutils.raises(tcp.NetLibSSLError, s.read, 1) + tutils.raises(TlsException, s.read, 1) def test_reader_readline_disconnect(self): o = mock.MagicMock() o.read = mock.MagicMock(side_effect=socket.error) s = tcp.Reader(o) - tutils.raises(tcp.NetLibDisconnect, s.readline, 10) + tutils.raises(TcpDisconnect, s.readline, 10) def test_reader_incomplete_error(self): s = cStringIO.StringIO("foobar") s = tcp.Reader(s) - tutils.raises(tcp.NetLibIncomplete, s.safe_read, 10) + tutils.raises(TcpReadIncomplete, s.safe_read, 10) class TestAddress: diff --git a/test/test_utils.py b/test/test_utils.py index 8b2ddae4..eb7aa31a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -14,7 +14,7 @@ def test_hexdump(): assert utils.hexdump("one\0" * 10) -def test_cleanBin(): +def test_clean_bin(): assert utils.clean_bin(b"one") == b"one" assert utils.clean_bin(b"\00ne") == b".ne" assert utils.clean_bin(b"\nne") == b"\nne" diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 3fdeb683..3af5dc9c 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -176,7 +176,7 @@ class TestBadHandshake(tservers.ServerTestBase): """ handler = BadHandshakeHandler - @raises(tcp.NetLibDisconnect) + @raises(TcpDisconnect) def test(self): client = WebSocketsClient(("127.0.0.1", self.port)) client.connect() -- cgit v1.2.3 From a07e43df8b3988f137b48957f978ad570d9dc782 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Sep 2015 02:39:42 +0200 Subject: http1: add assemble_body function --- netlib/exceptions.py | 2 -- netlib/http/http1/__init__.py | 2 ++ netlib/http/http1/assemble.py | 26 +++++++++++++++----------- test/http/http1/test_assemble.py | 18 ++++++++++++++---- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/netlib/exceptions.py b/netlib/exceptions.py index e30235af..05f1054b 100644 --- a/netlib/exceptions.py +++ b/netlib/exceptions.py @@ -40,8 +40,6 @@ class TcpDisconnect(TcpException, Disconnect): pass - - class TcpReadIncomplete(TcpException): pass diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py index 2d33ff8a..2aa7e26a 100644 --- a/netlib/http/http1/__init__.py +++ b/netlib/http/http1/__init__.py @@ -9,6 +9,7 @@ from .read import ( from .assemble import ( assemble_request, assemble_request_head, assemble_response, assemble_response_head, + assemble_body, ) @@ -20,4 +21,5 @@ __all__ = [ "expected_http_body_size", "assemble_request", "assemble_request_head", "assemble_response", "assemble_response_head", + "assemble_body", ] diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 33b9ef25..7252c446 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, print_function, division from ... import utils +import itertools from ...exceptions import HttpException from .. import CONTENT_MISSING @@ -25,12 +26,23 @@ def assemble_response(response): return head + response.body -def assemble_response_head(response, preserve_transfer_encoding=False): +def assemble_response_head(response): first_line = _assemble_response_line(response) - headers = _assemble_response_headers(response, preserve_transfer_encoding) + headers = _assemble_response_headers(response) return b"%s\r\n%s\r\n" % (first_line, headers) +def assemble_body(headers, body_chunks): + if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + for chunk in body_chunks: + if chunk: + yield b"%x\r\n%s\r\n" % (len(chunk), chunk) + yield b"0\r\n\r\n" + else: + for chunk in body_chunks: + yield chunk + + def _assemble_request_line(request, form=None): if form is None: form = request.form_out @@ -87,17 +99,9 @@ def _assemble_response_line(response): ) -def _assemble_response_headers(response, preserve_transfer_encoding=False): - # TODO: Remove preserve_transfer_encoding +def _assemble_response_headers(response): headers = response.headers.copy() for k in response._headers_to_strip_off: headers.pop(k, None) - if not preserve_transfer_encoding: - headers.pop(b"Transfer-Encoding", None) - - # If body is defined (i.e. not None or CONTENT_MISSING), - # we now need to set a content-length header. - if response.body or response.body == b"": - headers[b"Content-Length"] = str(len(response.body)).encode("ascii") return bytes(headers) diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index 8a0a54f1..cdc8bda9 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -4,8 +4,8 @@ from netlib.http import CONTENT_MISSING, Headers from netlib.http.http1.assemble import ( assemble_request, assemble_request_head, assemble_response, assemble_response_head, _assemble_request_line, _assemble_request_headers, - _assemble_response_headers -) + _assemble_response_headers, + assemble_body) from netlib.tutils import treq, raises, tresp @@ -50,6 +50,17 @@ def test_assemble_response_head(): assert b"message" not in c +def test_assemble_body(): + c = list(assemble_body(Headers(), [b"body"])) + assert c == [b"body"] + + c = list(assemble_body(Headers(transfer_encoding="chunked"), [b"123456789a", b""])) + assert c == [b"a\r\n123456789a\r\n", b"0\r\n\r\n"] + + c = list(assemble_body(Headers(transfer_encoding="chunked"), [b"123456789a"])) + assert c == [b"a\r\n123456789a\r\n", b"0\r\n\r\n"] + + def test_assemble_request_line(): assert _assemble_request_line(treq()) == b"GET /path HTTP/1.1" @@ -83,8 +94,7 @@ def test_assemble_response_headers(): r = tresp(body=b"") r.headers["Transfer-Encoding"] = b"chunked" c = _assemble_response_headers(r) - assert b"Content-Length" in c - assert b"Transfer-Encoding" not in c + assert b"Transfer-Encoding" in c assert b"Proxy-Connection" not in _assemble_response_headers( tresp(headers=Headers(Proxy_Connection=b"42")) -- cgit v1.2.3 From 8d71059d77c2dd1d9858d7971dd0b6b4387ed9f4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Sep 2015 15:16:12 +0200 Subject: clean up http message models --- netlib/http/http1/assemble.py | 8 +- netlib/http/models.py | 159 +++++++++++++-------------------------- netlib/tutils.py | 4 +- netlib/utils.py | 30 +++----- netlib/websockets/frame.py | 9 ++- netlib/websockets/protocol.py | 3 +- test/http/http2/test_protocol.py | 4 +- test/test_utils.py | 14 ++-- 8 files changed, 83 insertions(+), 148 deletions(-) diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 7252c446..b65a6be0 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -50,14 +50,14 @@ def _assemble_request_line(request, form=None): return b"%s %s %s" % ( request.method, request.path, - request.httpversion + request.http_version ) elif form == "authority": return b"%s %s:%d %s" % ( request.method, request.host, request.port, - request.httpversion + request.http_version ) elif form == "absolute": return b"%s %s://%s:%d%s %s" % ( @@ -66,7 +66,7 @@ def _assemble_request_line(request, form=None): request.host, request.port, request.path, - request.httpversion + request.http_version ) else: # pragma: nocover raise RuntimeError("Invalid request form") @@ -93,7 +93,7 @@ def _assemble_request_headers(request): def _assemble_response_line(response): return b"%s %d %s" % ( - response.httpversion, + response.http_version, response.status_code, response.msg, ) diff --git a/netlib/http/models.py b/netlib/http/models.py index b4446ecb..54b8b112 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -193,15 +193,45 @@ class Headers(MutableMapping, object): return cls([list(field) for field in state]) -class Request(object): +class Message(object): + def __init__(self, http_version, headers, body, timestamp_start, timestamp_end): + self.http_version = http_version + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + self.headers = headers + + self._body = body + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + @property + def body(self): + return self._body + + @body.setter + def body(self, body): + self._body = body + if isinstance(body, bytes): + self.headers[b"Content-Length"] = str(len(body)).encode() + + content = body + + def __eq__(self, other): + if isinstance(other, Message): + return self.__dict__ == other.__dict__ + return False + + +class Request(Message): # This list is adopted legacy code. # We probably don't need to strip off keep-alive. _headers_to_strip_off = [ - 'Proxy-Connection', - 'Keep-Alive', - 'Connection', - 'Transfer-Encoding', - 'Upgrade', + b'Proxy-Connection', + b'Keep-Alive', + b'Connection', + b'Transfer-Encoding', + b'Upgrade', ] def __init__( @@ -212,16 +242,14 @@ class Request(object): host, port, path, - httpversion, + http_version, headers=None, body=None, timestamp_start=None, timestamp_end=None, form_out=None ): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) + super(Request, self).__init__(http_version, headers, body, timestamp_start, timestamp_end) self.form_in = form_in self.method = method @@ -229,23 +257,8 @@ class Request(object): self.host = host self.port = port self.path = path - self.httpversion = httpversion - self.headers = headers - self._body = body - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end self.form_out = form_out or form_in - def __eq__(self, other): - try: - self_d = [self.__dict__[k] for k in self.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - return self_d == other_d - except: - return False - def __repr__(self): if self.host and self.port: hostport = "{}:{}".format(self.host, self.port) @@ -262,8 +275,8 @@ class Request(object): response. That is, we remove ETags and If-Modified-Since headers. """ delheaders = [ - "if-modified-since", - "if-none-match", + b"if-modified-since", + b"if-none-match", ] for i in delheaders: self.headers.pop(i, None) @@ -273,16 +286,16 @@ class Request(object): Modifies this request to remove headers that will compress the resource's data. """ - self.headers["accept-encoding"] = "identity" + self.headers[b"accept-encoding"] = b"identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = self.headers.get("accept-encoding") + accept_encoding = self.headers.get(b"accept-encoding") if accept_encoding: - self.headers["accept-encoding"] = ( + self.headers[b"accept-encoding"] = ( ', '.join( e for e in encoding.ENCODINGS @@ -335,7 +348,7 @@ class Request(object): """ # FIXME: If there's an existing content-type header indicating a # url-encoded form, leave it alone. - self.headers["Content-Type"] = HDR_FORM_URLENCODED + self.headers[b"Content-Type"] = HDR_FORM_URLENCODED self.body = utils.urlencode(odict.lst) def get_path_components(self): @@ -452,37 +465,17 @@ class Request(object): raise ValueError("Invalid URL: %s" % url) self.scheme, self.host, self.port, self.path = parts - @property - def body(self): - return self._body - - @body.setter - def body(self, body): - self._body = body - if isinstance(body, bytes): - self.headers["Content-Length"] = str(len(body)).encode() - - @property - def content(self): # pragma: no cover - # TODO: remove deprecated getter - return self.body - - @content.setter - def content(self, content): # pragma: no cover - # TODO: remove deprecated setter - self.body = content - -class Response(object): +class Response(Message): _headers_to_strip_off = [ - 'Proxy-Connection', - 'Alternate-Protocol', - 'Alt-Svc', + b'Proxy-Connection', + b'Alternate-Protocol', + b'Alt-Svc', ] def __init__( self, - httpversion, + http_version, status_code, msg=None, headers=None, @@ -490,27 +483,9 @@ class Response(object): timestamp_start=None, timestamp_end=None, ): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) - - self.httpversion = httpversion + super(Response, self).__init__(http_version, headers, body, timestamp_start, timestamp_end) self.status_code = status_code self.msg = msg - self.headers = headers - self._body = body - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - def __eq__(self, other): - try: - self_d = [self.__dict__[k] for k in self.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if - k not in ('timestamp_start', 'timestamp_end')] - return self_d == other_d - except: - return False def __repr__(self): # return "Response(%s - %s)" % (self.status_code, self.msg) @@ -536,7 +511,7 @@ class Response(object): attributes (e.g. HTTPOnly) are indicated by a Null value. """ ret = [] - for header in self.headers.get_all("set-cookie"): + for header in self.headers.get_all(b"set-cookie"): v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v @@ -559,34 +534,4 @@ class Response(object): i[1][1] ) ) - self.headers.set_all("Set-Cookie", values) - - @property - def body(self): - return self._body - - @body.setter - def body(self, body): - self._body = body - if isinstance(body, bytes): - self.headers["Content-Length"] = str(len(body)).encode() - - @property - def content(self): # pragma: no cover - # TODO: remove deprecated getter - return self.body - - @content.setter - def content(self, content): # pragma: no cover - # TODO: remove deprecated setter - self.body = content - - @property - def code(self): # pragma: no cover - # TODO: remove deprecated getter - return self.status_code - - @code.setter - def code(self, code): # pragma: no cover - # TODO: remove deprecated setter - self.status_code = code + self.headers.set_all(b"Set-Cookie", values) diff --git a/netlib/tutils.py b/netlib/tutils.py index 05791c49..b69495a3 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -105,7 +105,7 @@ def treq(**kwargs): host=b"address", port=22, path=b"/path", - httpversion=b"HTTP/1.1", + http_version=b"HTTP/1.1", headers=Headers(header=b"qvalue"), body=b"content" ) @@ -119,7 +119,7 @@ def tresp(**kwargs): netlib.http.Response """ default = dict( - httpversion=b"HTTP/1.1", + http_version=b"HTTP/1.1", status_code=200, msg=b"OK", headers=Headers(header_response=b"svalue"), diff --git a/netlib/utils.py b/netlib/utils.py index a86b8019..14b428d7 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -17,11 +17,6 @@ def isascii(bytes): return True -# best way to do it in python 2.x -def bytes_to_int(i): - return int(i.encode('hex'), 16) - - def clean_bin(s, keep_spacing=True): """ Cleans binary data to make it safe to display. @@ -51,21 +46,15 @@ def clean_bin(s, keep_spacing=True): def hexdump(s): """ - Returns a set of tuples: - (offset, hex, str) + Returns: + A generator of (offset, hex, str) tuples """ - parts = [] for i in range(0, len(s), 16): - o = "%.10x" % i + offset = b"%.10x" % i part = s[i:i + 16] - x = " ".join("%.2x" % ord(i) for i in part) - if len(part) < 16: - x += " " - x += " ".join(" " for i in range(16 - len(part))) - parts.append( - (o, x, clean_bin(part, False)) - ) - return parts + x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) + x = x.ljust(47) # 16*2 + 15 + yield (offset, x, clean_bin(part, False)) def setbit(byte, offset, value): @@ -80,8 +69,7 @@ def setbit(byte, offset, value): def getbit(byte, offset): mask = 1 << offset - if byte & mask: - return True + return bool(byte & mask) class BiDi(object): @@ -159,7 +147,7 @@ def is_valid_host(host): return False if len(host) > 255: return False - if host[-1] == ".": + if host[-1] == b".": host = host[:-1] return all(_label_valid.match(x) for x in host.split(b".")) @@ -248,7 +236,7 @@ def hostport(scheme, host, port): """ Returns the host component, with a port specifcation if needed. """ - if (port, scheme) in [(80, "http"), (443, "https")]: + if (port, scheme) in [(80, b"http"), (443, b"https")]: return host else: return b"%s:%d" % (host, port) diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index e3ff1405..ceddd273 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -2,6 +2,7 @@ from __future__ import absolute_import import os import struct import io +import six from .protocol import Masker from netlib import tcp @@ -127,8 +128,8 @@ class FrameHeader(object): """ read a websockets frame header """ - first_byte = utils.bytes_to_int(fp.safe_read(1)) - second_byte = utils.bytes_to_int(fp.safe_read(1)) + first_byte = six.byte2int(fp.safe_read(1)) + second_byte = six.byte2int(fp.safe_read(1)) fin = utils.getbit(first_byte, 7) rsv1 = utils.getbit(first_byte, 6) @@ -145,9 +146,9 @@ class FrameHeader(object): if length_code <= 125: payload_length = length_code elif length_code == 126: - payload_length = utils.bytes_to_int(fp.safe_read(2)) + payload_length, = struct.unpack("!H", fp.safe_read(2)) elif length_code == 127: - payload_length = utils.bytes_to_int(fp.safe_read(8)) + payload_length, = struct.unpack("!Q", fp.safe_read(8)) # masking key only present if mask bit set if mask_bit == 1: diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 46c02875..68d827a5 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -17,6 +17,7 @@ from __future__ import absolute_import import base64 import hashlib import os +import six from ..http import Headers from .. import utils @@ -40,7 +41,7 @@ class Masker(object): def __init__(self, key): self.key = key - self.masks = [utils.bytes_to_int(byte) for byte in key] + self.masks = [six.byte2int(byte) for byte in key] self.offset = 0 def mask(self, offset, data): diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 598b5cd7..a55941e0 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -413,7 +413,7 @@ class TestReadResponse(tservers.ServerTestBase): resp = protocol.read_response(NotImplemented, stream_id=42) - assert resp.httpversion == (2, 0) + assert resp.http_version == (2, 0) assert resp.status_code == 200 assert resp.msg == "" assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] @@ -440,7 +440,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): resp = protocol.read_response(NotImplemented, stream_id=42) assert resp.stream_id == 42 - assert resp.httpversion == (2, 0) + assert resp.http_version == (2, 0) assert resp.status_code == 200 assert resp.msg == "" assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] diff --git a/test/test_utils.py b/test/test_utils.py index eb7aa31a..0db75578 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -103,11 +103,11 @@ def test_get_header_tokens(): headers = Headers() assert utils.get_header_tokens(headers, "foo") == [] headers["foo"] = "bar" - assert utils.get_header_tokens(headers, "foo") == ["bar"] + assert utils.get_header_tokens(headers, "foo") == [b"bar"] headers["foo"] = "bar, voing" - assert utils.get_header_tokens(headers, "foo") == ["bar", "voing"] + assert utils.get_header_tokens(headers, "foo") == [b"bar", b"voing"] headers.set_all("foo", ["bar, voing", "oink"]) - assert utils.get_header_tokens(headers, "foo") == ["bar", "voing", "oink"] + assert utils.get_header_tokens(headers, "foo") == [b"bar", b"voing", b"oink"] def test_multipartdecode(): @@ -134,8 +134,8 @@ def test_multipartdecode(): def test_parse_content_type(): p = utils.parse_content_type - assert p("text/html") == ("text", "html", {}) - assert p("text") is None + assert p(b"text/html") == (b"text", b"html", {}) + assert p(b"text") is None - v = p("text/html; charset=UTF-8") - assert v == ('text', 'html', {'charset': 'UTF-8'}) + v = p(b"text/html; charset=UTF-8") + assert v == (b'text', b'html', {b'charset': b'UTF-8'}) -- cgit v1.2.3 From d798ed955dab4681a5285024b3648b1a3f13c24e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Sep 2015 16:31:50 +0200 Subject: python3++ --- .travis.yml | 6 +++++- netlib/encoding.py | 20 +++++++++++------- netlib/http/models.py | 48 +++++++++++++++++++++--------------------- netlib/odict.py | 25 +++------------------- netlib/tutils.py | 4 +--- netlib/utils.py | 22 +++++++++---------- test/http/test_models.py | 8 +++---- test/test_encoding.py | 10 +++++---- test/test_odict.py | 40 ++++++----------------------------- test/test_socks.py | 55 ++++++++++++++++++++++++------------------------ test/test_utils.py | 10 ++++----- 11 files changed, 105 insertions(+), 143 deletions(-) diff --git a/.travis.yml b/.travis.yml index fa997542..7e18176c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,11 @@ matrix: packages: - libssl-dev - python: 3.5 - script: "nosetests --with-cov --cov-report term-missing test/http/http1" + script: + - nosetests --with-cov --cov-report term-missing test/http/http1 + - nosetests --with-cov --cov-report term-missing test/test_utils.py + - nosetests --with-cov --cov-report term-missing test/test_encoding.py + - nosetests --with-cov --cov-report term-missing test/test_odict.py - python: pypy - python: pypy env: OPENSSL=1.0.2 diff --git a/netlib/encoding.py b/netlib/encoding.py index 06830f2c..8ac59905 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -5,28 +5,30 @@ from __future__ import absolute_import from io import BytesIO import gzip import zlib +from .utils import always_byte_args -__ALL__ = ["ENCODINGS"] -ENCODINGS = {"identity", "gzip", "deflate"} +ENCODINGS = {b"identity", b"gzip", b"deflate"} +@always_byte_args("ascii", "ignore") def decode(e, content): encoding_map = { - "identity": identity, - "gzip": decode_gzip, - "deflate": decode_deflate, + b"identity": identity, + b"gzip": decode_gzip, + b"deflate": decode_deflate, } if e not in encoding_map: return None return encoding_map[e](content) +@always_byte_args("ascii", "ignore") def encode(e, content): encoding_map = { - "identity": identity, - "gzip": encode_gzip, - "deflate": encode_deflate, + b"identity": identity, + b"gzip": encode_gzip, + b"deflate": encode_deflate, } if e not in encoding_map: return None @@ -80,3 +82,5 @@ def encode_deflate(content): Returns compressed content, always including zlib header and checksum. """ return zlib.compress(content) + +__all__ = ["ENCODINGS", "encode", "decode"] diff --git a/netlib/http/models.py b/netlib/http/models.py index 54b8b112..bc681de3 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -136,7 +136,7 @@ class Headers(MutableMapping, object): def __len__(self): return len(set(name.lower() for name, _ in self.fields)) - #__hash__ = object.__hash__ + # __hash__ = object.__hash__ def _index(self, name): name = name.lower() @@ -227,11 +227,11 @@ class Request(Message): # This list is adopted legacy code. # We probably don't need to strip off keep-alive. _headers_to_strip_off = [ - b'Proxy-Connection', - b'Keep-Alive', - b'Connection', - b'Transfer-Encoding', - b'Upgrade', + 'Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding', + 'Upgrade', ] def __init__( @@ -275,8 +275,8 @@ class Request(Message): response. That is, we remove ETags and If-Modified-Since headers. """ delheaders = [ - b"if-modified-since", - b"if-none-match", + b"If-Modified-Since", + b"If-None-Match", ] for i in delheaders: self.headers.pop(i, None) @@ -286,16 +286,16 @@ class Request(Message): Modifies this request to remove headers that will compress the resource's data. """ - self.headers[b"accept-encoding"] = b"identity" + self.headers["Accept-Encoding"] = b"identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = self.headers.get(b"accept-encoding") + accept_encoding = self.headers.get(b"Accept-Encoding") if accept_encoding: - self.headers[b"accept-encoding"] = ( + self.headers["Accept-Encoding"] = ( ', '.join( e for e in encoding.ENCODINGS @@ -316,9 +316,9 @@ class Request(Message): indicates non-form data. """ if self.body: - if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): + if HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): + elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): return self.get_form_multipart() return ODict([]) @@ -328,12 +328,12 @@ class Request(Message): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): + if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): return ODict(utils.urldecode(self.body)) return ODict([]) def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): + if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): return ODict( utils.multipartdecode( self.headers, @@ -405,9 +405,9 @@ class Request(Message): but not the resolved name. This is disabled by default, as an attacker may spoof the host header to confuse an analyst. """ - if hostheader and b"Host" in self.headers: + if hostheader and "Host" in self.headers: try: - return self.headers[b"Host"].decode("idna") + return self.headers["Host"].decode("idna") except ValueError: pass if self.host: @@ -426,7 +426,7 @@ class Request(Message): Returns a possibly empty netlib.odict.ODict object. """ ret = ODict() - for i in self.headers.get_all("cookie"): + for i in self.headers.get_all("Cookie"): ret.extend(cookies.parse_cookie_header(i)) return ret @@ -468,9 +468,9 @@ class Request(Message): class Response(Message): _headers_to_strip_off = [ - b'Proxy-Connection', - b'Alternate-Protocol', - b'Alt-Svc', + 'Proxy-Connection', + 'Alternate-Protocol', + 'Alt-Svc', ] def __init__( @@ -498,7 +498,7 @@ class Response(Message): return "".format( status_code=self.status_code, msg=self.msg, - contenttype=self.headers.get("content-type", "unknown content type"), + contenttype=self.headers.get("Content-Type", "unknown content type"), size=size) def get_cookies(self): @@ -511,7 +511,7 @@ class Response(Message): attributes (e.g. HTTPOnly) are indicated by a Null value. """ ret = [] - for header in self.headers.get_all(b"set-cookie"): + for header in self.headers.get_all("Set-Cookie"): v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v @@ -534,4 +534,4 @@ class Response(Message): i[1][1] ) ) - self.headers.set_all(b"Set-Cookie", values) + self.headers.set_all("Set-Cookie", values) diff --git a/netlib/odict.py b/netlib/odict.py index 11d5d52a..1124b23a 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,6 +1,7 @@ from __future__ import (absolute_import, print_function, division) import re import copy +import six def safe_subn(pattern, repl, target, *args, **kwargs): @@ -67,10 +68,10 @@ class ODict(object): Sets the values for key k. If there are existing values for this key, they are cleared. """ - if isinstance(valuelist, basestring): + if isinstance(valuelist, six.text_type) or isinstance(valuelist, six.binary_type): raise ValueError( "Expected list of values instead of string. " - "Example: odict['Host'] = ['www.example.com']" + "Example: odict[b'Host'] = [b'www.example.com']" ) kc = self._kconv(k) new = [] @@ -134,13 +135,6 @@ class ODict(object): def __repr__(self): return repr(self.lst) - def format(self): - elements = [] - for itm in self.lst: - elements.append(itm[0] + ": " + str(itm[1])) - elements.append("") - return "\r\n".join(elements) - def in_any(self, key, value, caseless=False): """ Do any of the values matching key contain value? @@ -156,19 +150,6 @@ class ODict(object): return True return False - def match_re(self, expr): - """ - Match the regular expression against each (key, value) pair. For - each pair a string of the following format is matched against: - - "key: value" - """ - for k, v in self.lst: - s = "%s: %s" % (k, v) - if re.search(expr, s): - return True - return False - def replace(self, pattern, repl, *args, **kwargs): """ Replaces a regular expression pattern with repl in both keys and diff --git a/netlib/tutils.py b/netlib/tutils.py index b69495a3..746e1488 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -123,9 +123,7 @@ def tresp(**kwargs): status_code=200, msg=b"OK", headers=Headers(header_response=b"svalue"), - body=b"message", - timestamp_start=time.time(), - timestamp_end=time.time() + body=b"message" ) default.update(kwargs) return Response(**default) diff --git a/netlib/utils.py b/netlib/utils.py index 14b428d7..6fed44b6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -246,7 +246,7 @@ def unparse_url(scheme, host, port, path=""): """ Returns a URL string, constructed from the specified compnents. """ - return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) + return b"%s://%s%s" % (scheme, hostport(scheme, host, port), path) def urlencode(s): @@ -295,7 +295,7 @@ def multipartdecode(headers, content): """ Takes a multipart boundary encoded string and returns list of (key, value) tuples. """ - v = headers.get("content-type") + v = headers.get(b"Content-Type") if v: v = parse_content_type(v) if not v: @@ -304,33 +304,33 @@ def multipartdecode(headers, content): if not boundary: return [] - rx = re.compile(r'\bname="([^"]+)"') + rx = re.compile(br'\bname="([^"]+)"') r = [] - for i in content.split("--" + boundary): + for i in content.split(b"--" + boundary): parts = i.splitlines() - if len(parts) > 1 and parts[0][0:2] != "--": + if len(parts) > 1 and parts[0][0:2] != b"--": match = rx.search(parts[1]) if match: key = match.group(1) - value = "".join(parts[3 + parts[2:].index(""):]) + value = b"".join(parts[3 + parts[2:].index(b""):]) r.append((key, value)) return r return [] -def always_bytes(unicode_or_bytes, encoding): +def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): - return unicode_or_bytes.encode(encoding) + return unicode_or_bytes.encode(*encode_args) return unicode_or_bytes -def always_byte_args(encoding): +def always_byte_args(*encode_args): """Decorator that transparently encodes all arguments passed as unicode""" def decorator(fun): def _fun(*args, **kwargs): - args = [always_bytes(arg, encoding) for arg in args] - kwargs = {k: always_bytes(v, encoding) for k, v in six.iteritems(kwargs)} + args = [always_bytes(arg, *encode_args) for arg in args] + kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} return fun(*args, **kwargs) return _fun return decorator diff --git a/test/http/test_models.py b/test/http/test_models.py index 8fce2e9d..c3ab4d0f 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -36,8 +36,8 @@ class TestRequest(object): assert isinstance(req.headers, Headers) def test_equal(self): - a = tutils.treq() - b = tutils.treq() + a = tutils.treq(timestamp_start=42, timestamp_end=43) + b = tutils.treq(timestamp_start=42, timestamp_end=43) assert a == b assert not a == 'foo' @@ -319,8 +319,8 @@ class TestResponse(object): assert isinstance(resp.headers, Headers) def test_equal(self): - a = tutils.tresp() - b = tutils.tresp() + a = tutils.tresp(timestamp_start=42, timestamp_end=43) + b = tutils.tresp(timestamp_start=42, timestamp_end=43) assert a == b assert not a == 'foo' diff --git a/test/test_encoding.py b/test/test_encoding.py index 9da3a38d..90f99338 100644 --- a/test/test_encoding.py +++ b/test/test_encoding.py @@ -2,10 +2,12 @@ from netlib import encoding def test_identity(): - assert "string" == encoding.decode("identity", "string") - assert "string" == encoding.encode("identity", "string") - assert not encoding.encode("nonexistent", "string") - assert None == encoding.decode("nonexistent encoding", "string") + assert b"string" == encoding.decode("identity", b"string") + assert b"string" == encoding.encode("identity", b"string") + assert b"string" == encoding.encode(b"identity", b"string") + assert b"string" == encoding.decode(b"identity", b"string") + assert not encoding.encode("nonexistent", b"string") + assert not encoding.decode("nonexistent encoding", b"string") def test_gzip(): diff --git a/test/test_odict.py b/test/test_odict.py index be3d862d..962c0daa 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -1,7 +1,7 @@ from netlib import odict, tutils -class TestODict: +class TestODict(object): def setUp(self): self.od = odict.ODict() @@ -13,21 +13,10 @@ class TestODict: def test_str_err(self): h = odict.ODict() - tutils.raises(ValueError, h.__setitem__, "key", "foo") - - def test_dictToHeader1(self): - self.od.add("one", "uno") - self.od.add("two", "due") - self.od.add("two", "tre") - expected = [ - "one: uno\r\n", - "two: due\r\n", - "two: tre\r\n", - "\r\n" - ] - out = self.od.format() - for i in expected: - assert out.find(i) >= 0 + with tutils.raises(ValueError): + h["key"] = u"foo" + with tutils.raises(ValueError): + h["key"] = b"foo" def test_getset_state(self): self.od.add("foo", 1) @@ -40,23 +29,6 @@ class TestODict: b.load_state(state) assert b == self.od - def test_dictToHeader2(self): - self.od["one"] = ["uno"] - expected1 = "one: uno\r\n" - expected2 = "\r\n" - out = self.od.format() - assert out.find(expected1) >= 0 - assert out.find(expected2) >= 0 - - def test_match_re(self): - h = odict.ODict() - h.add("one", "uno") - h.add("two", "due") - h.add("two", "tre") - assert h.match_re("uno") - assert h.match_re("two: due") - assert not h.match_re("nonono") - def test_in_any(self): self.od["one"] = ["atwoa", "athreea"] assert self.od.in_any("one", "two") @@ -122,7 +94,7 @@ class TestODict: assert a["a"] == ["b", "b"] -class TestODictCaseless: +class TestODictCaseless(object): def setUp(self): self.od = odict.ODictCaseless() diff --git a/test/test_socks.py b/test/test_socks.py index 3d109f42..65a0f0eb 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -1,12 +1,12 @@ -from cStringIO import StringIO +from io import BytesIO import socket from nose.plugins.skip import SkipTest from netlib import socks, tcp, tutils def test_client_greeting(): - raw = tutils.treader("\x05\x02\x00\xBE\xEF") - out = StringIO() + raw = tutils.treader(b"\x05\x02\x00\xBE\xEF") + out = BytesIO() msg = socks.ClientGreeting.from_file(raw) msg.assert_socks5() msg.to_file(out) @@ -19,11 +19,11 @@ def test_client_greeting(): def test_client_greeting_assert_socks5(): - raw = tutils.treader("\x00\x00") + raw = tutils.treader(b"\x00\x00") msg = socks.ClientGreeting.from_file(raw) tutils.raises(socks.SocksError, msg.assert_socks5) - raw = tutils.treader("HTTP/1.1 200 OK" + " " * 100) + raw = tutils.treader(b"HTTP/1.1 200 OK" + " " * 100) msg = socks.ClientGreeting.from_file(raw) try: msg.assert_socks5() @@ -33,7 +33,7 @@ def test_client_greeting_assert_socks5(): else: assert False - raw = tutils.treader("GET / HTTP/1.1" + " " * 100) + raw = tutils.treader(b"GET / HTTP/1.1" + " " * 100) msg = socks.ClientGreeting.from_file(raw) try: msg.assert_socks5() @@ -43,7 +43,7 @@ def test_client_greeting_assert_socks5(): else: assert False - raw = tutils.treader("XX") + raw = tutils.treader(b"XX") tutils.raises( socks.SocksError, socks.ClientGreeting.from_file, @@ -52,8 +52,8 @@ def test_client_greeting_assert_socks5(): def test_server_greeting(): - raw = tutils.treader("\x05\x02") - out = StringIO() + raw = tutils.treader(b"\x05\x02") + out = BytesIO() msg = socks.ServerGreeting.from_file(raw) msg.assert_socks5() msg.to_file(out) @@ -64,7 +64,7 @@ def test_server_greeting(): def test_server_greeting_assert_socks5(): - raw = tutils.treader("HTTP/1.1 200 OK" + " " * 100) + raw = tutils.treader(b"HTTP/1.1 200 OK" + " " * 100) msg = socks.ServerGreeting.from_file(raw) try: msg.assert_socks5() @@ -74,7 +74,7 @@ def test_server_greeting_assert_socks5(): else: assert False - raw = tutils.treader("GET / HTTP/1.1" + " " * 100) + raw = tutils.treader(b"GET / HTTP/1.1" + " " * 100) msg = socks.ServerGreeting.from_file(raw) try: msg.assert_socks5() @@ -86,36 +86,37 @@ def test_server_greeting_assert_socks5(): def test_message(): - raw = tutils.treader("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") - out = StringIO() + raw = tutils.treader(b"\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") + out = BytesIO() msg = socks.Message.from_file(raw) msg.assert_socks5() - assert raw.read(2) == "\xBE\xEF" + assert raw.read(2) == b"\xBE\xEF" msg.to_file(out) assert out.getvalue() == raw.getvalue()[:-2] assert msg.ver == 5 assert msg.msg == 0x01 assert msg.atyp == 0x03 - assert msg.addr == ("example.com", 0xDEAD) + assert msg.addr == (b"example.com", 0xDEAD) def test_message_assert_socks5(): - raw = tutils.treader("\xEE\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") + raw = tutils.treader(b"\xEE\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") msg = socks.Message.from_file(raw) tutils.raises(socks.SocksError, msg.assert_socks5) def test_message_ipv4(): # Test ATYP=0x01 (IPV4) - raw = tutils.treader("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") - out = StringIO() + raw = tutils.treader(b"\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + out = BytesIO() msg = socks.Message.from_file(raw) - assert raw.read(2) == "\xBE\xEF" + left = raw.read(2) + assert left == b"\xBE\xEF" msg.to_file(out) assert out.getvalue() == raw.getvalue()[:-2] - assert msg.addr == ("127.0.0.1", 0xDEAD) + assert msg.addr == (b"127.0.0.1", 0xDEAD) def test_message_ipv6(): @@ -125,14 +126,14 @@ def test_message_ipv6(): ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344" raw = tutils.treader( - "\x05\x01\x00\x04" + + b"\x05\x01\x00\x04" + socket.inet_pton( socket.AF_INET6, ipv6_addr) + - "\xDE\xAD\xBE\xEF") - out = StringIO() + b"\xDE\xAD\xBE\xEF") + out = BytesIO() msg = socks.Message.from_file(raw) - assert raw.read(2) == "\xBE\xEF" + assert raw.read(2) == b"\xBE\xEF" msg.to_file(out) assert out.getvalue() == raw.getvalue()[:-2] @@ -140,13 +141,13 @@ def test_message_ipv6(): def test_message_invalid_rsv(): - raw = tutils.treader("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + raw = tutils.treader(b"\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") tutils.raises(socks.SocksError, socks.Message.from_file, raw) def test_message_unknown_atyp(): - raw = tutils.treader("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") + raw = tutils.treader(b"\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") tutils.raises(socks.SocksError, socks.Message.from_file, raw) m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) - tutils.raises(socks.SocksError, m.to_file, StringIO()) + tutils.raises(socks.SocksError, m.to_file, BytesIO()) diff --git a/test/test_utils.py b/test/test_utils.py index 0db75578..ff27486c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -84,10 +84,10 @@ def test_parse_url(): def test_unparse_url(): - assert utils.unparse_url("http", "foo.com", 99, "") == "http://foo.com:99" - assert utils.unparse_url("http", "foo.com", 80, "") == "http://foo.com" - assert utils.unparse_url("https", "foo.com", 80, "") == "https://foo.com:80" - assert utils.unparse_url("https", "foo.com", 443, "") == "https://foo.com" + assert utils.unparse_url(b"http", b"foo.com", 99, b"") == b"http://foo.com:99" + assert utils.unparse_url(b"http", b"foo.com", 80, b"/bar") == b"http://foo.com/bar" + assert utils.unparse_url(b"https", b"foo.com", 80, b"") == b"https://foo.com:80" + assert utils.unparse_url(b"https", b"foo.com", 443, b"") == b"https://foo.com" def test_urlencode(): @@ -122,7 +122,7 @@ def test_multipartdecode(): "--{0}\n" "Content-Disposition: form-data; name=\"field2\"\n\n" "value2\n" - "--{0}--".format(boundary).encode("ascii") + "--{0}--".format(boundary.decode()).encode() ) form = utils.multipartdecode(headers, content) -- cgit v1.2.3 From 266b80238db34cfa91f9018c951394492bbde593 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Sep 2015 17:29:55 +0200 Subject: fix tests --- netlib/tutils.py | 4 +++- test/test_utils.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/netlib/tutils.py b/netlib/tutils.py index 746e1488..4903d63b 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -123,7 +123,9 @@ def tresp(**kwargs): status_code=200, msg=b"OK", headers=Headers(header_response=b"svalue"), - body=b"message" + body=b"message", + timestamp_start=time.time(), + timestamp_end=time.time(), ) default.update(kwargs) return Response(**default) diff --git a/test/test_utils.py b/test/test_utils.py index ff27486c..fb7d357a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -11,7 +11,7 @@ def test_bidi(): def test_hexdump(): - assert utils.hexdump("one\0" * 10) + assert list(utils.hexdump("one\0" * 10)) def test_clean_bin(): -- cgit v1.2.3 From f2c87cff8adc8099ef8c3a85adf314e303c475b7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 17 Sep 2015 17:32:59 +0200 Subject: fix py3 tests --- test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils.py b/test/test_utils.py index fb7d357a..8f4b4059 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -11,7 +11,7 @@ def test_bidi(): def test_hexdump(): - assert list(utils.hexdump("one\0" * 10)) + assert list(utils.hexdump(b"one\0" * 10)) def test_clean_bin(): -- cgit v1.2.3 From 7b6b15754754b45552d0872d36f3f30f5fa1a783 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 18 Sep 2015 15:35:02 +0200 Subject: properly handle SNI IPs fixes mitmproxy/mitmproxy#772 We must use the ipaddress package here, because that's what cryptography uses. If we opt for something else, we have nasty namespace conflicts. --- netlib/certutils.py | 11 +++++++++-- setup.py | 23 ++++++++++++++--------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index cc143a50..c3b795ac 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -4,6 +4,7 @@ import ssl import time import datetime import itertools +import ipaddress from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error @@ -85,8 +86,13 @@ def dummy_cert(privkey, cacert, commonname, sans): """ ss = [] for i in sans: - ss.append("DNS: %s" % i) - ss = ", ".join(ss) + try: + ipaddress.ip_address(i.decode("ascii")) + except ValueError: + ss.append(b"DNS: %s" % i) + else: + ss.append(b"IP: %s" % i) + ss = b", ".join(ss) cert = OpenSSL.crypto.X509() cert.gmtime_adj_notBefore(-3600 * 48) @@ -335,6 +341,7 @@ class CertStore(object): class _GeneralName(univ.Choice): # We are only interested in dNSNames. We use a default handler to ignore # other types. + # TODO: We should also handle iPAddresses. componentType = namedtype.NamedTypes( namedtype.NamedType('dNSName', char.IA5String().subtype( implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) diff --git a/setup.py b/setup.py index d3c09ceb..0c9fb07b 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ from setuptools import setup, find_packages from codecs import open import os +import sys from netlib import version @@ -13,6 +14,18 @@ here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, 'README.mkd'), encoding='utf-8') as f: long_description = f.read() +deps = { + "pyasn1>=0.1.7", + "pyOpenSSL>=0.15.1", + "cryptography>=1.0", + "passlib>=1.6.2", + "hpack>=1.0.1", + "six>=1.9.0", + "certifi>=2015.9.6.2", +} +if sys.version_info < (3, 0): + deps.add("ipaddress>=1.0.14") + setup( name="netlib", version=version.VERSION, @@ -40,15 +53,7 @@ setup( packages=find_packages(), include_package_data=True, zip_safe=False, - install_requires=[ - "pyasn1>=0.1.7", - "pyOpenSSL>=0.15.1", - "cryptography>=1.0", - "passlib>=1.6.2", - "hpack>=1.0.1", - "six>=1.9.0", - "certifi" - ], + install_requires=list(deps), extras_require={ 'dev': [ "mock>=1.0.1", -- cgit v1.2.3 From d1904c2f52dfc7409ae275bb081f23635c94acc9 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 18 Sep 2015 15:38:31 +0200 Subject: python3++ --- netlib/certutils.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index c3b795ac..9193b757 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -12,7 +12,7 @@ import OpenSSL DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 # Generated with "openssl dhparam". It's too slow to generate this on startup. -DEFAULT_DHPARAM = """ +DEFAULT_DHPARAM = b""" -----BEGIN DH PARAMETERS----- MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv @@ -43,29 +43,29 @@ def create_ca(o, cn, exp): cert.set_pubkey(key) cert.add_extensions([ OpenSSL.crypto.X509Extension( - "basicConstraints", + b"basicConstraints", True, - "CA:TRUE" + b"CA:TRUE" ), OpenSSL.crypto.X509Extension( - "nsCertType", + b"nsCertType", False, - "sslCA" + b"sslCA" ), OpenSSL.crypto.X509Extension( - "extendedKeyUsage", + b"extendedKeyUsage", False, - "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" + b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" ), OpenSSL.crypto.X509Extension( - "keyUsage", + b"keyUsage", True, - "keyCertSign, cRLSign" + b"keyCertSign, cRLSign" ), OpenSSL.crypto.X509Extension( - "subjectKeyIdentifier", + b"subjectKeyIdentifier", False, - "hash", + b"hash", subject=cert ), ]) @@ -103,7 +103,7 @@ def dummy_cert(privkey, cacert, commonname, sans): if ss: cert.set_version(2) cert.add_extensions( - [OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) + [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)]) cert.set_pubkey(cacert.get_pubkey()) cert.sign(privkey, "sha256") return SSLCert(cert) @@ -291,14 +291,14 @@ class CertStore(object): @staticmethod def asterisk_forms(dn): - parts = dn.split(".") + parts = dn.split(b".") parts.reverse() - curr_dn = "" - dn_forms = ["*"] + curr_dn = b"" + dn_forms = [b"*"] for part in parts[:-1]: - curr_dn = "." + part + curr_dn # .example.com - dn_forms.append("*" + curr_dn) # *.example.com - if parts[-1] != "*": + curr_dn = b"." + part + curr_dn # .example.com + dn_forms.append(b"*" + curr_dn) # *.example.com + if parts[-1] != b"*": dn_forms.append(parts[-1] + curr_dn) return dn_forms @@ -430,7 +430,7 @@ class SSLCert(object): def cn(self): c = None for i in self.subject: - if i[0] == "CN": + if i[0] == b"CN": c = i[1] return c @@ -439,7 +439,7 @@ class SSLCert(object): altnames = [] for i in range(self.x509.get_extension_count()): ext = self.x509.get_extension(i) - if ext.get_short_name() == "subjectAltName": + if ext.get_short_name() == b"subjectAltName": try: dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) except PyAsn1Error: -- cgit v1.2.3 From 551d9f11e571eac495674f1c23cfd0dfa8af2cb7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 18 Sep 2015 18:05:50 +0200 Subject: experimental: don't interfere with headers --- netlib/http/http1/assemble.py | 20 +++++--------------- netlib/http/models.py | 21 ++++----------------- test/http/http1/test_assemble.py | 11 +---------- test/http/test_models.py | 11 +++++++---- 4 files changed, 17 insertions(+), 46 deletions(-) diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index b65a6be0..c2b60a0f 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -10,7 +10,8 @@ def assemble_request(request): if request.body == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_request_head(request) - return head + request.body + body = b"".join(assemble_body(request.headers, [request.body])) + return head + body def assemble_request_head(request): @@ -23,7 +24,8 @@ def assemble_response(response): if response.body == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_response_head(response) - return head + response.body + body = b"".join(assemble_body(response.headers, [response.body])) + return head + body def assemble_response_head(response): @@ -74,20 +76,12 @@ def _assemble_request_line(request, form=None): def _assemble_request_headers(request): headers = request.headers.copy() - for k in request._headers_to_strip_off: - headers.pop(k, None) if b"host" not in headers and request.scheme and request.host and request.port: headers[b"Host"] = utils.hostport( request.scheme, request.host, request.port ) - - # If content is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if request.body or request.body == b"": - headers[b"Content-Length"] = str(len(request.body)).encode("ascii") - return bytes(headers) @@ -100,8 +94,4 @@ def _assemble_response_line(response): def _assemble_response_headers(response): - headers = response.headers.copy() - for k in response._headers_to_strip_off: - headers.pop(k, None) - - return bytes(headers) + return bytes(response.headers) diff --git a/netlib/http/models.py b/netlib/http/models.py index bc681de3..ff854b13 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -92,7 +92,10 @@ class Headers(MutableMapping, object): self.update(headers) def __bytes__(self): - return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + if self.fields: + return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + else: + return b"" if six.PY2: __str__ = __bytes__ @@ -224,16 +227,6 @@ class Message(object): class Request(Message): - # This list is adopted legacy code. - # We probably don't need to strip off keep-alive. - _headers_to_strip_off = [ - 'Proxy-Connection', - 'Keep-Alive', - 'Connection', - 'Transfer-Encoding', - 'Upgrade', - ] - def __init__( self, form_in, @@ -467,12 +460,6 @@ class Request(Message): class Response(Message): - _headers_to_strip_off = [ - 'Proxy-Connection', - 'Alternate-Protocol', - 'Alt-Svc', - ] - def __init__( self, http_version, diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index cdc8bda9..2d250909 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -79,15 +79,10 @@ def test_assemble_request_headers(): r = treq(body=b"") r.headers[b"Transfer-Encoding"] = b"chunked" c = _assemble_request_headers(r) - assert b"Content-Length" in c - assert b"Transfer-Encoding" not in c + assert b"Transfer-Encoding" in c assert b"Host" in _assemble_request_headers(treq(headers=Headers())) - assert b"Proxy-Connection" not in _assemble_request_headers( - treq(headers=Headers(Proxy_Connection="42")) - ) - def test_assemble_response_headers(): # https://github.com/mitmproxy/mitmproxy/issues/186 @@ -95,7 +90,3 @@ def test_assemble_response_headers(): r.headers["Transfer-Encoding"] = b"chunked" c = _assemble_response_headers(r) assert b"Transfer-Encoding" in c - - assert b"Proxy-Connection" not in _assemble_response_headers( - tresp(headers=Headers(Proxy_Connection=b"42")) - ) diff --git a/test/http/test_models.py b/test/http/test_models.py index c3ab4d0f..6970a6e4 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -442,13 +442,16 @@ class TestHeaders(object): def test_str(self): headers = Headers(Host="example.com") - assert bytes(headers) == "Host: example.com\r\n" + assert bytes(headers) == b"Host: example.com\r\n" headers = Headers([ - ["Host", "example.com"], - ["Accept", "text/plain"] + [b"Host", b"example.com"], + [b"Accept", b"text/plain"] ]) - assert str(headers) == "Host: example.com\r\nAccept: text/plain\r\n" + assert bytes(headers) == b"Host: example.com\r\nAccept: text/plain\r\n" + + headers = Headers() + assert bytes(headers) == b"" def test_setitem(self): headers = Headers() -- cgit v1.2.3 From 91cdd78201497e89b9a17275a484d461f0143137 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 19 Sep 2015 11:59:40 +0200 Subject: improve http error messages --- netlib/http/http1/read.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 7f2b7bab..c6760ff3 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -215,9 +215,9 @@ def _get_first_line(rfile): # Possible leftover from previous message line = rfile.readline() except TcpDisconnect: - raise HttpReadDisconnect() + raise HttpReadDisconnect("Remote disconnected") if not line: - raise HttpReadDisconnect() + raise HttpReadDisconnect("Remote disconnected") line = line.strip() try: line.decode("ascii") @@ -227,7 +227,11 @@ def _get_first_line(rfile): def _read_request_line(rfile): - line = _get_first_line(rfile) + try: + line = _get_first_line(rfile) + except HttpReadDisconnect: + # We want to provide a better error message. + raise HttpReadDisconnect("Client disconnected") try: method, path, http_version = line.split(b" ") @@ -270,7 +274,11 @@ def _parse_authority_form(hostport): def _read_response_line(rfile): - line = _get_first_line(rfile) + try: + line = _get_first_line(rfile) + except HttpReadDisconnect: + # We want to provide a better error message. + raise HttpReadDisconnect("Server disconnected") try: -- cgit v1.2.3 From 3f1ca556d14ce71331b8dbc69be4db670863271a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 20 Sep 2015 18:12:55 +0200 Subject: python3++ --- .travis.yml | 1 + netlib/certutils.py | 11 ++++++----- netlib/tcp.py | 4 ++-- netlib/wsgi.py | 31 +++++++++++++++++-------------- test/test_certutils.py | 30 +++++++++++++++--------------- test/test_socks.py | 2 +- test/test_tcp.py | 30 +++++++++++++++--------------- test/test_wsgi.py | 6 +++--- 8 files changed, 60 insertions(+), 55 deletions(-) diff --git a/.travis.yml b/.travis.yml index 7e18176c..00f8b4db 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,6 +21,7 @@ matrix: - nosetests --with-cov --cov-report term-missing test/test_utils.py - nosetests --with-cov --cov-report term-missing test/test_encoding.py - nosetests --with-cov --cov-report term-missing test/test_odict.py + - nosetests --with-cov --cov-report term-missing test/test_certutils.py - python: pypy - python: pypy env: OPENSSL=1.0.2 diff --git a/netlib/certutils.py b/netlib/certutils.py index 9193b757..df793537 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -5,6 +5,8 @@ import time import datetime import itertools import ipaddress + +import sys from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error @@ -184,7 +186,7 @@ class CertStore(object): with open(path, "wb") as f: f.write(DEFAULT_DHPARAM) - bio = OpenSSL.SSL._lib.BIO_new_file(path, b"r") + bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r") if bio != OpenSSL.SSL._ffi.NULL: bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( @@ -318,10 +320,9 @@ class CertStore(object): potential_keys.append((commonname, tuple(sans))) name = next( - itertools.ifilter( - lambda key: key in self.certs, - potential_keys), - None) + filter(lambda key: key in self.certs, potential_keys), + None + ) if name: entry = self.certs[name] else: diff --git a/netlib/tcp.py b/netlib/tcp.py index 707e11e0..6dcc8c72 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -76,7 +76,7 @@ class SSLKeyLogger(object): d = os.path.dirname(self.filename) if not os.path.isdir(d): os.makedirs(d) - self.f = open(self.filename, "ab") + self.f = open(self.filename, "a") self.f.write("\r\n") client_random = connection.client_random().encode("hex") masterkey = connection.master_key().encode("hex") @@ -184,7 +184,7 @@ class Reader(_FileLike): """ If length is -1, we read until connection closes. """ - result = '' + result = b'' start = time.time() while length == -1 or length > 0: if length == -1 or length > self.BLOCKSIZE: diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 8a98884a..fba9f388 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,8 +1,11 @@ from __future__ import (absolute_import, print_function, division) -import cStringIO +from io import BytesIO import urllib import time import traceback + +import six + from . import http, tcp @@ -58,7 +61,7 @@ class WSGIAdaptor(object): environ = { 'wsgi.version': (1, 0), 'wsgi.url_scheme': flow.request.scheme, - 'wsgi.input': cStringIO.StringIO(flow.request.body or ""), + 'wsgi.input': BytesIO(flow.request.body or b""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, @@ -91,17 +94,17 @@ class WSGIAdaptor(object): Make a best-effort attempt to write an error page. If headers are already sent, we just bung the error into the page. """ - c = """ + c = b"""

Internal Server Error

%s"
- """ % s + """.strip() % s if not headers_sent: - soc.write("HTTP/1.1 500 Internal Server Error\r\n") - soc.write("Content-Type: text/html\r\n") - soc.write("Content-Length: %s\r\n" % len(c)) - soc.write("\r\n") + soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") + soc.write(b"Content-Type: text/html\r\n") + soc.write(b"Content-Length: %s\r\n" % len(c)) + soc.write(b"\r\n") soc.write(c) def serve(self, request, soc, **env): @@ -114,14 +117,14 @@ class WSGIAdaptor(object): def write(data): if not state["headers_sent"]: - soc.write("HTTP/1.1 %s\r\n" % state["status"]) + soc.write(b"HTTP/1.1 %s\r\n" % state["status"]) headers = state["headers"] if 'server' not in headers: headers["Server"] = self.sversion if 'date' not in headers: headers["Date"] = date_time_string() - soc.write(str(headers)) - soc.write("\r\n") + soc.write(bytes(headers)) + soc.write(b"\r\n") state["headers_sent"] = True if data: soc.write(data) @@ -131,7 +134,7 @@ class WSGIAdaptor(object): if exc_info: try: if state["headers_sent"]: - raise exc_info[0], exc_info[1], exc_info[2] + six.reraise(*exc_info) finally: exc_info = None elif state["status"]: @@ -140,7 +143,7 @@ class WSGIAdaptor(object): state["headers"] = http.Headers(headers) return write - errs = cStringIO.StringIO() + errs = BytesIO() try: dataiter = self.app( self.make_environ(request, errs, **env), start_response @@ -148,7 +151,7 @@ class WSGIAdaptor(object): for i in dataiter: write(i) if not state["headers_sent"]: - write("") + write(b"") except Exception as e: try: s = traceback.format_exc() diff --git a/test/test_certutils.py b/test/test_certutils.py index b44879f6..fc91609e 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -36,10 +36,10 @@ class TestCertStore: def test_create_explicit(self): with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") - assert ca.get_cert("foo", []) + assert ca.get_cert(b"foo", []) ca2 = certutils.CertStore.from_store(d, "test") - assert ca2.get_cert("foo", []) + assert ca2.get_cert(b"foo", []) assert ca.default_ca.get_serial_number( ) == ca2.default_ca.get_serial_number() @@ -47,11 +47,11 @@ class TestCertStore: def test_create_tmp(self): with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") - assert ca.get_cert("foo.com", []) - assert ca.get_cert("foo.com", []) - assert ca.get_cert("*.foo.com", []) + assert ca.get_cert(b"foo.com", []) + assert ca.get_cert(b"foo.com", []) + assert ca.get_cert(b"*.foo.com", []) - r = ca.get_cert("*.foo.com", []) + r = ca.get_cert(b"*.foo.com", []) assert r[1] == ca.default_privatekey def test_add_cert(self): @@ -61,18 +61,18 @@ class TestCertStore: def test_sans(self): with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") - c1 = ca.get_cert("foo.com", ["*.bar.com"]) - ca.get_cert("foo.bar.com", []) + c1 = ca.get_cert(b"foo.com", [b"*.bar.com"]) + ca.get_cert(b"foo.bar.com", []) # assert c1 == c2 - c3 = ca.get_cert("bar.com", []) + c3 = ca.get_cert(b"bar.com", []) assert not c1 == c3 def test_sans_change(self): with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") - ca.get_cert("foo.com", ["*.bar.com"]) - cert, key, chain_file = ca.get_cert("foo.bar.com", ["*.baz.com"]) - assert "*.baz.com" in cert.altnames + ca.get_cert(b"foo.com", [b"*.bar.com"]) + cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"]) + assert b"*.baz.com" in cert.altnames def test_overrides(self): with tutils.tmpdir() as d: @@ -81,14 +81,14 @@ class TestCertStore: assert not ca1.default_ca.get_serial_number( ) == ca2.default_ca.get_serial_number() - dc = ca2.get_cert("foo.com", ["sans.example.com"]) + dc = ca2.get_cert(b"foo.com", [b"sans.example.com"]) dcp = os.path.join(d, "dc") f = open(dcp, "wb") f.write(dc[0].to_pem()) f.close() - ca1.add_cert_file("foo.com", dcp) + ca1.add_cert_file(b"foo.com", dcp) - ret = ca1.get_cert("foo.com", []) + ret = ca1.get_cert(b"foo.com", []) assert ret[0].serial == dc[0].serial diff --git a/test/test_socks.py b/test/test_socks.py index 65a0f0eb..f2fb9b98 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -23,7 +23,7 @@ def test_client_greeting_assert_socks5(): msg = socks.ClientGreeting.from_file(raw) tutils.raises(socks.SocksError, msg.assert_socks5) - raw = tutils.treader(b"HTTP/1.1 200 OK" + " " * 100) + raw = tutils.treader(b"HTTP/1.1 200 OK" + b" " * 100) msg = socks.ClientGreeting.from_file(raw) try: msg.assert_socks5() diff --git a/test/test_tcp.py b/test/test_tcp.py index 615900ce..dc0efeb0 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,5 +1,5 @@ -import cStringIO -import Queue +from io import BytesIO +from six.moves import queue import time import socket import random @@ -435,7 +435,7 @@ class TestSSLDisconnect(tservers.ServerTestBase): c.rfile.read(10) c.close() tutils.raises(TcpDisconnect, c.wfile.write, "foo") - tutils.raises(Queue.Empty, self.q.get_nowait) + tutils.raises(queue.Empty, self.q.get_nowait) class TestSSLHardDisconnect(tservers.ServerTestBase): @@ -578,7 +578,7 @@ class TestTCPClient: class TestFileLike: def test_blocksize(self): - s = cStringIO.StringIO("1234567890abcdefghijklmnopqrstuvwxyz") + s = BytesIO(b"1234567890abcdefghijklmnopqrstuvwxyz") s = tcp.Reader(s) s.BLOCKSIZE = 2 assert s.read(1) == "1" @@ -589,7 +589,7 @@ class TestFileLike: assert d.startswith("abc") and d.endswith("xyz") def test_wrap(self): - s = cStringIO.StringIO("foobar\nfoobar") + s = BytesIO(b"foobar\nfoobar") s.flush() s = tcp.Reader(s) assert s.readline() == "foobar\n" @@ -598,18 +598,18 @@ class TestFileLike: assert s.isatty def test_limit(self): - s = cStringIO.StringIO("foobar\nfoobar") + s = BytesIO(b"foobar\nfoobar") s = tcp.Reader(s) assert s.readline(3) == "foo" def test_limitless(self): - s = cStringIO.StringIO("f" * (50 * 1024)) + s = BytesIO(b"f" * (50 * 1024)) s = tcp.Reader(s) ret = s.read(-1) assert len(ret) == 50 * 1024 def test_readlog(self): - s = cStringIO.StringIO("foobar\nfoobar") + s = BytesIO(b"foobar\nfoobar") s = tcp.Reader(s) assert not s.is_logging() s.start_log() @@ -626,7 +626,7 @@ class TestFileLike: tutils.raises(ValueError, s.get_log) def test_writelog(self): - s = cStringIO.StringIO() + s = BytesIO() s = tcp.Writer(s) s.start_log() assert s.is_logging() @@ -636,7 +636,7 @@ class TestFileLike: assert s.get_log() == "xx" def test_writer_flush_error(self): - s = cStringIO.StringIO() + s = BytesIO() s = tcp.Writer(s) o = mock.MagicMock() o.flush = mock.MagicMock(side_effect=socket.error) @@ -644,7 +644,7 @@ class TestFileLike: tutils.raises(TcpDisconnect, s.flush) def test_reader_read_error(self): - s = cStringIO.StringIO("foobar\nfoobar") + s = BytesIO(b"foobar\nfoobar") s = tcp.Reader(s) o = mock.MagicMock() o.read = mock.MagicMock(side_effect=socket.error) @@ -652,14 +652,14 @@ class TestFileLike: tutils.raises(TcpDisconnect, s.read, 10) def test_reset_timestamps(self): - s = cStringIO.StringIO("foobar\nfoobar") + s = BytesIO(b"foobar\nfoobar") s = tcp.Reader(s) s.first_byte_timestamp = 500 s.reset_timestamps() assert not s.first_byte_timestamp def test_first_byte_timestamp_updated_on_read(self): - s = cStringIO.StringIO("foobar\nfoobar") + s = BytesIO(b"foobar\nfoobar") s = tcp.Reader(s) s.read(1) assert s.first_byte_timestamp @@ -668,7 +668,7 @@ class TestFileLike: assert s.first_byte_timestamp == expected def test_first_byte_timestamp_updated_on_readline(self): - s = cStringIO.StringIO("foobar\nfoobar\nfoobar") + s = BytesIO(b"foobar\nfoobar\nfoobar") s = tcp.Reader(s) s.readline() assert s.first_byte_timestamp @@ -695,7 +695,7 @@ class TestFileLike: tutils.raises(TcpDisconnect, s.readline, 10) def test_reader_incomplete_error(self): - s = cStringIO.StringIO("foobar") + s = BytesIO(b"foobar") s = tcp.Reader(s) tutils.raises(TcpReadIncomplete, s.safe_read, 10) diff --git a/test/test_wsgi.py b/test/test_wsgi.py index e26e1413..856967af 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -1,4 +1,4 @@ -import cStringIO +from io import BytesIO import sys from netlib import wsgi from netlib.http import Headers @@ -41,7 +41,7 @@ class TestWSGI: f.request.host = "foo" f.request.port = 80 - wfile = cStringIO.StringIO() + wfile = BytesIO() err = w.serve(f, wfile) assert ta.called assert not err @@ -55,7 +55,7 @@ class TestWSGI: f = tflow() f.request.host = "foo" f.request.port = 80 - wfile = cStringIO.StringIO() + wfile = BytesIO() w.serve(f, wfile) return wfile.getvalue() -- cgit v1.2.3 From 693cdfc6d75e460a00585ccc9b734b80d6eba74d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 20 Sep 2015 19:40:09 +0200 Subject: python3++ --- .gitignore | 1 + .travis.yml | 1 + netlib/certutils.py | 6 +++--- netlib/socks.py | 22 +++++++++++++--------- netlib/utils.py | 6 ++++++ test/test_certutils.py | 10 +++++----- test/test_socks.py | 18 +++++++----------- 7 files changed, 36 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index c3c6f1cb..d8ffb588 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ _cffi__* .eggs/ netlib.egg-info/ pathod/ +.cache/ \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 00f8b4db..c8cbeaa2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -22,6 +22,7 @@ matrix: - nosetests --with-cov --cov-report term-missing test/test_encoding.py - nosetests --with-cov --cov-report term-missing test/test_odict.py - nosetests --with-cov --cov-report term-missing test/test_certutils.py + - nosetests --with-cov --cov-report term-missing test/test_socks.py - python: pypy - python: pypy env: OPENSSL=1.0.2 diff --git a/netlib/certutils.py b/netlib/certutils.py index df793537..b3ddcbe4 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -3,7 +3,7 @@ import os import ssl import time import datetime -import itertools +from six.moves import filter import ipaddress import sys @@ -396,12 +396,12 @@ class SSLCert(object): @property def notbefore(self): t = self.x509.get_notBefore() - return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") + return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") @property def notafter(self): t = self.x509.get_notAfter() - return datetime.datetime.strptime(t, "%Y%m%d%H%M%SZ") + return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") @property def has_expired(self): diff --git a/netlib/socks.py b/netlib/socks.py index d38b88c8..51ad1c63 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -1,7 +1,7 @@ from __future__ import (absolute_import, print_function, division) -import socket import struct import array +import ipaddress from . import tcp, utils @@ -133,19 +133,23 @@ class Message(object): def from_file(cls, f): ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) if rsv != 0x00: - raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, - "Socks Request: Invalid reserved byte: %s" % rsv) - + raise SocksError( + REP.GENERAL_SOCKS_SERVER_FAILURE, + "Socks Request: Invalid reserved byte: %s" % rsv + ) if atyp == ATYP.IPV4_ADDRESS: # We use tnoa here as ntop is not commonly available on Windows. - host = socket.inet_ntoa(f.safe_read(4)) + host = ipaddress.IPv4Address(f.safe_read(4)).compressed use_ipv6 = False elif atyp == ATYP.IPV6_ADDRESS: - host = socket.inet_ntop(socket.AF_INET6, f.safe_read(16)) + host = ipaddress.IPv6Address(f.safe_read(16)).compressed use_ipv6 = True elif atyp == ATYP.DOMAINNAME: length, = struct.unpack("!B", f.safe_read(1)) host = f.safe_read(length) + if not utils.is_valid_host(host): + raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Invalid hostname: %s" % host) + host = host.decode("idna") use_ipv6 = False else: raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, @@ -158,12 +162,12 @@ class Message(object): def to_file(self, f): f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) if self.atyp == ATYP.IPV4_ADDRESS: - f.write(socket.inet_aton(self.addr.host)) + f.write(ipaddress.IPv4Address(self.addr.host).packed) elif self.atyp == ATYP.IPV6_ADDRESS: - f.write(socket.inet_pton(socket.AF_INET6, self.addr.host)) + f.write(ipaddress.IPv6Address(self.addr.host).packed) elif self.atyp == ATYP.DOMAINNAME: f.write(struct.pack("!B", len(self.addr.host))) - f.write(self.addr.host) + f.write(self.addr.host.encode("idna")) else: raise SocksError( REP.ADDRESS_TYPE_NOT_SUPPORTED, diff --git a/netlib/utils.py b/netlib/utils.py index 6fed44b6..799b0d42 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -141,6 +141,12 @@ _label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? Date: Sun, 20 Sep 2015 19:56:45 +0200 Subject: python3++ --- netlib/tcp.py | 8 +++++--- test/test_socks.py | 2 +- test/test_tcp.py | 20 ++++++++++---------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 6dcc8c72..f6f7d06f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -7,6 +7,8 @@ import threading import time import traceback +from six.moves import range + import certifi import six import OpenSSL @@ -227,7 +229,7 @@ class Reader(_FileLike): return result def readline(self, size=None): - result = '' + result = b'' bytes_read = 0 while True: if size is not None and bytes_read >= size: @@ -399,7 +401,7 @@ def close_socket(sock): sock.settimeout(sock.gettimeout() or 20) # limit at a megabyte so that we don't read infinitely - for _ in xrange(1024 ** 3 // 4096): + for _ in range(1024 ** 3 // 4096): # may raise a timeout/disconnect exception. if not sock.recv(4096): break @@ -649,7 +651,7 @@ class TCPClient(_Connection): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: - return "" + return b"" class BaseHandler(_Connection): diff --git a/test/test_socks.py b/test/test_socks.py index dd8e2807..d95dee41 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -121,7 +121,7 @@ def test_message_ipv4(): def test_message_ipv6(): # Test ATYP=0x04 (IPV6) - ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344" + ipv6_addr = u"2001:db8:85a3:8d3:1319:8a2e:370:7344" raw = tutils.treader( b"\x05\x01\x00\x04" + diff --git a/test/test_tcp.py b/test/test_tcp.py index dc0efeb0..725aa8b0 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -49,9 +49,9 @@ class ALPNHandler(tcp.BaseHandler): def handle(self): alp = self.get_alpn_proto_negotiated() if alp: - self.wfile.write("%s" % alp) + self.wfile.write(("%s" % alp).encode("ascii")) else: - self.wfile.write("NONE") + self.wfile.write(b"NONE") self.wfile.flush() @@ -503,24 +503,24 @@ class TestALPNClient(tservers.ServerTestBase): def test_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"]) - assert c.get_alpn_proto_negotiated() == "bar" - assert c.rfile.readline().strip() == "bar" + c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"]) + assert c.get_alpn_proto_negotiated() == b"bar" + assert c.rfile.readline().strip() == b"bar" def test_no_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - assert c.get_alpn_proto_negotiated() == "" - assert c.rfile.readline().strip() == "NONE" + assert c.get_alpn_proto_negotiated() == b"" + assert c.rfile.readline().strip() == b"NONE" else: def test_none_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=["foo", "bar", "fasel"]) - assert c.get_alpn_proto_negotiated() == "" - assert c.rfile.readline() == "NONE" + c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"]) + assert c.get_alpn_proto_negotiated() == b"" + assert c.rfile.readline() == b"NONE" class TestNoSSLNoALPNClient(tservers.ServerTestBase): -- cgit v1.2.3 From 292a0aa9e671748f0ad77a5e8f8f21d40314b030 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 20 Sep 2015 19:56:57 +0200 Subject: make tests compatible with py.test --- setup.py | 2 + test/test_odict.py | 144 ++++++++++++++++++++++++++++------------------------- test/tservers.py | 4 +- 3 files changed, 79 insertions(+), 71 deletions(-) diff --git a/setup.py b/setup.py index 0c9fb07b..ac0d36cf 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,8 @@ setup( extras_require={ 'dev': [ "mock>=1.0.1", + "pytest>=2.8.0", + "pytest-xdist>=1.13.1", "nose>=1.3.0", "nose-cov>=1.6", "coveralls>=0.4.1", diff --git a/test/test_odict.py b/test/test_odict.py index 962c0daa..88197026 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -3,9 +3,6 @@ from netlib import odict, tutils class TestODict(object): - def setUp(self): - self.od = odict.ODict() - def test_repr(self): h = odict.ODict() h["one"] = ["two"] @@ -19,72 +16,81 @@ class TestODict(object): h["key"] = b"foo" def test_getset_state(self): - self.od.add("foo", 1) - self.od.add("foo", 2) - self.od.add("bar", 3) - state = self.od.get_state() + od = odict.ODict() + od.add("foo", 1) + od.add("foo", 2) + od.add("bar", 3) + state = od.get_state() nd = odict.ODict.from_state(state) - assert nd == self.od + assert nd == od b = odict.ODict() b.load_state(state) - assert b == self.od + assert b == od def test_in_any(self): - self.od["one"] = ["atwoa", "athreea"] - assert self.od.in_any("one", "two") - assert self.od.in_any("one", "three") - assert not self.od.in_any("one", "four") - assert not self.od.in_any("nonexistent", "foo") - assert not self.od.in_any("one", "TWO") - assert self.od.in_any("one", "TWO", True) + od = odict.ODict() + od["one"] = ["atwoa", "athreea"] + assert od.in_any("one", "two") + assert od.in_any("one", "three") + assert not od.in_any("one", "four") + assert not od.in_any("nonexistent", "foo") + assert not od.in_any("one", "TWO") + assert od.in_any("one", "TWO", True) def test_iter(self): - assert not [i for i in self.od] - self.od.add("foo", 1) - assert [i for i in self.od] + od = odict.ODict() + assert not [i for i in od] + od.add("foo", 1) + assert [i for i in od] def test_keys(self): - assert not self.od.keys() - self.od.add("foo", 1) - assert self.od.keys() == ["foo"] - self.od.add("foo", 2) - assert self.od.keys() == ["foo"] - self.od.add("bar", 2) - assert len(self.od.keys()) == 2 + od = odict.ODict() + assert not od.keys() + od.add("foo", 1) + assert od.keys() == ["foo"] + od.add("foo", 2) + assert od.keys() == ["foo"] + od.add("bar", 2) + assert len(od.keys()) == 2 def test_copy(self): - self.od.add("foo", 1) - self.od.add("foo", 2) - self.od.add("bar", 3) - assert self.od == self.od.copy() - assert not self.od != self.od.copy() + od = odict.ODict() + od.add("foo", 1) + od.add("foo", 2) + od.add("bar", 3) + assert od == od.copy() + assert not od != od.copy() def test_del(self): - self.od.add("foo", 1) - self.od.add("Foo", 2) - self.od.add("bar", 3) - del self.od["foo"] - assert len(self.od.lst) == 2 + od = odict.ODict() + od.add("foo", 1) + od.add("Foo", 2) + od.add("bar", 3) + del od["foo"] + assert len(od.lst) == 2 def test_replace(self): - self.od.add("one", "two") - self.od.add("two", "one") - assert self.od.replace("one", "vun") == 2 - assert self.od.lst == [ + od = odict.ODict() + od.add("one", "two") + od.add("two", "one") + assert od.replace("one", "vun") == 2 + assert od.lst == [ ["vun", "two"], ["two", "vun"], ] def test_get(self): - self.od.add("one", "two") - assert self.od.get("one") == ["two"] - assert self.od.get("two") is None + od = odict.ODict() + od.add("one", "two") + assert od.get("one") == ["two"] + assert od.get("two") is None def test_get_first(self): - self.od.add("one", "two") - self.od.add("one", "three") - assert self.od.get_first("one") == "two" - assert self.od.get_first("two") is None + od = odict.ODict() + od.add("one", "two") + od.add("one", "three") + assert od.get_first("one") == "two" + assert od.get_first("two") is None def test_extend(self): a = odict.ODict([["a", "b"], ["c", "d"]]) @@ -96,9 +102,6 @@ class TestODict(object): class TestODictCaseless(object): - def setUp(self): - self.od = odict.ODictCaseless() - def test_override(self): o = odict.ODictCaseless() o.add('T', 'application/x-www-form-urlencoded; charset=UTF-8') @@ -106,29 +109,32 @@ class TestODictCaseless(object): assert o["T"] == ["foo"] def test_case_preservation(self): - self.od["Foo"] = ["1"] - assert "foo" in self.od - assert self.od.items()[0][0] == "Foo" - assert self.od.get("foo") == ["1"] - assert self.od.get("foo", [""]) == ["1"] - assert self.od.get("Foo", [""]) == ["1"] - assert self.od.get("xx", "yy") == "yy" + od = odict.ODictCaseless() + od["Foo"] = ["1"] + assert "foo" in od + assert od.items()[0][0] == "Foo" + assert od.get("foo") == ["1"] + assert od.get("foo", [""]) == ["1"] + assert od.get("Foo", [""]) == ["1"] + assert od.get("xx", "yy") == "yy" def test_del(self): - self.od.add("foo", 1) - self.od.add("Foo", 2) - self.od.add("bar", 3) - del self.od["foo"] - assert len(self.od) == 1 + od = odict.ODictCaseless() + od.add("foo", 1) + od.add("Foo", 2) + od.add("bar", 3) + del od["foo"] + assert len(od) == 1 def test_keys(self): - assert not self.od.keys() - self.od.add("foo", 1) - assert self.od.keys() == ["foo"] - self.od.add("Foo", 2) - assert self.od.keys() == ["foo"] - self.od.add("bar", 2) - assert len(self.od.keys()) == 2 + od = odict.ODictCaseless() + assert not od.keys() + od.add("foo", 1) + assert od.keys() == ["foo"] + od.add("Foo", 2) + assert od.keys() == ["foo"] + od.add("bar", 2) + assert len(od.keys()) == 2 def test_add_order(self): od = odict.ODict( diff --git a/test/tservers.py b/test/tservers.py index 1f4ce725..c47d6a5f 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -26,7 +26,7 @@ class ServerTestBase(object): addr = ("localhost", 0) @classmethod - def setupAll(cls): + def setup_class(cls): cls.q = queue.Queue() s = cls.makeserver() cls.port = s.address.port @@ -38,7 +38,7 @@ class ServerTestBase(object): return TServer(cls.ssl, cls.q, cls.handler, cls.addr) @classmethod - def teardownAll(cls): + def teardown_class(cls): cls.server.shutdown() @property -- cgit v1.2.3 From daebd1bd275a398d42cc4dbfe5c6399c7fe3b3a0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 20 Sep 2015 20:35:45 +0200 Subject: python3++ --- .travis.yml | 7 +-- netlib/http/authentication.py | 4 +- netlib/tcp.py | 28 +++++------ test/http/test_authentication.py | 10 ++-- test/test_tcp.py | 105 ++++++++++++++++++++------------------- 5 files changed, 73 insertions(+), 81 deletions(-) diff --git a/.travis.yml b/.travis.yml index c8cbeaa2..2edd2558 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,12 +17,7 @@ matrix: - libssl-dev - python: 3.5 script: - - nosetests --with-cov --cov-report term-missing test/http/http1 - - nosetests --with-cov --cov-report term-missing test/test_utils.py - - nosetests --with-cov --cov-report term-missing test/test_encoding.py - - nosetests --with-cov --cov-report term-missing test/test_odict.py - - nosetests --with-cov --cov-report term-missing test/test_certutils.py - - nosetests --with-cov --cov-report term-missing test/test_socks.py + - py.test3 -n 4 -k "not http2 and not websockets and not wsgi and not models" . - python: pypy - python: pypy env: OPENSSL=1.0.2 diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 2055f843..5831660b 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -12,7 +12,7 @@ def parse_http_basic_auth(s): user = binascii.a2b_base64(words[1]) except binascii.Error: return None - parts = user.split(':') + parts = user.split(b':') if len(parts) != 2: return None return scheme, parts[0], parts[1] @@ -69,7 +69,7 @@ class BasicProxyAuth(NullProxyAuth): if not parts: return False scheme, username, password = parts - if scheme.lower() != 'basic': + if scheme.lower() != b'basic': return False if not self.password_manager.test(username, password): return False diff --git a/netlib/tcp.py b/netlib/tcp.py index f6f7d06f..40ffbd48 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -7,6 +7,7 @@ import threading import time import traceback +import binascii from six.moves import range import certifi @@ -78,14 +79,11 @@ class SSLKeyLogger(object): d = os.path.dirname(self.filename) if not os.path.isdir(d): os.makedirs(d) - self.f = open(self.filename, "a") - self.f.write("\r\n") - client_random = connection.client_random().encode("hex") - masterkey = connection.master_key().encode("hex") - self.f.write( - "CLIENT_RANDOM {} {}\r\n".format( - client_random, - masterkey)) + self.f = open(self.filename, "ab") + self.f.write(b"\r\n") + client_random = binascii.hexlify(connection.client_random()) + masterkey = binascii.hexlify(connection.master_key()) + self.f.write(b"CLIENT_RANDOM %s %s\r\n" % (client_random, masterkey)) self.f.flush() def close(self): @@ -140,7 +138,7 @@ class _FileLike(object): """ if not self.is_logging(): raise ValueError("Not logging!") - return "".join(self._log) + return b"".join(self._log) def add_log(self, v): if self.is_logging(): @@ -216,9 +214,9 @@ class Reader(_FileLike): except SSL.SysCallError as e: if e.args == (-1, 'Unexpected EOF'): break - raise TlsException(e.message) + raise TlsException(str(e)) except SSL.Error as e: - raise TlsException(e.message) + raise TlsException(str(e)) self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break @@ -240,7 +238,7 @@ class Reader(_FileLike): break else: result += ch - if ch == '\n': + if ch == b'\n': break return result @@ -757,7 +755,7 @@ class BaseHandler(_Connection): if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: - return "" + return b"" class TCPServer(object): @@ -829,9 +827,7 @@ class TCPServer(object): exc = six.text_type(traceback.format_exc()) print(u'-' * 40, file=fp) print( - u"Error in processing of request from %s:%s" % ( - client_address.host, client_address.port - ), file=fp) + u"Error in processing of request from %s" % repr(client_address), file=fp) print(exc, file=fp) print(u'-' * 40, file=fp) diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py index ee192dd7..a2aa774a 100644 --- a/test/http/test_authentication.py +++ b/test/http/test_authentication.py @@ -11,7 +11,7 @@ def test_parse_http_basic_auth(): ) == vals assert not authentication.parse_http_basic_auth("") assert not authentication.parse_http_basic_auth("foo bar") - v = "basic " + binascii.b2a_base64("foo") + v = b"basic " + binascii.b2a_base64(b"foo") assert not authentication.parse_http_basic_auth(v) @@ -34,7 +34,7 @@ class TestPassManHtpasswd: def test_simple(self): pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) - vals = ("basic", "test", "test") + vals = (b"basic", b"test", b"test") authentication.assemble_http_basic_auth(*vals) assert pm.test("test", "test") assert not pm.test("test", "foo") @@ -73,7 +73,7 @@ class TestBasicProxyAuth: ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") headers = Headers() - vals = ("basic", "foo", "bar") + vals = (b"basic", b"foo", b"bar") headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) assert ba.authenticate(headers) @@ -86,12 +86,12 @@ class TestBasicProxyAuth: headers[ba.AUTH_HEADER] = "foo" assert not ba.authenticate(headers) - vals = ("foo", "foo", "bar") + vals = (b"foo", b"foo", b"bar") headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) assert not ba.authenticate(headers) ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") - vals = ("basic", "foo", "bar") + vals = (b"basic", b"foo", b"bar") headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) assert not ba.authenticate(headers) diff --git a/test/test_tcp.py b/test/test_tcp.py index 725aa8b0..c87bebb3 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -49,7 +49,7 @@ class ALPNHandler(tcp.BaseHandler): def handle(self): alp = self.get_alpn_proto_negotiated() if alp: - self.wfile.write(("%s" % alp).encode("ascii")) + self.wfile.write(alp) else: self.wfile.write(b"NONE") self.wfile.flush() @@ -59,7 +59,7 @@ class TestServer(tservers.ServerTestBase): handler = EchoHandler def test_echo(self): - testval = "echo!\n" + testval = b"echo!\n" c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write(testval) @@ -81,7 +81,7 @@ class TestServerBind(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): - self.wfile.write(str(self.connection.getpeername())) + self.wfile.write(str(self.connection.getpeername()).encode()) self.wfile.flush() def test_bind(self): @@ -93,7 +93,7 @@ class TestServerBind(tservers.ServerTestBase): ("127.0.0.1", self.port), source_address=( "127.0.0.1", random_port)) c.connect() - assert c.rfile.readline() == str(("127.0.0.1", random_port)) + assert c.rfile.readline() == str(("127.0.0.1", random_port)).encode() return except TcpException: # port probably already in use pass @@ -104,7 +104,7 @@ class TestServerIPv6(tservers.ServerTestBase): addr = tcp.Address(("localhost", 0), use_ipv6=True) def test_echo(self): - testval = "echo!\n" + testval = b"echo!\n" c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True)) c.connect() c.wfile.write(testval) @@ -116,7 +116,7 @@ class TestEcho(tservers.ServerTestBase): handler = EchoHandler def test_echo(self): - testval = "echo!\n" + testval = b"echo!\n" c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write(testval) @@ -141,7 +141,7 @@ class TestFinishFail(tservers.ServerTestBase): def test_disconnect_in_finish(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.wfile.write("foo\n") + c.wfile.write(b"foo\n") c.wfile.flush = mock.Mock(side_effect=TcpDisconnect) c.finish() @@ -156,8 +156,8 @@ class TestServerSSL(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL) - testval = "echo!\n" + c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL) + testval = b"echo!\n" c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval @@ -166,7 +166,7 @@ class TestServerSSL(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() assert not c.get_current_cipher() - c.convert_to_ssl(sni="foo.com") + c.convert_to_ssl(sni=b"foo.com") ret = c.get_current_cipher() assert ret assert "AES" in ret[0] @@ -182,7 +182,7 @@ class TestSSLv3Only(tservers.ServerTestBase): def test_failure(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com") + tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com") class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): @@ -190,7 +190,8 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): ssl = dict( cert=tutils.test_data.path("data/verificationcerts/untrusted.crt"), - key=tutils.test_data.path("data/verificationcerts/verification-server.key")) + key=tutils.test_data.path("data/verificationcerts/verification-server.key") + ) def test_mode_default_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -202,7 +203,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): # aborted assert c.ssl_verification_error is not None - testval = "echo!\n" + testval = b"echo!\n" c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval @@ -216,7 +217,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): # Verification errors should be saved even if connection isn't aborted assert c.ssl_verification_error is not None - testval = "echo!\n" + testval = b"echo!\n" c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval @@ -280,7 +281,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): assert c.ssl_verification_error is None - testval = "echo!\n" + testval = b"echo!\n" c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval @@ -295,7 +296,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): assert c.ssl_verification_error is None - testval = "echo!\n" + testval = b"echo!\n" c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval @@ -310,7 +311,7 @@ class TestSSLClientCert(tservers.ServerTestBase): self.sni = connection.get_servername() def handle(self): - self.wfile.write("%s\n" % self.clientcert.serial) + self.wfile.write(b"%d\n" % self.clientcert.serial) self.wfile.flush() ssl = dict( @@ -323,7 +324,7 @@ class TestSSLClientCert(tservers.ServerTestBase): c.connect() c.convert_to_ssl( cert=tutils.test_data.path("data/clientcert/client.pem")) - assert c.rfile.readline().strip() == "1" + assert c.rfile.readline().strip() == b"1" def test_clientcert_err(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -352,9 +353,9 @@ class TestSNI(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(sni="foo.com") - assert c.sni == "foo.com" - assert c.rfile.readline() == "foo.com" + c.convert_to_ssl(sni=b"foo.com") + assert c.sni == b"foo.com" + assert c.rfile.readline() == b"foo.com" class TestServerCipherList(tservers.ServerTestBase): @@ -366,8 +367,8 @@ class TestServerCipherList(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(sni="foo.com") - assert c.rfile.readline() == "['RC4-SHA']" + c.convert_to_ssl(sni=b"foo.com") + assert c.rfile.readline() == b"['RC4-SHA']" class TestServerCurrentCipher(tservers.ServerTestBase): @@ -376,7 +377,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase): sni = None def handle(self): - self.wfile.write("%s" % str(self.get_current_cipher())) + self.wfile.write(str(self.get_current_cipher()).encode()) self.wfile.flush() ssl = dict( @@ -386,8 +387,8 @@ class TestServerCurrentCipher(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(sni="foo.com") - assert "RC4-SHA" in c.rfile.readline() + c.convert_to_ssl(sni=b"foo.com") + assert b"RC4-SHA" in c.rfile.readline() class TestServerCipherListError(tservers.ServerTestBase): @@ -399,7 +400,7 @@ class TestServerCipherListError(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com") + tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com") class TestClientCipherListError(tservers.ServerTestBase): @@ -414,7 +415,7 @@ class TestClientCipherListError(tservers.ServerTestBase): tutils.raises( "cipher specification", c.convert_to_ssl, - sni="foo.com", + sni=b"foo.com", cipher_list="bogus") @@ -434,7 +435,7 @@ class TestSSLDisconnect(tservers.ServerTestBase): # Excercise SSL.ZeroReturnError c.rfile.read(10) c.close() - tutils.raises(TcpDisconnect, c.wfile.write, "foo") + tutils.raises(TcpDisconnect, c.wfile.write, b"foo") tutils.raises(queue.Empty, self.q.get_nowait) @@ -449,7 +450,7 @@ class TestSSLHardDisconnect(tservers.ServerTestBase): # Exercise SSL.SysCallError c.rfile.read(10) c.close() - tutils.raises(TcpDisconnect, c.wfile.write, "foo") + tutils.raises(TcpDisconnect, c.wfile.write, b"foo") class TestDisconnect(tservers.ServerTestBase): @@ -458,7 +459,7 @@ class TestDisconnect(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.rfile.read(10) - c.wfile.write("foo") + c.wfile.write(b"foo") c.close() c.close() @@ -496,7 +497,7 @@ class TestTimeOut(tservers.ServerTestBase): class TestALPNClient(tservers.ServerTestBase): handler = ALPNHandler ssl = dict( - alpn_select="bar" + alpn_select=b"bar" ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -529,8 +530,8 @@ class TestNoSSLNoALPNClient(tservers.ServerTestBase): def test_no_ssl_no_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - assert c.get_alpn_proto_negotiated() == "" - assert c.rfile.readline().strip() == "NONE" + assert c.get_alpn_proto_negotiated() == b"" + assert c.rfile.readline().strip() == b"NONE" class TestSSLTimeOut(tservers.ServerTestBase): @@ -581,26 +582,26 @@ class TestFileLike: s = BytesIO(b"1234567890abcdefghijklmnopqrstuvwxyz") s = tcp.Reader(s) s.BLOCKSIZE = 2 - assert s.read(1) == "1" - assert s.read(2) == "23" - assert s.read(3) == "456" - assert s.read(4) == "7890" + assert s.read(1) == b"1" + assert s.read(2) == b"23" + assert s.read(3) == b"456" + assert s.read(4) == b"7890" d = s.read(-1) - assert d.startswith("abc") and d.endswith("xyz") + assert d.startswith(b"abc") and d.endswith(b"xyz") def test_wrap(self): s = BytesIO(b"foobar\nfoobar") s.flush() s = tcp.Reader(s) - assert s.readline() == "foobar\n" - assert s.readline() == "foobar" + assert s.readline() == b"foobar\n" + assert s.readline() == b"foobar" # Test __getattr__ assert s.isatty def test_limit(self): s = BytesIO(b"foobar\nfoobar") s = tcp.Reader(s) - assert s.readline(3) == "foo" + assert s.readline(3) == b"foo" def test_limitless(self): s = BytesIO(b"f" * (50 * 1024)) @@ -615,13 +616,13 @@ class TestFileLike: s.start_log() assert s.is_logging() s.readline() - assert s.get_log() == "foobar\n" + assert s.get_log() == b"foobar\n" s.read(1) - assert s.get_log() == "foobar\nf" + assert s.get_log() == b"foobar\nf" s.start_log() - assert s.get_log() == "" + assert s.get_log() == b"" s.read(1) - assert s.get_log() == "o" + assert s.get_log() == b"o" s.stop_log() tutils.raises(ValueError, s.get_log) @@ -630,10 +631,10 @@ class TestFileLike: s = tcp.Writer(s) s.start_log() assert s.is_logging() - s.write("x") - assert s.get_log() == "x" - s.write("x") - assert s.get_log() == "xx" + s.write(b"x") + assert s.get_log() == b"x" + s.write(b"x") + assert s.get_log() == b"xx" def test_writer_flush_error(self): s = BytesIO() @@ -721,7 +722,7 @@ class TestSSLKeyLogger(tservers.ServerTestBase): ) def test_log(self): - testval = "echo!\n" + testval = b"echo!\n" _logfun = tcp.log_ssl_key with tutils.tmpdir() as d: @@ -738,7 +739,7 @@ class TestSSLKeyLogger(tservers.ServerTestBase): tcp.log_ssl_key.close() with open(logfile, "rb") as f: - assert f.read().count("CLIENT_RANDOM") == 2 + assert f.read().count(b"CLIENT_RANDOM") == 2 tcp.log_ssl_key = _logfun -- cgit v1.2.3 From 73586b1be95d97f0be76e85223b53d1f4ed697d6 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Sep 2015 00:44:17 +0200 Subject: python 3++ --- .travis.yml | 2 +- netlib/encoding.py | 16 ++++---- netlib/http/models.py | 30 +++++++-------- netlib/tutils.py | 5 +-- netlib/utils.py | 53 +++++++++++++++++--------- netlib/websockets/frame.py | 77 +++++++++++++++++++++++++------------- netlib/websockets/protocol.py | 52 +++++++++++++------------ netlib/wsgi.py | 55 ++++++++++++++------------- test/http/http1/test_protocol.py | 0 test/http/test_models.py | 72 +++++++++++++++++------------------ test/test_encoding.py | 2 - test/test_wsgi.py | 21 ++++++----- test/websockets/test_websockets.py | 71 +++++++++++++++++------------------ 13 files changed, 250 insertions(+), 206 deletions(-) delete mode 100644 test/http/http1/test_protocol.py diff --git a/.travis.yml b/.travis.yml index 2edd2558..c5634a10 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,7 +17,7 @@ matrix: - libssl-dev - python: 3.5 script: - - py.test3 -n 4 -k "not http2 and not websockets and not wsgi and not models" . + - py.test -n 4 -k "not http2" . - python: pypy - python: pypy env: OPENSSL=1.0.2 diff --git a/netlib/encoding.py b/netlib/encoding.py index 8ac59905..4c11273b 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -8,27 +8,25 @@ import zlib from .utils import always_byte_args -ENCODINGS = {b"identity", b"gzip", b"deflate"} +ENCODINGS = {"identity", "gzip", "deflate"} -@always_byte_args("ascii", "ignore") def decode(e, content): encoding_map = { - b"identity": identity, - b"gzip": decode_gzip, - b"deflate": decode_deflate, + "identity": identity, + "gzip": decode_gzip, + "deflate": decode_deflate, } if e not in encoding_map: return None return encoding_map[e](content) -@always_byte_args("ascii", "ignore") def encode(e, content): encoding_map = { - b"identity": identity, - b"gzip": encode_gzip, - b"deflate": encode_deflate, + "identity": identity, + "gzip": encode_gzip, + "deflate": encode_deflate, } if e not in encoding_map: return None diff --git a/netlib/http/models.py b/netlib/http/models.py index ff854b13..3c360a37 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -3,7 +3,7 @@ import copy from ..odict import ODict from .. import utils, encoding -from ..utils import always_bytes, always_byte_args +from ..utils import always_bytes, always_byte_args, native from . import cookies import six @@ -254,7 +254,7 @@ class Request(Message): def __repr__(self): if self.host and self.port: - hostport = "{}:{}".format(self.host, self.port) + hostport = "{}:{}".format(native(self.host,"idna"), self.port) else: hostport = "" path = self.path or "" @@ -279,14 +279,14 @@ class Request(Message): Modifies this request to remove headers that will compress the resource's data. """ - self.headers["Accept-Encoding"] = b"identity" + self.headers["Accept-Encoding"] = "identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = self.headers.get(b"Accept-Encoding") + accept_encoding = native(self.headers.get("Accept-Encoding"), "ascii") if accept_encoding: self.headers["Accept-Encoding"] = ( ', '.join( @@ -309,9 +309,9 @@ class Request(Message): indicates non-form data. """ if self.body: - if HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): + if HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower(): return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): + elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower(): return self.get_form_multipart() return ODict([]) @@ -321,12 +321,12 @@ class Request(Message): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): + if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower(): return ODict(utils.urldecode(self.body)) return ODict([]) def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): + if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower(): return ODict( utils.multipartdecode( self.headers, @@ -351,7 +351,7 @@ class Request(Message): Components are unquoted. """ _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(i) for i in path.split(b"/") if i] + return [urllib.parse.unquote(native(i,"ascii")) for i in path.split(b"/") if i] def set_path_components(self, lst): """ @@ -360,7 +360,7 @@ class Request(Message): Components are quoted. """ lst = [urllib.parse.quote(i, safe="") for i in lst] - path = b"/" + b"/".join(lst) + path = always_bytes("/" + "/".join(lst)) scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) self.url = urllib.parse.urlunparse( [scheme, netloc, path, params, query, fragment] @@ -408,11 +408,11 @@ class Request(Message): def pretty_url(self, hostheader): if self.form_out == "authority": # upstream proxy mode - return "%s:%s" % (self.pretty_host(hostheader), self.port) + return b"%s:%d" % (always_bytes(self.pretty_host(hostheader)), self.port) return utils.unparse_url(self.scheme, self.pretty_host(hostheader), self.port, - self.path).encode('ascii') + self.path) def get_cookies(self): """ @@ -420,7 +420,7 @@ class Request(Message): """ ret = ODict() for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) + ret.extend(cookies.parse_cookie_header(native(i,"ascii"))) return ret def set_cookies(self, odict): @@ -441,7 +441,7 @@ class Request(Message): self.host, self.port, self.path - ).encode('ascii') + ) @url.setter def url(self, url): @@ -499,7 +499,7 @@ class Response(Message): """ ret = [] for header in self.headers.get_all("Set-Cookie"): - v = cookies.parse_set_cookie_header(header) + v = cookies.parse_set_cookie_header(native(header, "ascii")) if v: name, value, attrs = v ret.append([name, [value, attrs]]) diff --git a/netlib/tutils.py b/netlib/tutils.py index 4903d63b..1665a792 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -7,7 +7,7 @@ from contextlib import contextmanager import six import sys -from . import utils +from . import utils, tcp from .http import Request, Response, Headers @@ -15,7 +15,6 @@ def treader(bytes): """ Construct a tcp.Read object from bytes. """ - from . import tcp # TODO: move to top once cryptography is on Python 3.5 fp = BytesIO(bytes) return tcp.Reader(fp) @@ -106,7 +105,7 @@ def treq(**kwargs): port=22, path=b"/path", http_version=b"HTTP/1.1", - headers=Headers(header=b"qvalue"), + headers=Headers(header="qvalue"), body=b"content" ) default.update(kwargs) diff --git a/netlib/utils.py b/netlib/utils.py index 799b0d42..8d11bd5b 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -9,6 +9,41 @@ import six from six.moves import urllib +def always_bytes(unicode_or_bytes, *encode_args): + if isinstance(unicode_or_bytes, six.text_type): + return unicode_or_bytes.encode(*encode_args) + return unicode_or_bytes + + +def always_byte_args(*encode_args): + """Decorator that transparently encodes all arguments passed as unicode""" + def decorator(fun): + def _fun(*args, **kwargs): + args = [always_bytes(arg, *encode_args) for arg in args] + kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} + return fun(*args, **kwargs) + return _fun + return decorator + + +def native(s, encoding="latin-1"): + """ + Convert :py:class:`bytes` or :py:class:`unicode` to the native + :py:class:`str` type, using latin1 encoding if conversion is necessary. + + https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types + """ + if not isinstance(s, (six.binary_type, six.text_type)): + raise TypeError("%r is neither bytes nor unicode" % s) + if six.PY3: + if isinstance(s, six.binary_type): + return s.decode(encoding) + else: + if isinstance(s, six.text_type): + return s.encode(encoding) + return s + + def isascii(bytes): try: bytes.decode("ascii") @@ -238,6 +273,7 @@ def get_header_tokens(headers, key): return [token.strip() for token in tokens] +@always_byte_args() def hostport(scheme, host, port): """ Returns the host component, with a port specifcation if needed. @@ -323,20 +359,3 @@ def multipartdecode(headers, content): r.append((key, value)) return r return [] - - -def always_bytes(unicode_or_bytes, *encode_args): - if isinstance(unicode_or_bytes, six.text_type): - return unicode_or_bytes.encode(*encode_args) - return unicode_or_bytes - - -def always_byte_args(*encode_args): - """Decorator that transparently encodes all arguments passed as unicode""" - def decorator(fun): - def _fun(*args, **kwargs): - args = [always_bytes(arg, *encode_args) for arg in args] - kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} - return fun(*args, **kwargs) - return _fun - return decorator diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index ceddd273..55eeaf41 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -2,13 +2,14 @@ from __future__ import absolute_import import os import struct import io +import warnings + import six from .protocol import Masker from netlib import tcp from netlib import utils -DEFAULT = object() MAX_16_BIT_INT = (1 << 16) MAX_64_BIT_INT = (1 << 64) @@ -33,9 +34,9 @@ class FrameHeader(object): rsv1=False, rsv2=False, rsv3=False, - masking_key=DEFAULT, - mask=DEFAULT, - length_code=DEFAULT + masking_key=None, + mask=None, + length_code=None ): if not 0 <= opcode < 2 ** 4: raise ValueError("opcode must be 0-16") @@ -46,18 +47,18 @@ class FrameHeader(object): self.rsv2 = rsv2 self.rsv3 = rsv3 - if length_code is DEFAULT: + if length_code is None: self.length_code = self._make_length_code(self.payload_length) else: self.length_code = length_code - if mask is DEFAULT and masking_key is DEFAULT: + if mask is None and masking_key is None: self.mask = False - self.masking_key = "" - elif mask is DEFAULT: + self.masking_key = b"" + elif mask is None: self.mask = 1 self.masking_key = masking_key - elif masking_key is DEFAULT: + elif masking_key is None: self.mask = mask self.masking_key = os.urandom(4) else: @@ -81,7 +82,7 @@ class FrameHeader(object): else: return 127 - def human_readable(self): + def __repr__(self): vals = [ "ws frame:", OPCODE.get_name(self.opcode, hex(self.opcode)).lower() @@ -98,7 +99,11 @@ class FrameHeader(object): vals.append(" %s" % utils.pretty_size(self.payload_length)) return "".join(vals) - def to_bytes(self): + def human_readable(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return repr(self) + + def __bytes__(self): first_byte = utils.setbit(0, 7, self.fin) first_byte = utils.setbit(first_byte, 6, self.rsv1) first_byte = utils.setbit(first_byte, 5, self.rsv2) @@ -107,7 +112,7 @@ class FrameHeader(object): second_byte = utils.setbit(self.length_code, 7, self.mask) - b = chr(first_byte) + chr(second_byte) + b = six.int2byte(first_byte) + six.int2byte(second_byte) if self.payload_length < 126: pass @@ -119,10 +124,17 @@ class FrameHeader(object): # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length b += struct.pack('!Q', self.payload_length) - if self.masking_key is not None: + if self.masking_key: b += self.masking_key return b + if six.PY2: + __str__ = __bytes__ + + def to_bytes(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return bytes(self) + @classmethod def from_file(cls, fp): """ @@ -154,7 +166,7 @@ class FrameHeader(object): if mask_bit == 1: masking_key = fp.safe_read(4) else: - masking_key = None + masking_key = False return cls( fin=fin, @@ -169,7 +181,9 @@ class FrameHeader(object): ) def __eq__(self, other): - return self.to_bytes() == other.to_bytes() + if isinstance(other, FrameHeader): + return bytes(self) == bytes(other) + return False class Frame(object): @@ -200,7 +214,7 @@ class Frame(object): +---------------------------------------------------------------+ """ - def __init__(self, payload="", **kwargs): + def __init__(self, payload=b"", **kwargs): self.payload = payload kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) self.header = FrameHeader(**kwargs) @@ -216,7 +230,7 @@ class Frame(object): masking_key = os.urandom(4) else: mask_bit = 0 - masking_key = None + masking_key = False return cls( message, @@ -234,28 +248,37 @@ class Frame(object): """ return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) - def human_readable(self): - ret = self.header.human_readable() + def __repr__(self): + ret = repr(self.header) if self.payload: - ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload) + ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii") return ret - def __repr__(self): - return self.header.human_readable() + def human_readable(self): + warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning) + return repr(self) - def to_bytes(self): + def __bytes__(self): """ Serialize the frame to wire format. Returns a string. """ - b = self.header.to_bytes() + b = bytes(self.header) if self.header.masking_key: b += Masker(self.header.masking_key)(self.payload) else: b += self.payload return b + if six.PY2: + __str__ = __bytes__ + + def to_bytes(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return bytes(self) + def to_file(self, writer): - writer.write(self.to_bytes()) + warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning) + writer.write(bytes(self)) writer.flush() @classmethod @@ -286,4 +309,6 @@ class Frame(object): ) def __eq__(self, other): - return self.to_bytes() == other.to_bytes() + if isinstance(other, Frame): + return bytes(self) == bytes(other) + return False \ No newline at end of file diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 68d827a5..778fe7e7 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -17,11 +17,12 @@ from __future__ import absolute_import import base64 import hashlib import os + +import binascii import six from ..http import Headers -from .. import utils -websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" HEADER_WEBSOCKET_KEY = 'sec-websocket-key' @@ -41,14 +42,21 @@ class Masker(object): def __init__(self, key): self.key = key - self.masks = [six.byte2int(byte) for byte in key] self.offset = 0 def mask(self, offset, data): - result = "" - for c in data: - result += chr(ord(c) ^ self.masks[offset % 4]) - offset += 1 + result = bytearray(data) + if six.PY2: + for i in range(len(data)): + result[i] ^= ord(self.key[offset % 4]) + offset += 1 + result = str(result) + else: + + for i in range(len(data)): + result[i] ^= self.key[offset % 4] + offset += 1 + result = bytes(result) return result def __call__(self, data): @@ -73,37 +81,35 @@ class WebsocketsProtocol(object): """ if not key: key = base64.b64encode(os.urandom(16)).decode('utf-8') - return Headers([ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - (HEADER_WEBSOCKET_KEY, key), - (HEADER_WEBSOCKET_VERSION, version) - ]) + return Headers(**{ + HEADER_WEBSOCKET_KEY: key, + HEADER_WEBSOCKET_VERSION: version, + "Connection": "Upgrade", + "Upgrade": "websocket", + }) @classmethod def server_handshake_headers(self, key): """ The server response is a valid HTTP 101 response. """ - return Headers( - [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - (HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key)) - ] - ) + return Headers(**{ + HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key), + "Connection": "Upgrade", + "Upgrade": "websocket", + }) @classmethod def check_client_handshake(self, headers): - if headers.get("upgrade") != "websocket": + if headers.get("upgrade") != b"websocket": return return headers.get(HEADER_WEBSOCKET_KEY) @classmethod def check_server_handshake(self, headers): - if headers.get("upgrade") != "websocket": + if headers.get("upgrade") != b"websocket": return return headers.get(HEADER_WEBSOCKET_ACCEPT) @@ -111,5 +117,5 @@ class WebsocketsProtocol(object): @classmethod def create_server_nonce(self, client_nonce): return base64.b64encode( - hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') + binascii.unhexlify(hashlib.sha1(client_nonce + websockets_magic).hexdigest()) ) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index fba9f388..8fb09008 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,14 +1,15 @@ from __future__ import (absolute_import, print_function, division) -from io import BytesIO +from io import BytesIO, StringIO import urllib import time import traceback import six +from six.moves import urllib +from netlib.utils import always_bytes, native from . import http, tcp - class ClientConn(object): def __init__(self, address): @@ -24,9 +25,10 @@ class Flow(object): class Request(object): - def __init__(self, scheme, method, path, headers, body): + def __init__(self, scheme, method, path, http_version, headers, body): self.scheme, self.method, self.path = scheme, method, path self.headers, self.body = headers, body + self.http_version = http_version def date_time_string(): @@ -53,38 +55,38 @@ class WSGIAdaptor(object): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion def make_environ(self, flow, errsoc, **extra): - if '?' in flow.request.path: - path_info, query = flow.request.path.split('?', 1) + path = native(flow.request.path) + if '?' in path: + path_info, query = native(path).split('?', 1) else: - path_info = flow.request.path + path_info = path query = '' environ = { 'wsgi.version': (1, 0), - 'wsgi.url_scheme': flow.request.scheme, + 'wsgi.url_scheme': native(flow.request.scheme), 'wsgi.input': BytesIO(flow.request.body or b""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, 'wsgi.run_once': False, 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': flow.request.method, + 'REQUEST_METHOD': native(flow.request.method), 'SCRIPT_NAME': '', - 'PATH_INFO': urllib.unquote(path_info), + 'PATH_INFO': urllib.parse.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': flow.request.headers.get('Content-Type', ''), - 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', ''), + 'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', '')), + 'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', '')), 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), - # FIXME: We need to pick up the protocol read from the request. - 'SERVER_PROTOCOL': "HTTP/1.1", + 'SERVER_PROTOCOL': native(flow.request.http_version), } environ.update(extra) if flow.client_conn.address: - environ["REMOTE_ADDR"], environ[ - "REMOTE_PORT"] = flow.client_conn.address() + environ["REMOTE_ADDR"] = native(flow.client_conn.address.host) + environ["REMOTE_PORT"] = flow.client_conn.address.port for key, value in flow.request.headers.items(): - key = 'HTTP_' + key.upper().replace('-', '_') + key = 'HTTP_' + native(key).upper().replace('-', '_') if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): environ[key] = value return environ @@ -99,7 +101,7 @@ class WSGIAdaptor(object):

Internal Server Error

%s"
- """.strip() % s + """.strip() % s.encode() if not headers_sent: soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") soc.write(b"Content-Type: text/html\r\n") @@ -117,7 +119,7 @@ class WSGIAdaptor(object): def write(data): if not state["headers_sent"]: - soc.write(b"HTTP/1.1 %s\r\n" % state["status"]) + soc.write(b"HTTP/1.1 %s\r\n" % state["status"].encode()) headers = state["headers"] if 'server' not in headers: headers["Server"] = self.sversion @@ -132,18 +134,17 @@ class WSGIAdaptor(object): def start_response(status, headers, exc_info=None): if exc_info: - try: - if state["headers_sent"]: - six.reraise(*exc_info) - finally: - exc_info = None + if state["headers_sent"]: + six.reraise(*exc_info) elif state["status"]: raise AssertionError('Response already started') state["status"] = status - state["headers"] = http.Headers(headers) - return write + state["headers"] = http.Headers([[always_bytes(k), always_bytes(v)] for k,v in headers]) + if exc_info: + self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2])) + state["headers_sent"] = True - errs = BytesIO() + errs = six.BytesIO() try: dataiter = self.app( self.make_environ(request, errs, **env), start_response @@ -155,7 +156,7 @@ class WSGIAdaptor(object): except Exception as e: try: s = traceback.format_exc() - errs.write(s) + errs.write(s.encode("utf-8", "replace")) self.error_page(soc, state["headers_sent"], s) except Exception: # pragma: no cover pass diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/http/test_models.py b/test/http/test_models.py index 6970a6e4..d420b22b 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -58,20 +58,20 @@ class TestRequest(object): req = tutils.treq() req.headers["Accept-Encoding"] = "foobar" req.anticomp() - assert req.headers["Accept-Encoding"] == "identity" + assert req.headers["Accept-Encoding"] == b"identity" def test_constrain_encoding(self): req = tutils.treq() req.headers["Accept-Encoding"] = "identity, gzip, foo" req.constrain_encoding() - assert "foo" not in req.headers["Accept-Encoding"] + assert b"foo" not in req.headers["Accept-Encoding"] def test_update_host(self): req = tutils.treq() req.headers["Host"] = "" req.host = "foobar" req.update_host_header() - assert req.headers["Host"] == "foobar" + assert req.headers["Host"] == b"foobar" def test_get_form(self): req = tutils.treq() @@ -132,7 +132,7 @@ class TestRequest(object): def test_set_path_components(self): req = tutils.treq() - req.set_path_components(["foo", "bar"]) + req.set_path_components([b"foo", b"bar"]) # TODO: add meaningful assertions def test_get_query(self): @@ -140,7 +140,7 @@ class TestRequest(object): assert req.get_query().lst == [] req.url = "http://localhost:80/foo?bar=42" - assert req.get_query().lst == [("bar", "42")] + assert req.get_query().lst == [(b"bar", b"42")] def test_set_query(self): req = tutils.treq() @@ -167,12 +167,12 @@ class TestRequest(object): def test_pretty_url(self): req = tutils.treq() req.form_out = "authority" - assert req.pretty_url(True) == "address:22" - assert req.pretty_url(False) == "address:22" + assert req.pretty_url(True) == b"address:22" + assert req.pretty_url(False) == b"address:22" req.form_out = "relative" - assert req.pretty_url(True) == "http://address:22/path" - assert req.pretty_url(False) == "http://address:22/path" + assert req.pretty_url(True) == b"http://address:22/path" + assert req.pretty_url(False) == b"http://address:22/path" def test_get_cookies_none(self): headers = Headers() @@ -213,11 +213,11 @@ class TestRequest(object): def test_set_url(self): r = tutils.treq(form_in="absolute") - r.url = "https://otheraddress:42/ORLY" - assert r.scheme == "https" - assert r.host == "otheraddress" + r.url = b"https://otheraddress:42/ORLY" + assert r.scheme == b"https" + assert r.host == b"otheraddress" assert r.port == 42 - assert r.path == "/ORLY" + assert r.path == b"/ORLY" try: r.url = "//localhost:80/foo@bar" @@ -374,8 +374,8 @@ class TestResponse(object): def test_get_cookies_twocookies(self): resp = tutils.tresp() resp.headers = Headers([ - ["Set-Cookie", "cookiename=cookievalue"], - ["Set-Cookie", "othercookie=othervalue"] + [b"Set-Cookie", b"cookiename=cookievalue"], + [b"Set-Cookie", b"othercookie=othervalue"] ]) result = resp.get_cookies() assert len(result) == 2 @@ -399,8 +399,8 @@ class TestHeaders(object): def _2host(self): return Headers( [ - ["Host", "example.com"], - ["host", "example.org"] + [b"Host", b"example.com"], + [b"host", b"example.org"] ] ) @@ -408,37 +408,37 @@ class TestHeaders(object): headers = Headers() assert len(headers) == 0 - headers = Headers([["Host", "example.com"]]) + headers = Headers([[b"Host", b"example.com"]]) assert len(headers) == 1 - assert headers["Host"] == "example.com" + assert headers["Host"] == b"example.com" headers = Headers(Host="example.com") assert len(headers) == 1 - assert headers["Host"] == "example.com" + assert headers["Host"] == b"example.com" headers = Headers( - [["Host", "invalid"]], + [[b"Host", b"invalid"]], Host="example.com" ) assert len(headers) == 1 - assert headers["Host"] == "example.com" + assert headers["Host"] == b"example.com" headers = Headers( - [["Host", "invalid"], ["Accept", "text/plain"]], + [[b"Host", b"invalid"], [b"Accept", b"text/plain"]], Host="example.com" ) assert len(headers) == 2 - assert headers["Host"] == "example.com" - assert headers["Accept"] == "text/plain" + assert headers["Host"] == b"example.com" + assert headers["Accept"] == b"text/plain" def test_getitem(self): headers = Headers(Host="example.com") - assert headers["Host"] == "example.com" - assert headers["host"] == "example.com" + assert headers["Host"] == b"example.com" + assert headers["host"] == b"example.com" tutils.raises(KeyError, headers.__getitem__, "Accept") headers = self._2host() - assert headers["Host"] == "example.com, example.org" + assert headers["Host"] == b"example.com, example.org" def test_str(self): headers = Headers(Host="example.com") @@ -458,12 +458,12 @@ class TestHeaders(object): headers["Host"] = "example.com" assert "Host" in headers assert "host" in headers - assert headers["Host"] == "example.com" + assert headers["Host"] == b"example.com" headers["host"] = "example.org" assert "Host" in headers assert "host" in headers - assert headers["Host"] == "example.org" + assert headers["Host"] == b"example.org" headers["accept"] = "text/plain" assert len(headers) == 2 @@ -494,12 +494,10 @@ class TestHeaders(object): def test_keys(self): headers = Headers(Host="example.com") - assert len(headers.keys()) == 1 - assert headers.keys()[0] == "Host" + assert list(headers.keys()) == [b"Host"] headers = self._2host() - assert len(headers.keys()) == 1 - assert headers.keys()[0] == "Host" + assert list(headers.keys()) == [b"Host"] def test_eq_ne(self): headers1 = Headers(Host="example.com") @@ -516,7 +514,7 @@ class TestHeaders(object): def test_get_all(self): headers = self._2host() - assert headers.get_all("host") == ["example.com", "example.org"] + assert headers.get_all("host") == [b"example.com", b"example.org"] assert headers.get_all("accept") == [] def test_set_all(self): @@ -527,10 +525,10 @@ class TestHeaders(object): headers = self._2host() headers.set_all("Host", ["example.org"]) - assert headers["host"] == "example.org" + assert headers["host"] == b"example.org" headers.set_all("Host", ["example.org", "example.net"]) - assert headers["host"] == "example.org, example.net" + assert headers["host"] == b"example.org, example.net" def test_state(self): headers = self._2host() diff --git a/test/test_encoding.py b/test/test_encoding.py index 90f99338..0ff1aad1 100644 --- a/test/test_encoding.py +++ b/test/test_encoding.py @@ -4,8 +4,6 @@ from netlib import encoding def test_identity(): assert b"string" == encoding.decode("identity", b"string") assert b"string" == encoding.encode("identity", b"string") - assert b"string" == encoding.encode(b"identity", b"string") - assert b"string" == encoding.decode(b"identity", b"string") assert not encoding.encode("nonexistent", b"string") assert not encoding.decode("nonexistent encoding", b"string") diff --git a/test/test_wsgi.py b/test/test_wsgi.py index 856967af..fe6f09b5 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -5,8 +5,8 @@ from netlib.http import Headers def tflow(): - headers = Headers(test="value") - req = wsgi.Request("http", "GET", "/", headers, "") + headers = Headers(test=b"value") + req = wsgi.Request("http", "GET", "/", "HTTP/1.1", headers, "") return wsgi.Flow(("127.0.0.1", 8888), req) @@ -20,7 +20,7 @@ class TestApp: status = '200 OK' response_headers = [('Content-type', 'text/plain')] start_response(status, response_headers) - return ['Hello', ' world!\n'] + return [b'Hello', b' world!\n'] class TestWSGI: @@ -47,8 +47,8 @@ class TestWSGI: assert not err val = wfile.getvalue() - assert "Hello world" in val - assert "Server:" in val + assert b"Hello world" in val + assert b"Server:" in val def _serve(self, app): w = wsgi.WSGIAdaptor(app, "foo", 80, "version") @@ -77,7 +77,7 @@ class TestWSGI: response_headers = [('Content-type', 'text/plain')] start_response(status, response_headers) start_response(status, response_headers) - assert "Internal Server Error" in self._serve(app) + assert b"Internal Server Error" in self._serve(app) def test_serve_single_err(self): def app(environ, start_response): @@ -88,7 +88,8 @@ class TestWSGI: status = '200 OK' response_headers = [('Content-type', 'text/plain')] start_response(status, response_headers, ei) - assert "Internal Server Error" in self._serve(app) + yield b"" + assert b"Internal Server Error" in self._serve(app) def test_serve_double_err(self): def app(environ, start_response): @@ -99,7 +100,7 @@ class TestWSGI: status = '200 OK' response_headers = [('Content-type', 'text/plain')] start_response(status, response_headers) - yield "aaa" + yield b"aaa" start_response(status, response_headers, ei) - yield "bbb" - assert "Internal Server Error" in self._serve(app) + yield b"bbb" + assert b"Internal Server Error" in self._serve(app) diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 3af5dc9c..6f67b84d 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -41,7 +41,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): key = self.protocol.check_client_handshake(req.headers) preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble + "\r\n") + self.wfile.write(preamble.encode() + b"\r\n") headers = self.protocol.server_handshake_headers(key) self.wfile.write(str(headers) + "\r\n") self.wfile.flush() @@ -62,11 +62,11 @@ class WebSocketsClient(tcp.TCPClient): def connect(self): super(WebSocketsClient, self).connect() - preamble = 'GET / HTTP/1.1' - self.wfile.write(preamble + "\r\n") + preamble = b'GET / HTTP/1.1' + self.wfile.write(preamble + b"\r\n") headers = self.protocol.client_handshake_headers() self.client_nonce = headers["sec-websocket-key"] - self.wfile.write(str(headers) + "\r\n") + self.wfile.write(bytes(headers) + b"\r\n") self.wfile.flush() resp = read_response(self.rfile, treq(method="GET")) @@ -101,7 +101,7 @@ class TestWebSockets(tservers.ServerTestBase): assert response == msg def test_simple_echo(self): - self.echo("hello I'm the client") + self.echo(b"hello I'm the client") def test_frame_sizes(self): # length can fit in the the 7 bit payload length @@ -161,10 +161,10 @@ class BadHandshakeHandler(WebSocketsEchoHandler): client_hs = read_request(self.rfile) self.protocol.check_client_handshake(client_hs.headers) - preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble + "\r\n") - headers = self.protocol.server_handshake_headers("malformed key") - self.wfile.write(str(headers) + "\r\n") + preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101) + self.wfile.write(preamble.encode()) + headers = self.protocol.server_handshake_headers(b"malformed key") + self.wfile.write(bytes(headers) + b"\r\n") self.wfile.flush() self.handshake_done = True @@ -180,7 +180,7 @@ class TestBadHandshake(tservers.ServerTestBase): def test(self): client = WebSocketsClient(("127.0.0.1", self.port)) client.connect() - client.send_message("hello") + client.send_message(b"hello") class TestFrameHeader: @@ -188,8 +188,7 @@ class TestFrameHeader: def test_roundtrip(self): def round(*args, **kwargs): f = websockets.FrameHeader(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f))) assert f == f2 round() round(fin=1) @@ -201,11 +200,11 @@ class TestFrameHeader: round(payload_length=1000) round(payload_length=10000) round(opcode=websockets.OPCODE.PING) - round(masking_key="test") + round(masking_key=b"test") def test_human_readable(self): f = websockets.FrameHeader( - masking_key="test", + masking_key=b"test", fin=True, payload_length=10 ) @@ -214,23 +213,23 @@ class TestFrameHeader: assert f.human_readable() def test_funky(self): - f = websockets.FrameHeader(masking_key="test", mask=False) + f = websockets.FrameHeader(masking_key=b"test", mask=False) bytes = f.to_bytes() f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) assert not f2.mask def test_violations(self): tutils.raises("opcode", websockets.FrameHeader, opcode=17) - tutils.raises("masking key", websockets.FrameHeader, masking_key="x") + tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") def test_automask(self): f = websockets.FrameHeader(mask=True) assert f.masking_key - f = websockets.FrameHeader(masking_key="foob") + f = websockets.FrameHeader(masking_key=b"foob") assert f.mask - f = websockets.FrameHeader(masking_key="foob", mask=0) + f = websockets.FrameHeader(masking_key=b"foob", mask=0) assert not f.mask assert f.masking_key @@ -240,31 +239,31 @@ class TestFrame: def test_roundtrip(self): def round(*args, **kwargs): f = websockets.Frame(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.Frame.from_file(tutils.treader(bytes)) + raw = bytes(f) + f2 = websockets.Frame.from_file(tutils.treader(raw)) assert f == f2 - round("test") - round("test", fin=1) - round("test", rsv1=1) - round("test", opcode=websockets.OPCODE.PING) - round("test", masking_key="test") + round(b"test") + round(b"test", fin=1) + round(b"test", rsv1=1) + round(b"test", opcode=websockets.OPCODE.PING) + round(b"test", masking_key=b"test") def test_human_readable(self): f = websockets.Frame() - assert f.human_readable() + assert repr(f) def test_masker(): tests = [ - ["a"], - ["four"], - ["fourf"], - ["fourfive"], - ["a", "aasdfasdfa", "asdf"], - ["a" * 50, "aasdfasdfa", "asdf"], + [b"a"], + [b"four"], + [b"fourf"], + [b"fourfive"], + [b"a", b"aasdfasdfa", b"asdf"], + [b"a" * 50, b"aasdfasdfa", b"asdf"], ] for i in tests: - m = websockets.Masker("abcd") - data = "".join([m(t) for t in i]) - data2 = websockets.Masker("abcd")(data) - assert data2 == "".join(i) + m = websockets.Masker(b"abcd") + data = b"".join([m(t) for t in i]) + data2 = websockets.Masker(b"abcd")(data) + assert data2 == b"".join(i) -- cgit v1.2.3 From f2e3e6af6de401632eb15783b93fdda768de6ab2 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Sep 2015 00:45:52 +0200 Subject: test on pypy3 --- .travis.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.travis.yml b/.travis.yml index c5634a10..7835bb64 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,6 +18,9 @@ matrix: - python: 3.5 script: - py.test -n 4 -k "not http2" . + - python: pypy3 + script: + - py.test -n 4 -k "not http2" . - python: pypy - python: pypy env: OPENSSL=1.0.2 -- cgit v1.2.3 From eaf66550b0561f57f7cec5b569017047c0427ede Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Sep 2015 01:08:19 +0200 Subject: always use py.test --- .coveragerc | 10 +++++++++- .travis.yml | 3 ++- setup.py | 1 + 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/.coveragerc b/.coveragerc index 8076aebe..ccbebf8c 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,11 @@ +[run] +branch = True + [report] -omit = *contrib* +show_missing = True include = *netlib/netlib* +exclude_lines = + pragma: nocover + pragma: no cover + raise NotImplementedError() +omit = *contrib* diff --git a/.travis.yml b/.travis.yml index 7835bb64..0f2b1431 100644 --- a/.travis.yml +++ b/.travis.yml @@ -36,6 +36,7 @@ matrix: # We allow pypy to fail until Travis fixes their infrastructure to a pypy # with a recent enought CFFI library to run cryptography 1.0+. - python: pypy + - python: pypy3 install: - "pip install --src . -r requirements.txt" @@ -44,7 +45,7 @@ before_script: - "openssl version -a" script: - - "nosetests --with-cov --cov-report term-missing" + - "py.test -n 4 --cov netlib" after_success: - coveralls diff --git a/setup.py b/setup.py index ac0d36cf..a661e6f7 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ setup( "mock>=1.0.1", "pytest>=2.8.0", "pytest-xdist>=1.13.1", + "pytest-cov>=2.1.0", "nose>=1.3.0", "nose-cov>=1.6", "coveralls>=0.4.1", -- cgit v1.2.3 From f0ff68023d428dfcfe10dc01965cbb840e0f2267 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Sep 2015 01:11:42 +0200 Subject: remove nose as a dependency --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index a661e6f7..30c80f5b 100644 --- a/setup.py +++ b/setup.py @@ -60,8 +60,6 @@ setup( "pytest>=2.8.0", "pytest-xdist>=1.13.1", "pytest-cov>=2.1.0", - "nose>=1.3.0", - "nose-cov>=1.6", "coveralls>=0.4.1", "autopep8>=1.0.3", "autoflake>=0.6.6", -- cgit v1.2.3 From 151942d7aefe652f0cb019cf8914d0ad8ee2cdd6 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Sep 2015 01:13:59 +0200 Subject: update appveyor --- .appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.appveyor.yml b/.appveyor.yml index 4e690c06..dbb6d2fa 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -8,4 +8,4 @@ install: - "%PYTHON%\\python -c \"from OpenSSL import SSL; print(SSL.SSLeay_version(SSL.SSLEAY_VERSION))\"" build: off # Not a C# project test_script: - - "%PYTHON%\\Scripts\\nosetests" \ No newline at end of file + - "%PYTHON%\\Scripts\\py.test -n 4" \ No newline at end of file -- cgit v1.2.3 From 9dea36e43913a642c7f379a55828e1fb7745ba6b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Sep 2015 01:22:05 +0200 Subject: remove nose references --- test/http/http2/test_frames.py | 332 +++++++++++++++++-------------------- test/http/test_cookies.py | 24 ++- test/websockets/test_websockets.py | 9 +- 3 files changed, 167 insertions(+), 198 deletions(-) diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py index 4c89b023..9f41c74d 100644 --- a/test/http/http2/test_frames.py +++ b/test/http/http2/test_frames.py @@ -1,5 +1,4 @@ from io import BytesIO -from nose.tools import assert_equal from netlib import tcp, tutils from netlib.http.http2.frame import * @@ -30,7 +29,7 @@ def test_frame_equality(): flags=Frame.FLAG_END_STREAM, stream_id=0x1234567, payload='foobar') - assert_equal(a, b) + assert a == b def test_too_large_frames(): @@ -48,7 +47,7 @@ def test_data_frame_to_bytes(): flags=Frame.FLAG_END_STREAM, stream_id=0x1234567, payload='foobar') - assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') + assert f.to_bytes().encode('hex') == '000006000101234567666f6f626172' f = DataFrame( length=11, @@ -56,9 +55,7 @@ def test_data_frame_to_bytes(): stream_id=0x1234567, payload='foobar', pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000a00090123456703666f6f626172000000') + assert f.to_bytes().encode('hex') == '00000a00090123456703666f6f626172000000' f = DataFrame( length=6, @@ -71,19 +68,19 @@ def test_data_frame_to_bytes(): def test_data_frame_from_bytes(): f = Frame.from_file(hex_to_file('000006000101234567666f6f626172')) assert isinstance(f, DataFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, DataFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_END_STREAM) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.payload, 'foobar') + assert f.length == 6 + assert f.TYPE == DataFrame.TYPE + assert f.flags == Frame.FLAG_END_STREAM + assert f.stream_id == 0x1234567 + assert f.payload == 'foobar' f = Frame.from_file(hex_to_file('00000a00090123456703666f6f626172000000')) assert isinstance(f, DataFrame) - assert_equal(f.length, 10) - assert_equal(f.TYPE, DataFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.payload, 'foobar') + assert f.length == 10 + assert f.TYPE == DataFrame.TYPE + assert f.flags == Frame.FLAG_END_STREAM | Frame.FLAG_PADDED + assert f.stream_id == 0x1234567 + assert f.payload == 'foobar' def test_data_frame_human_readable(): @@ -102,7 +99,7 @@ def test_headers_frame_to_bytes(): flags=(Frame.FLAG_NO_FLAGS), stream_id=0x1234567, header_block_fragment='668594e75e31d9'.decode('hex')) - assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') + assert f.to_bytes().encode('hex') == '000007010001234567668594e75e31d9' f = HeadersFrame( length=10, @@ -110,9 +107,7 @@ def test_headers_frame_to_bytes(): stream_id=0x1234567, header_block_fragment='668594e75e31d9'.decode('hex'), pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000b01080123456703668594e75e31d9000000') + assert f.to_bytes().encode('hex') == '00000b01080123456703668594e75e31d9000000' f = HeadersFrame( length=10, @@ -122,9 +117,7 @@ def test_headers_frame_to_bytes(): exclusive=True, stream_dependency=0x7654321, weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00000c012001234567876543212a668594e75e31d9') + assert f.to_bytes().encode('hex') == '00000c012001234567876543212a668594e75e31d9' f = HeadersFrame( length=14, @@ -135,9 +128,7 @@ def test_headers_frame_to_bytes(): exclusive=True, stream_dependency=0x7654321, weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00001001280123456703876543212a668594e75e31d9000000') + assert f.to_bytes().encode('hex') == '00001001280123456703876543212a668594e75e31d9000000' f = HeadersFrame( length=14, @@ -148,9 +139,7 @@ def test_headers_frame_to_bytes(): exclusive=False, stream_dependency=0x7654321, weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00001001280123456703076543212a668594e75e31d9000000') + assert f.to_bytes().encode('hex') == '00001001280123456703076543212a668594e75e31d9000000' f = HeadersFrame( length=6, @@ -164,56 +153,56 @@ def test_headers_frame_from_bytes(): f = Frame.from_file(hex_to_file( '000007010001234567668594e75e31d9')) assert isinstance(f, HeadersFrame) - assert_equal(f.length, 7) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert f.length == 7 + assert f.TYPE == HeadersFrame.TYPE + assert f.flags == Frame.FLAG_NO_FLAGS + assert f.stream_id == 0x1234567 + assert f.header_block_fragment == '668594e75e31d9'.decode('hex') f = Frame.from_file(hex_to_file( '00000b01080123456703668594e75e31d9000000')) assert isinstance(f, HeadersFrame) - assert_equal(f.length, 11) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert f.length == 11 + assert f.TYPE == HeadersFrame.TYPE + assert f.flags == HeadersFrame.FLAG_PADDED + assert f.stream_id == 0x1234567 + assert f.header_block_fragment == '668594e75e31d9'.decode('hex') f = Frame.from_file(hex_to_file( '00000c012001234567876543212a668594e75e31d9')) assert isinstance(f, HeadersFrame) - assert_equal(f.length, 12) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) + assert f.length == 12 + assert f.TYPE == HeadersFrame.TYPE + assert f.flags == HeadersFrame.FLAG_PRIORITY + assert f.stream_id == 0x1234567 + assert f.header_block_fragment == '668594e75e31d9'.decode('hex') + assert f.exclusive == True + assert f.stream_dependency == 0x7654321 + assert f.weight == 42 f = Frame.from_file(hex_to_file( '00001001280123456703876543212a668594e75e31d9000000')) assert isinstance(f, HeadersFrame) - assert_equal(f.length, 16) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) + assert f.length == 16 + assert f.TYPE == HeadersFrame.TYPE + assert f.flags == HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY + assert f.stream_id == 0x1234567 + assert f.header_block_fragment == '668594e75e31d9'.decode('hex') + assert f.exclusive == True + assert f.stream_dependency == 0x7654321 + assert f.weight == 42 f = Frame.from_file(hex_to_file( '00001001280123456703076543212a668594e75e31d9000000')) assert isinstance(f, HeadersFrame) - assert_equal(f.length, 16) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, False) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) + assert f.length == 16 + assert f.TYPE == HeadersFrame.TYPE + assert f.flags == HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY + assert f.stream_id == 0x1234567 + assert f.header_block_fragment == '668594e75e31d9'.decode('hex') + assert f.exclusive == False + assert f.stream_dependency == 0x7654321 + assert f.weight == 42 def test_headers_frame_human_readable(): @@ -248,7 +237,7 @@ def test_priority_frame_to_bytes(): exclusive=True, stream_dependency=0x0, weight=42) - assert_equal(f.to_bytes().encode('hex'), '000005020001234567800000002a') + assert f.to_bytes().encode('hex') == '000005020001234567800000002a' f = PriorityFrame( length=5, @@ -257,7 +246,7 @@ def test_priority_frame_to_bytes(): exclusive=False, stream_dependency=0x7654321, weight=21) - assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') + assert f.to_bytes().encode('hex') == '0000050200012345670765432115' f = PriorityFrame( length=5, @@ -270,23 +259,23 @@ def test_priority_frame_to_bytes(): def test_priority_frame_from_bytes(): f = Frame.from_file(hex_to_file('000005020001234567876543212a')) assert isinstance(f, PriorityFrame) - assert_equal(f.length, 5) - assert_equal(f.TYPE, PriorityFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) + assert f.length == 5 + assert f.TYPE == PriorityFrame.TYPE + assert f.flags == Frame.FLAG_NO_FLAGS + assert f.stream_id == 0x1234567 + assert f.exclusive == True + assert f.stream_dependency == 0x7654321 + assert f.weight == 42 f = Frame.from_file(hex_to_file('0000050200012345670765432115')) assert isinstance(f, PriorityFrame) - assert_equal(f.length, 5) - assert_equal(f.TYPE, PriorityFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.exclusive, False) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 21) + assert f.length == 5 + assert f.TYPE == PriorityFrame.TYPE + assert f.flags == Frame.FLAG_NO_FLAGS + assert f.stream_id == 0x1234567 + assert f.exclusive == False + assert f.stream_dependency == 0x7654321 + assert f.weight == 21 def test_priority_frame_human_readable(): @@ -306,7 +295,7 @@ def test_rst_stream_frame_to_bytes(): flags=Frame.FLAG_NO_FLAGS, stream_id=0x1234567, error_code=0x7654321) - assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') + assert f.to_bytes().encode('hex') == '00000403000123456707654321' f = RstStreamFrame( length=4, @@ -318,11 +307,11 @@ def test_rst_stream_frame_to_bytes(): def test_rst_stream_frame_from_bytes(): f = Frame.from_file(hex_to_file('00000403000123456707654321')) assert isinstance(f, RstStreamFrame) - assert_equal(f.length, 4) - assert_equal(f.TYPE, RstStreamFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.error_code, 0x07654321) + assert f.length == 4 + assert f.TYPE == RstStreamFrame.TYPE + assert f.flags == Frame.FLAG_NO_FLAGS + assert f.stream_id == 0x1234567 + assert f.error_code == 0x07654321 def test_rst_stream_frame_human_readable(): @@ -339,13 +328,13 @@ def test_settings_frame_to_bytes(): length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0) - assert_equal(f.to_bytes().encode('hex'), '000000040000000000') + assert f.to_bytes().encode('hex') == '000000040000000000' f = SettingsFrame( length=0, flags=SettingsFrame.FLAG_ACK, stream_id=0x0) - assert_equal(f.to_bytes().encode('hex'), '000000040100000000') + assert f.to_bytes().encode('hex') == '000000040100000000' f = SettingsFrame( length=6, @@ -353,7 +342,7 @@ def test_settings_frame_to_bytes(): stream_id=0x0, settings={ SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) - assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') + assert f.to_bytes().encode('hex') == '000006040100000000000200000001' f = SettingsFrame( length=12, @@ -362,9 +351,7 @@ def test_settings_frame_to_bytes(): settings={ SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert_equal( - f.to_bytes().encode('hex'), - '00000c040000000000000200000001000312345678') + assert f.to_bytes().encode('hex') == '00000c040000000000000200000001000312345678' f = SettingsFrame( length=0, @@ -376,40 +363,37 @@ def test_settings_frame_to_bytes(): def test_settings_frame_from_bytes(): f = Frame.from_file(hex_to_file('000000040000000000')) assert isinstance(f, SettingsFrame) - assert_equal(f.length, 0) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) + assert f.length == 0 + assert f.TYPE == SettingsFrame.TYPE + assert f.flags == Frame.FLAG_NO_FLAGS + assert f.stream_id == 0x0 f = Frame.from_file(hex_to_file('000000040100000000')) assert isinstance(f, SettingsFrame) - assert_equal(f.length, 0) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, SettingsFrame.FLAG_ACK) - assert_equal(f.stream_id, 0x0) + assert f.length == 0 + assert f.TYPE == SettingsFrame.TYPE + assert f.flags == SettingsFrame.FLAG_ACK + assert f.stream_id == 0x0 f = Frame.from_file(hex_to_file('000006040100000000000200000001')) assert isinstance(f, SettingsFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, SettingsFrame.FLAG_ACK, 0x0) - assert_equal(f.stream_id, 0x0) - assert_equal(len(f.settings), 1) - assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) + assert f.length == 6 + assert f.TYPE == SettingsFrame.TYPE + assert f.flags == SettingsFrame.FLAG_ACK, 0x0 + assert f.stream_id == 0x0 + assert len(f.settings) == 1 + assert f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 1 f = Frame.from_file(hex_to_file( '00000c040000000000000200000001000312345678')) assert isinstance(f, SettingsFrame) - assert_equal(f.length, 12) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(len(f.settings), 2) - assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - assert_equal( - f.settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], - 0x12345678) + assert f.length == 12 + assert f.TYPE == SettingsFrame.TYPE + assert f.flags == Frame.FLAG_NO_FLAGS + assert f.stream_id == 0x0 + assert len(f.settings) == 2 + assert f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 1 + assert f.settings[SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 0x12345678 def test_settings_frame_human_readable(): @@ -437,9 +421,7 @@ def test_push_promise_frame_to_bytes(): stream_id=0x1234567, promised_stream=0x7654321, header_block_fragment='foobar') - assert_equal( - f.to_bytes().encode('hex'), - '00000a05000123456707654321666f6f626172') + assert f.to_bytes().encode('hex') == '00000a05000123456707654321666f6f626172' f = PushPromiseFrame( length=14, @@ -448,9 +430,7 @@ def test_push_promise_frame_to_bytes(): promised_stream=0x7654321, header_block_fragment='foobar', pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000e0508012345670307654321666f6f626172000000') + assert f.to_bytes().encode('hex') == '00000e0508012345670307654321666f6f626172000000' f = PushPromiseFrame( length=4, @@ -470,20 +450,20 @@ def test_push_promise_frame_to_bytes(): def test_push_promise_frame_from_bytes(): f = Frame.from_file(hex_to_file('00000a05000123456707654321666f6f626172')) assert isinstance(f, PushPromiseFrame) - assert_equal(f.length, 10) - assert_equal(f.TYPE, PushPromiseFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') + assert f.length == 10 + assert f.TYPE == PushPromiseFrame.TYPE + assert f.flags == Frame.FLAG_NO_FLAGS + assert f.stream_id == 0x1234567 + assert f.header_block_fragment == 'foobar' f = Frame.from_file(hex_to_file( '00000e0508012345670307654321666f6f626172000000')) assert isinstance(f, PushPromiseFrame) - assert_equal(f.length, 14) - assert_equal(f.TYPE, PushPromiseFrame.TYPE) - assert_equal(f.flags, PushPromiseFrame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') + assert f.length == 14 + assert f.TYPE == PushPromiseFrame.TYPE + assert f.flags == PushPromiseFrame.FLAG_PADDED + assert f.stream_id == 0x1234567 + assert f.header_block_fragment == 'foobar' def test_push_promise_frame_human_readable(): @@ -503,18 +483,14 @@ def test_ping_frame_to_bytes(): flags=PingFrame.FLAG_ACK, stream_id=0x0, payload=b'foobar') - assert_equal( - f.to_bytes().encode('hex'), - '000008060100000000666f6f6261720000') + assert f.to_bytes().encode('hex') == '000008060100000000666f6f6261720000' f = PingFrame( length=8, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, payload=b'foobardeadbeef') - assert_equal( - f.to_bytes().encode('hex'), - '000008060000000000666f6f6261726465') + assert f.to_bytes().encode('hex') == '000008060000000000666f6f6261726465' f = PingFrame( length=8, @@ -526,19 +502,19 @@ def test_ping_frame_to_bytes(): def test_ping_frame_from_bytes(): f = Frame.from_file(hex_to_file('000008060100000000666f6f6261720000')) assert isinstance(f, PingFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, PingFrame.TYPE) - assert_equal(f.flags, PingFrame.FLAG_ACK) - assert_equal(f.stream_id, 0x0) - assert_equal(f.payload, b'foobar\0\0') + assert f.length == 8 + assert f.TYPE == PingFrame.TYPE + assert f.flags == PingFrame.FLAG_ACK + assert f.stream_id == 0x0 + assert f.payload == b'foobar\0\0' f = Frame.from_file(hex_to_file('000008060000000000666f6f6261726465')) assert isinstance(f, PingFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, PingFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.payload, b'foobarde') + assert f.length == 8 + assert f.TYPE == PingFrame.TYPE + assert f.flags == Frame.FLAG_NO_FLAGS + assert f.stream_id == 0x0 + assert f.payload == b'foobarde' def test_ping_frame_human_readable(): @@ -558,9 +534,7 @@ def test_goaway_frame_to_bytes(): last_stream=0x1234567, error_code=0x87654321, data=b'') - assert_equal( - f.to_bytes().encode('hex'), - '0000080700000000000123456787654321') + assert f.to_bytes().encode('hex') == '0000080700000000000123456787654321' f = GoAwayFrame( length=14, @@ -569,9 +543,7 @@ def test_goaway_frame_to_bytes(): last_stream=0x1234567, error_code=0x87654321, data=b'foobar') - assert_equal( - f.to_bytes().encode('hex'), - '00000e0700000000000123456787654321666f6f626172') + assert f.to_bytes().encode('hex') == '00000e0700000000000123456787654321666f6f626172' f = GoAwayFrame( length=8, @@ -586,24 +558,24 @@ def test_goaway_frame_from_bytes(): f = Frame.from_file(hex_to_file( '0000080700000000000123456787654321')) assert isinstance(f, GoAwayFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, GoAwayFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.last_stream, 0x1234567) - assert_equal(f.error_code, 0x87654321) - assert_equal(f.data, b'') + assert f.length == 8 + assert f.TYPE == GoAwayFrame.TYPE + assert f.flags == Frame.FLAG_NO_FLAGS + assert f.stream_id == 0x0 + assert f.last_stream == 0x1234567 + assert f.error_code == 0x87654321 + assert f.data == b'' f = Frame.from_file(hex_to_file( '00000e0700000000000123456787654321666f6f626172')) assert isinstance(f, GoAwayFrame) - assert_equal(f.length, 14) - assert_equal(f.TYPE, GoAwayFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.last_stream, 0x1234567) - assert_equal(f.error_code, 0x87654321) - assert_equal(f.data, b'foobar') + assert f.length == 14 + assert f.TYPE == GoAwayFrame.TYPE + assert f.flags == Frame.FLAG_NO_FLAGS + assert f.stream_id == 0x0 + assert f.last_stream == 0x1234567 + assert f.error_code == 0x87654321 + assert f.data == b'foobar' def test_go_away_frame_human_readable(): @@ -623,14 +595,14 @@ def test_window_update_frame_to_bytes(): flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, window_size_increment=0x1234567) - assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') + assert f.to_bytes().encode('hex') == '00000408000000000001234567' f = WindowUpdateFrame( length=4, flags=Frame.FLAG_NO_FLAGS, stream_id=0x1234567, window_size_increment=0x7654321) - assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') + assert f.to_bytes().encode('hex') == '00000408000123456707654321' f = WindowUpdateFrame( length=4, @@ -646,11 +618,11 @@ def test_window_update_frame_to_bytes(): def test_window_update_frame_from_bytes(): f = Frame.from_file(hex_to_file('00000408000000000001234567')) assert isinstance(f, WindowUpdateFrame) - assert_equal(f.length, 4) - assert_equal(f.TYPE, WindowUpdateFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.window_size_increment, 0x1234567) + assert f.length == 4 + assert f.TYPE == WindowUpdateFrame.TYPE + assert f.flags == Frame.FLAG_NO_FLAGS + assert f.stream_id == 0x0 + assert f.window_size_increment == 0x1234567 def test_window_update_frame_human_readable(): @@ -668,7 +640,7 @@ def test_continuation_frame_to_bytes(): flags=ContinuationFrame.FLAG_END_HEADERS, stream_id=0x1234567, header_block_fragment='foobar') - assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') + assert f.to_bytes().encode('hex') == '000006090401234567666f6f626172' f = ContinuationFrame( length=6, @@ -681,11 +653,11 @@ def test_continuation_frame_to_bytes(): def test_continuation_frame_from_bytes(): f = Frame.from_file(hex_to_file('000006090401234567666f6f626172')) assert isinstance(f, ContinuationFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, ContinuationFrame.TYPE) - assert_equal(f.flags, ContinuationFrame.FLAG_END_HEADERS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') + assert f.length == 6 + assert f.TYPE == ContinuationFrame.TYPE + assert f.flags == ContinuationFrame.FLAG_END_HEADERS + assert f.stream_id == 0x1234567 + assert f.header_block_fragment == 'foobar' def test_continuation_frame_human_readable(): diff --git a/test/http/test_cookies.py b/test/http/test_cookies.py index 4f99593a..413b6241 100644 --- a/test/http/test_cookies.py +++ b/test/http/test_cookies.py @@ -1,5 +1,3 @@ -import nose.tools - from netlib.http import cookies @@ -13,7 +11,7 @@ def test_read_token(): [(" foo=bar", 1), ("foo", 4)], ] for q, a in tokens: - nose.tools.eq_(cookies._read_token(*q), a) + assert cookies._read_token(*q) == a def test_read_quoted_string(): @@ -25,7 +23,7 @@ def test_read_quoted_string(): [('"fo\\\"" x', 0), ("fo\"", 6)], ] for q, a in tokens: - nose.tools.eq_(cookies._read_quoted_string(*q), a) + assert cookies._read_quoted_string(*q) == a def test_read_pairs(): @@ -61,7 +59,7 @@ def test_read_pairs(): ] for s, lst in vals: ret, off = cookies._read_pairs(s) - nose.tools.eq_(ret, lst) + assert ret == lst def test_pairs_roundtrips(): @@ -109,10 +107,10 @@ def test_pairs_roundtrips(): ] for s, lst in pairs: ret, off = cookies._read_pairs(s) - nose.tools.eq_(ret, lst) + assert ret == lst s2 = cookies._format_pairs(lst) ret, off = cookies._read_pairs(s2) - nose.tools.eq_(ret, lst) + assert ret == lst def test_cookie_roundtrips(): @@ -128,10 +126,10 @@ def test_cookie_roundtrips(): ] for s, lst in pairs: ret = cookies.parse_cookie_header(s) - nose.tools.eq_(ret.lst, lst) + assert ret.lst == lst s2 = cookies.format_cookie_header(ret) ret = cookies.parse_cookie_header(s2) - nose.tools.eq_(ret.lst, lst) + assert ret.lst == lst def test_parse_set_cookie_pairs(): @@ -181,10 +179,10 @@ def test_parse_set_cookie_pairs(): ] for s, lst in pairs: ret = cookies._parse_set_cookie_pairs(s) - nose.tools.eq_(ret, lst) + assert ret == lst s2 = cookies._format_set_cookie_pairs(ret) ret2 = cookies._parse_set_cookie_pairs(s2) - nose.tools.eq_(ret2, lst) + assert ret2 == lst def test_parse_set_cookie_header(): @@ -209,11 +207,11 @@ def test_parse_set_cookie_header(): if expected: assert ret[0] == expected[0] assert ret[1] == expected[1] - nose.tools.eq_(ret[2].lst, expected[2]) + assert ret[2].lst == expected[2] s2 = cookies.format_set_cookie_header(*ret) ret2 = cookies.parse_set_cookie_header(s2) assert ret2[0] == expected[0] assert ret2[1] == expected[1] - nose.tools.eq_(ret2[2].lst, expected[2]) + assert ret2[2].lst == expected[2] else: assert ret is None diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 6f67b84d..48acc2d6 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -1,6 +1,5 @@ import os -from nose.tools import raises from netlib.http.http1 import read_response, read_request from netlib import tcp, tutils, websockets, http @@ -176,11 +175,11 @@ class TestBadHandshake(tservers.ServerTestBase): """ handler = BadHandshakeHandler - @raises(TcpDisconnect) def test(self): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(b"hello") + with tutils.raises(TcpDisconnect): + client = WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message(b"hello") class TestFrameHeader: -- cgit v1.2.3 From 1ff8f294b459e03e113acb417678a6fd782c2685 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Sep 2015 18:34:43 +0200 Subject: minor encoding fixes --- netlib/utils.py | 6 +++--- netlib/wsgi.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/netlib/utils.py b/netlib/utils.py index 8d11bd5b..b9848038 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -26,7 +26,7 @@ def always_byte_args(*encode_args): return decorator -def native(s, encoding="latin-1"): +def native(s, *encoding_opts): """ Convert :py:class:`bytes` or :py:class:`unicode` to the native :py:class:`str` type, using latin1 encoding if conversion is necessary. @@ -37,10 +37,10 @@ def native(s, encoding="latin-1"): raise TypeError("%r is neither bytes nor unicode" % s) if six.PY3: if isinstance(s, six.binary_type): - return s.decode(encoding) + return s.decode(*encoding_opts) else: if isinstance(s, six.text_type): - return s.encode(encoding) + return s.encode(*encoding_opts) return s diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 8fb09008..4fcd5178 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -55,38 +55,38 @@ class WSGIAdaptor(object): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion def make_environ(self, flow, errsoc, **extra): - path = native(flow.request.path) + path = native(flow.request.path, "latin-1") if '?' in path: - path_info, query = native(path).split('?', 1) + path_info, query = native(path, "latin-1").split('?', 1) else: path_info = path query = '' environ = { 'wsgi.version': (1, 0), - 'wsgi.url_scheme': native(flow.request.scheme), + 'wsgi.url_scheme': native(flow.request.scheme, "latin-1"), 'wsgi.input': BytesIO(flow.request.body or b""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, 'wsgi.run_once': False, 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': native(flow.request.method), + 'REQUEST_METHOD': native(flow.request.method, "latin-1"), 'SCRIPT_NAME': '', 'PATH_INFO': urllib.parse.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', '')), - 'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', '')), + 'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', ''), "latin-1"), + 'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', ''), "latin-1"), 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), - 'SERVER_PROTOCOL': native(flow.request.http_version), + 'SERVER_PROTOCOL': native(flow.request.http_version, "latin-1"), } environ.update(extra) if flow.client_conn.address: - environ["REMOTE_ADDR"] = native(flow.client_conn.address.host) + environ["REMOTE_ADDR"] = native(flow.client_conn.address.host, "latin-1") environ["REMOTE_PORT"] = flow.client_conn.address.port for key, value in flow.request.headers.items(): - key = 'HTTP_' + native(key).upper().replace('-', '_') + key = 'HTTP_' + native(key, "latin-1").upper().replace('-', '_') if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): environ[key] = value return environ -- cgit v1.2.3 From e9fe45f3f404bb1c762dfb13477072c06d4b74ec Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 21 Sep 2015 18:38:50 +0200 Subject: backport changes --- netlib/http/models.py | 36 ++++++++++++++++++------------------ netlib/tcp.py | 1 + 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/netlib/http/models.py b/netlib/http/models.py index 3c360a37..512a764d 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -216,7 +216,7 @@ class Message(object): def body(self, body): self._body = body if isinstance(body, bytes): - self.headers[b"Content-Length"] = str(len(body)).encode() + self.headers[b"content-length"] = str(len(body)).encode() content = body @@ -268,8 +268,8 @@ class Request(Message): response. That is, we remove ETags and If-Modified-Since headers. """ delheaders = [ - b"If-Modified-Since", - b"If-None-Match", + b"if-modified-since", + b"if-none-match", ] for i in delheaders: self.headers.pop(i, None) @@ -279,16 +279,16 @@ class Request(Message): Modifies this request to remove headers that will compress the resource's data. """ - self.headers["Accept-Encoding"] = "identity" + self.headers["accept-encoding"] = b"identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = native(self.headers.get("Accept-Encoding"), "ascii") + accept_encoding = native(self.headers.get("accept-encoding"), "ascii") if accept_encoding: - self.headers["Accept-Encoding"] = ( + self.headers["accept-encoding"] = ( ', '.join( e for e in encoding.ENCODINGS @@ -300,7 +300,7 @@ class Request(Message): """ Update the host header to reflect the current target. """ - self.headers["Host"] = self.host + self.headers["host"] = self.host def get_form(self): """ @@ -309,9 +309,9 @@ class Request(Message): indicates non-form data. """ if self.body: - if HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower(): + if HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower(): + elif HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): return self.get_form_multipart() return ODict([]) @@ -321,12 +321,12 @@ class Request(Message): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower(): + if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): return ODict(utils.urldecode(self.body)) return ODict([]) def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower(): + if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): return ODict( utils.multipartdecode( self.headers, @@ -341,7 +341,7 @@ class Request(Message): """ # FIXME: If there's an existing content-type header indicating a # url-encoded form, leave it alone. - self.headers[b"Content-Type"] = HDR_FORM_URLENCODED + self.headers[b"content-type"] = HDR_FORM_URLENCODED self.body = utils.urlencode(odict.lst) def get_path_components(self): @@ -398,9 +398,9 @@ class Request(Message): but not the resolved name. This is disabled by default, as an attacker may spoof the host header to confuse an analyst. """ - if hostheader and "Host" in self.headers: + if hostheader and "host" in self.headers: try: - return self.headers["Host"].decode("idna") + return self.headers["host"].decode("idna") except ValueError: pass if self.host: @@ -429,7 +429,7 @@ class Request(Message): headers. """ v = cookies.format_cookie_header(odict) - self.headers["Cookie"] = v + self.headers["cookie"] = v @property def url(self): @@ -485,7 +485,7 @@ class Response(Message): return "".format( status_code=self.status_code, msg=self.msg, - contenttype=self.headers.get("Content-Type", "unknown content type"), + contenttype=self.headers.get("content-type", "unknown content type"), size=size) def get_cookies(self): @@ -498,7 +498,7 @@ class Response(Message): attributes (e.g. HTTPOnly) are indicated by a Null value. """ ret = [] - for header in self.headers.get_all("Set-Cookie"): + for header in self.headers.get_all("set-cookie"): v = cookies.parse_set_cookie_header(native(header, "ascii")) if v: name, value, attrs = v @@ -521,4 +521,4 @@ class Response(Message): i[1][1] ) ) - self.headers.set_all("Set-Cookie", values) + self.headers.set_all("set-cookie", values) diff --git a/netlib/tcp.py b/netlib/tcp.py index 40ffbd48..b751d71f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -279,6 +279,7 @@ class Reader(_FileLike): if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): return self.o.recv(length, socket.MSG_PEEK) else: + # TODO: remove once a new version is released # Polyfill for pyOpenSSL <= 0.15.1 # Taken from https://github.com/pyca/pyopenssl/commit/1d95dea7fea03c7c0df345a5ea30c12d8a0378d2 buf = SSL._ffi.new("char[]", length) -- cgit v1.2.3 From 9fbeac50ce3f6ae49b0f0270c508b6e81a1eaf17 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 21 Sep 2015 22:49:39 +0200 Subject: revert websocket changes from 73586b1b The DEFAULT construct is very weird, but with None we apparently break pathod in some difficult-to-debug ways. Revisit once we do more here. --- netlib/websockets/frame.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index 55eeaf41..fce2c9d3 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -14,6 +14,8 @@ from netlib import utils MAX_16_BIT_INT = (1 << 16) MAX_64_BIT_INT = (1 << 64) +DEFAULT=object() + OPCODE = utils.BiDi( CONTINUE=0x00, TEXT=0x01, @@ -34,9 +36,9 @@ class FrameHeader(object): rsv1=False, rsv2=False, rsv3=False, - masking_key=None, - mask=None, - length_code=None + masking_key=DEFAULT, + mask=DEFAULT, + length_code=DEFAULT ): if not 0 <= opcode < 2 ** 4: raise ValueError("opcode must be 0-16") @@ -47,18 +49,18 @@ class FrameHeader(object): self.rsv2 = rsv2 self.rsv3 = rsv3 - if length_code is None: + if length_code is DEFAULT: self.length_code = self._make_length_code(self.payload_length) else: self.length_code = length_code - if mask is None and masking_key is None: + if mask is DEFAULT and masking_key is DEFAULT: self.mask = False self.masking_key = b"" - elif mask is None: + elif mask is DEFAULT: self.mask = 1 self.masking_key = masking_key - elif masking_key is None: + elif masking_key is DEFAULT: self.mask = mask self.masking_key = os.urandom(4) else: @@ -166,7 +168,7 @@ class FrameHeader(object): if mask_bit == 1: masking_key = fp.safe_read(4) else: - masking_key = False + masking_key = None return cls( fin=fin, @@ -230,7 +232,7 @@ class Frame(object): masking_key = os.urandom(4) else: mask_bit = 0 - masking_key = False + masking_key = None return cls( message, @@ -311,4 +313,4 @@ class Frame(object): def __eq__(self, other): if isinstance(other, Frame): return bytes(self) == bytes(other) - return False \ No newline at end of file + return False -- cgit v1.2.3 From f93752277395d201fabefed8fae6d412f13da699 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 22 Sep 2015 01:48:35 +0200 Subject: Headers: return str on all Python versions --- netlib/http/__init__.py | 6 +- netlib/http/authentication.py | 10 +- netlib/http/headers.py | 205 +++++++++++++++++++++++++++++++++++ netlib/http/http1/assemble.py | 6 +- netlib/http/http1/read.py | 14 +-- netlib/http/models.py | 215 ++++--------------------------------- netlib/utils.py | 17 +-- netlib/websockets/protocol.py | 14 ++- test/http/http1/test_assemble.py | 6 +- test/http/http1/test_read.py | 22 ++-- test/http/test_authentication.py | 12 +-- test/http/test_headers.py | 149 +++++++++++++++++++++++++ test/http/test_models.py | 152 +------------------------- test/test_utils.py | 20 ++-- test/websockets/test_websockets.py | 13 ++- 15 files changed, 443 insertions(+), 418 deletions(-) create mode 100644 netlib/http/headers.py create mode 100644 test/http/test_headers.py diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index d72884b3..0ccf6b32 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,11 +1,13 @@ from __future__ import absolute_import, print_function, division -from .models import Request, Response, Headers +from .headers import Headers +from .models import Request, Response from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING from . import http1, http2 __all__ = [ - "Request", "Response", "Headers", + "Headers", + "Request", "Response", "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", "http1", "http2", diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 5831660b..d769abe5 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -9,18 +9,18 @@ def parse_http_basic_auth(s): return None scheme = words[0] try: - user = binascii.a2b_base64(words[1]) + user = binascii.a2b_base64(words[1]).decode("utf8", "replace") except binascii.Error: return None - parts = user.split(b':') + parts = user.split(':') if len(parts) != 2: return None return scheme, parts[0], parts[1] def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + b":" + password) - return scheme + b" " + v + v = binascii.b2a_base64((username + ":" + password).encode("utf8")).decode("ascii") + return scheme + " " + v class NullProxyAuth(object): @@ -69,7 +69,7 @@ class BasicProxyAuth(NullProxyAuth): if not parts: return False scheme, username, password = parts - if scheme.lower() != b'basic': + if scheme.lower() != 'basic': return False if not self.password_manager.test(username, password): return False diff --git a/netlib/http/headers.py b/netlib/http/headers.py new file mode 100644 index 00000000..1511ea2d --- /dev/null +++ b/netlib/http/headers.py @@ -0,0 +1,205 @@ +""" + +Unicode Handling +---------------- +See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ +""" +from __future__ import absolute_import, print_function, division +import copy +try: + from collections.abc import MutableMapping +except ImportError: # Workaround for Python < 3.3 + from collections import MutableMapping + + +import six + +from netlib.utils import always_byte_args + +if six.PY2: + _native = lambda x: x + _asbytes = lambda x: x + _always_byte_args = lambda x: x +else: + # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. + _native = lambda x: x.decode("utf-8", "surrogateescape") + _asbytes = lambda x: x.encode("utf-8", "surrogateescape") + _always_byte_args = always_byte_args("utf-8", "surrogateescape") + + +class Headers(MutableMapping, object): + """ + Header class which allows both convenient access to individual headers as well as + direct access to the underlying raw data. Provides a full dictionary interface. + + Example: + + .. code-block:: python + + # Create header from a list of (header_name, header_value) tuples + >>> h = Headers([ + ["Host","example.com"], + ["Accept","text/html"], + ["accept","application/xml"] + ]) + + # Headers mostly behave like a normal dict. + >>> h["Host"] + "example.com" + + # HTTP Headers are case insensitive + >>> h["host"] + "example.com" + + # Multiple headers are folded into a single header as per RFC7230 + >>> h["Accept"] + "text/html, application/xml" + + # Setting a header removes all existing headers with the same name. + >>> h["Accept"] = "application/text" + >>> h["Accept"] + "application/text" + + # str(h) returns a HTTP1 header block. + >>> print(h) + Host: example.com + Accept: application/text + + # For full control, the raw header fields can be accessed + >>> h.fields + + # Headers can also be crated from keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") + + Caveats: + For use with the "Set-Cookie" header, see :py:meth:`get_all`. + """ + + @_always_byte_args + def __init__(self, fields=None, **headers): + """ + Args: + fields: (optional) list of ``(name, value)`` header tuples, + e.g. ``[("Host","example.com")]``. All names and values must be bytes. + **headers: Additional headers to set. Will overwrite existing values from `fields`. + For convenience, underscores in header names will be transformed to dashes - + this behaviour does not extend to other methods. + If ``**headers`` contains multiple keys that have equal ``.lower()`` s, + the behavior is undefined. + """ + self.fields = fields or [] + + for name, value in self.fields: + if not isinstance(name, bytes) or not isinstance(value, bytes): + raise ValueError("Headers passed as fields must be bytes.") + + # content_type -> content-type + headers = { + _asbytes(name).replace(b"_", b"-"): value + for name, value in six.iteritems(headers) + } + self.update(headers) + + def __bytes__(self): + if self.fields: + return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + else: + return b"" + + if six.PY2: + __str__ = __bytes__ + + @_always_byte_args + def __getitem__(self, name): + values = self.get_all(name) + if not values: + raise KeyError(name) + return ", ".join(values) + + @_always_byte_args + def __setitem__(self, name, value): + idx = self._index(name) + + # To please the human eye, we insert at the same position the first existing header occured. + if idx is not None: + del self[name] + self.fields.insert(idx, [name, value]) + else: + self.fields.append([name, value]) + + @_always_byte_args + def __delitem__(self, name): + if name not in self: + raise KeyError(name) + name = name.lower() + self.fields = [ + field for field in self.fields + if name != field[0].lower() + ] + + def __iter__(self): + seen = set() + for name, _ in self.fields: + name_lower = name.lower() + if name_lower not in seen: + seen.add(name_lower) + yield _native(name) + + def __len__(self): + return len(set(name.lower() for name, _ in self.fields)) + + # __hash__ = object.__hash__ + + def _index(self, name): + name = name.lower() + for i, field in enumerate(self.fields): + if field[0].lower() == name: + return i + return None + + def __eq__(self, other): + if isinstance(other, Headers): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @_always_byte_args + def get_all(self, name): + """ + Like :py:meth:`get`, but does not fold multiple headers into a single one. + This is useful for Set-Cookie headers, which do not support folding. + + See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 + """ + name_lower = name.lower() + values = [_native(value) for n, value in self.fields if n.lower() == name_lower] + return values + + @_always_byte_args + def set_all(self, name, values): + """ + Explicitly set multiple headers for the given key. + See: :py:meth:`get_all` + """ + values = map(_asbytes, values) # _always_byte_args does not fix lists + if name in self: + del self[name] + self.fields.extend( + [name, value] for value in values + ) + + def copy(self): + return Headers(copy.copy(self.fields)) + + # Implement the StateObject protocol from mitmproxy + def get_state(self, short=False): + return tuple(tuple(field) for field in self.fields) + + def load_state(self, state): + self.fields = [list(field) for field in state] + + @classmethod + def from_state(cls, state): + return cls([list(field) for field in state]) \ No newline at end of file diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index c2b60a0f..88aeac05 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -35,7 +35,7 @@ def assemble_response_head(response): def assemble_body(headers, body_chunks): - if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + if "chunked" in headers.get("transfer-encoding", "").lower(): for chunk in body_chunks: if chunk: yield b"%x\r\n%s\r\n" % (len(chunk), chunk) @@ -76,8 +76,8 @@ def _assemble_request_line(request, form=None): def _assemble_request_headers(request): headers = request.headers.copy() - if b"host" not in headers and request.scheme and request.host and request.port: - headers[b"Host"] = utils.hostport( + if "host" not in headers and request.scheme and request.host and request.port: + headers["host"] = utils.hostport( request.scheme, request.host, request.port diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index c6760ff3..4c898348 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -146,11 +146,11 @@ def connection_close(http_version, headers): according to RFC 2616 Section 8.1. """ # At first, check if we have an explicit Connection header. - if b"connection" in headers: + if "connection" in headers: tokens = utils.get_header_tokens(headers, "connection") - if b"close" in tokens: + if "close" in tokens: return True - elif b"keep-alive" in tokens: + elif "keep-alive" in tokens: return False # If we don't have a Connection header, HTTP 1.1 connections are assumed to @@ -181,7 +181,7 @@ def expected_http_body_size(request, response=None): is_request = False if is_request: - if headers.get(b"expect", b"").lower() == b"100-continue": + if headers.get("expect", "").lower() == "100-continue": return 0 else: if request.method.upper() == b"HEAD": @@ -193,11 +193,11 @@ def expected_http_body_size(request, response=None): if response_code in (204, 304): return 0 - if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + if "chunked" in headers.get("transfer-encoding", "").lower(): return None - if b"content-length" in headers: + if "content-length" in headers: try: - size = int(headers[b"content-length"]) + size = int(headers["content-length"]) if size < 0: raise ValueError() return size diff --git a/netlib/http/models.py b/netlib/http/models.py index 512a764d..55664533 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -1,201 +1,22 @@ -from __future__ import absolute_import, print_function, division -import copy + from ..odict import ODict from .. import utils, encoding -from ..utils import always_bytes, always_byte_args, native +from ..utils import always_bytes, native from . import cookies +from .headers import Headers -import six from six.moves import urllib -try: - from collections import MutableMapping -except ImportError: - from collections.abc import MutableMapping # TODO: Move somewhere else? ALPN_PROTO_HTTP1 = b'http/1.1' ALPN_PROTO_H2 = b'h2' -HDR_FORM_URLENCODED = b"application/x-www-form-urlencoded" -HDR_FORM_MULTIPART = b"multipart/form-data" +HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 -class Headers(MutableMapping, object): - """ - Header class which allows both convenient access to individual headers as well as - direct access to the underlying raw data. Provides a full dictionary interface. - - Example: - - .. code-block:: python - - # Create header from a list of (header_name, header_value) tuples - >>> h = Headers([ - ["Host","example.com"], - ["Accept","text/html"], - ["accept","application/xml"] - ]) - - # Headers mostly behave like a normal dict. - >>> h["Host"] - "example.com" - - # HTTP Headers are case insensitive - >>> h["host"] - "example.com" - - # Multiple headers are folded into a single header as per RFC7230 - >>> h["Accept"] - "text/html, application/xml" - - # Setting a header removes all existing headers with the same name. - >>> h["Accept"] = "application/text" - >>> h["Accept"] - "application/text" - - # str(h) returns a HTTP1 header block. - >>> print(h) - Host: example.com - Accept: application/text - - # For full control, the raw header fields can be accessed - >>> h.fields - - # Headers can also be crated from keyword arguments - >>> h = Headers(host="example.com", content_type="application/xml") - - Caveats: - For use with the "Set-Cookie" header, see :py:meth:`get_all`. - """ - - @always_byte_args("ascii") - def __init__(self, fields=None, **headers): - """ - Args: - fields: (optional) list of ``(name, value)`` header tuples, - e.g. ``[("Host","example.com")]``. All names and values must be bytes. - **headers: Additional headers to set. Will overwrite existing values from `fields`. - For convenience, underscores in header names will be transformed to dashes - - this behaviour does not extend to other methods. - If ``**headers`` contains multiple keys that have equal ``.lower()`` s, - the behavior is undefined. - """ - self.fields = fields or [] - - # content_type -> content-type - headers = { - name.encode("ascii").replace(b"_", b"-"): value - for name, value in six.iteritems(headers) - } - self.update(headers) - - def __bytes__(self): - if self.fields: - return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" - else: - return b"" - - if six.PY2: - __str__ = __bytes__ - - @always_byte_args("ascii") - def __getitem__(self, name): - values = self.get_all(name) - if not values: - raise KeyError(name) - return b", ".join(values) - - @always_byte_args("ascii") - def __setitem__(self, name, value): - idx = self._index(name) - - # To please the human eye, we insert at the same position the first existing header occured. - if idx is not None: - del self[name] - self.fields.insert(idx, [name, value]) - else: - self.fields.append([name, value]) - - @always_byte_args("ascii") - def __delitem__(self, name): - if name not in self: - raise KeyError(name) - name = name.lower() - self.fields = [ - field for field in self.fields - if name != field[0].lower() - ] - - def __iter__(self): - seen = set() - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - yield name - - def __len__(self): - return len(set(name.lower() for name, _ in self.fields)) - - # __hash__ = object.__hash__ - - def _index(self, name): - name = name.lower() - for i, field in enumerate(self.fields): - if field[0].lower() == name: - return i - return None - - def __eq__(self, other): - if isinstance(other, Headers): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - @always_byte_args("ascii") - def get_all(self, name): - """ - Like :py:meth:`get`, but does not fold multiple headers into a single one. - This is useful for Set-Cookie headers, which do not support folding. - - See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 - """ - name_lower = name.lower() - values = [value for n, value in self.fields if n.lower() == name_lower] - return values - - def set_all(self, name, values): - """ - Explicitly set multiple headers for the given key. - See: :py:meth:`get_all` - """ - name = always_bytes(name, "ascii") - values = (always_bytes(value, "ascii") for value in values) - if name in self: - del self[name] - self.fields.extend( - [name, value] for value in values - ) - - def copy(self): - return Headers(copy.copy(self.fields)) - - # Implement the StateObject protocol from mitmproxy - def get_state(self, short=False): - return tuple(tuple(field) for field in self.fields) - - def load_state(self, state): - self.fields = [list(field) for field in state] - - @classmethod - def from_state(cls, state): - return cls([list(field) for field in state]) - - class Message(object): def __init__(self, http_version, headers, body, timestamp_start, timestamp_end): self.http_version = http_version @@ -216,7 +37,7 @@ class Message(object): def body(self, body): self._body = body if isinstance(body, bytes): - self.headers[b"content-length"] = str(len(body)).encode() + self.headers["content-length"] = str(len(body)).encode() content = body @@ -268,8 +89,8 @@ class Request(Message): response. That is, we remove ETags and If-Modified-Since headers. """ delheaders = [ - b"if-modified-since", - b"if-none-match", + "if-modified-since", + "if-none-match", ] for i in delheaders: self.headers.pop(i, None) @@ -279,14 +100,14 @@ class Request(Message): Modifies this request to remove headers that will compress the resource's data. """ - self.headers["accept-encoding"] = b"identity" + self.headers["accept-encoding"] = "identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = native(self.headers.get("accept-encoding"), "ascii") + accept_encoding = self.headers.get("accept-encoding") if accept_encoding: self.headers["accept-encoding"] = ( ', '.join( @@ -309,9 +130,9 @@ class Request(Message): indicates non-form data. """ if self.body: - if HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): + if HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): + elif HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): return self.get_form_multipart() return ODict([]) @@ -321,12 +142,12 @@ class Request(Message): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): + if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): return ODict(utils.urldecode(self.body)) return ODict([]) def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): + if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): return ODict( utils.multipartdecode( self.headers, @@ -341,7 +162,7 @@ class Request(Message): """ # FIXME: If there's an existing content-type header indicating a # url-encoded form, leave it alone. - self.headers[b"content-type"] = HDR_FORM_URLENCODED + self.headers["content-type"] = HDR_FORM_URLENCODED self.body = utils.urlencode(odict.lst) def get_path_components(self): @@ -400,7 +221,7 @@ class Request(Message): """ if hostheader and "host" in self.headers: try: - return self.headers["host"].decode("idna") + return self.headers["host"] except ValueError: pass if self.host: @@ -420,7 +241,7 @@ class Request(Message): """ ret = ODict() for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(native(i,"ascii"))) + ret.extend(cookies.parse_cookie_header(i)) return ret def set_cookies(self, odict): @@ -499,7 +320,7 @@ class Response(Message): """ ret = [] for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(native(header, "ascii")) + v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v ret.append([name, [value, attrs]]) diff --git a/netlib/utils.py b/netlib/utils.py index b9848038..d5b30128 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -269,7 +269,7 @@ def get_header_tokens(headers, key): """ if key not in headers: return [] - tokens = headers[key].split(b",") + tokens = headers[key].split(",") return [token.strip() for token in tokens] @@ -320,14 +320,14 @@ def parse_content_type(c): ("text", "html", {"charset": "UTF-8"}) """ - parts = c.split(b";", 1) - ts = parts[0].split(b"/", 1) + parts = c.split(";", 1) + ts = parts[0].split("/", 1) if len(ts) != 2: return None d = {} if len(parts) == 2: - for i in parts[1].split(b";"): - clause = i.split(b"=", 1) + for i in parts[1].split(";"): + clause = i.split("=", 1) if len(clause) == 2: d[clause[0].strip()] = clause[1].strip() return ts[0].lower(), ts[1].lower(), d @@ -337,13 +337,14 @@ def multipartdecode(headers, content): """ Takes a multipart boundary encoded string and returns list of (key, value) tuples. """ - v = headers.get(b"Content-Type") + v = headers.get("Content-Type") if v: v = parse_content_type(v) if not v: return [] - boundary = v[2].get(b"boundary") - if not boundary: + try: + boundary = v[2]["boundary"].encode("ascii") + except (KeyError, UnicodeError): return [] rx = re.compile(br'\bname="([^"]+)"') diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 778fe7e7..e62f8df6 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -80,7 +80,7 @@ class WebsocketsProtocol(object): Returns an instance of Headers """ if not key: - key = base64.b64encode(os.urandom(16)).decode('utf-8') + key = base64.b64encode(os.urandom(16)).decode('ascii') return Headers(**{ HEADER_WEBSOCKET_KEY: key, HEADER_WEBSOCKET_VERSION: version, @@ -95,27 +95,25 @@ class WebsocketsProtocol(object): """ return Headers(**{ HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key), - "Connection": "Upgrade", - "Upgrade": "websocket", + "connection": "Upgrade", + "upgrade": "websocket", }) @classmethod def check_client_handshake(self, headers): - if headers.get("upgrade") != b"websocket": + if headers.get("upgrade") != "websocket": return return headers.get(HEADER_WEBSOCKET_KEY) @classmethod def check_server_handshake(self, headers): - if headers.get("upgrade") != b"websocket": + if headers.get("upgrade") != "websocket": return return headers.get(HEADER_WEBSOCKET_ACCEPT) @classmethod def create_server_nonce(self, client_nonce): - return base64.b64encode( - binascii.unhexlify(hashlib.sha1(client_nonce + websockets_magic).hexdigest()) - ) + return base64.b64encode(hashlib.sha1(client_nonce + websockets_magic).digest()) diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index 2d250909..963e7549 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -77,16 +77,16 @@ def test_assemble_request_line(): def test_assemble_request_headers(): # https://github.com/mitmproxy/mitmproxy/issues/186 r = treq(body=b"") - r.headers[b"Transfer-Encoding"] = b"chunked" + r.headers["Transfer-Encoding"] = "chunked" c = _assemble_request_headers(r) assert b"Transfer-Encoding" in c - assert b"Host" in _assemble_request_headers(treq(headers=Headers())) + assert b"host" in _assemble_request_headers(treq(headers=Headers())) def test_assemble_response_headers(): # https://github.com/mitmproxy/mitmproxy/issues/186 r = tresp(body=b"") - r.headers["Transfer-Encoding"] = b"chunked" + r.headers["Transfer-Encoding"] = "chunked" c = _assemble_response_headers(r) assert b"Transfer-Encoding" in c diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index 55def2a5..9eb02a24 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -1,9 +1,7 @@ from __future__ import absolute_import, print_function, division from io import BytesIO import textwrap - from mock import Mock - from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect from netlib.http import Headers from netlib.http.http1.read import ( @@ -35,7 +33,7 @@ def test_read_request_head(): rfile.first_byte_timestamp = 42 r = read_request_head(rfile) assert r.method == b"GET" - assert r.headers["Content-Length"] == b"4" + assert r.headers["Content-Length"] == "4" assert r.body is None assert rfile.reset_timestamps.called assert r.timestamp_start == 42 @@ -62,7 +60,7 @@ def test_read_response_head(): rfile.first_byte_timestamp = 42 r = read_response_head(rfile) assert r.status_code == 418 - assert r.headers["Content-Length"] == b"4" + assert r.headers["Content-Length"] == "4" assert r.body is None assert rfile.reset_timestamps.called assert r.timestamp_start == 42 @@ -76,14 +74,12 @@ class TestReadBody(object): assert body == b"foo" assert rfile.read() == b"bar" - def test_known_size(self): rfile = BytesIO(b"foobar") body = b"".join(read_body(rfile, 3)) assert body == b"foo" assert rfile.read() == b"bar" - def test_known_size_limit(self): rfile = BytesIO(b"foobar") with raises(HttpException): @@ -99,7 +95,6 @@ class TestReadBody(object): body = b"".join(read_body(rfile, -1)) assert body == b"foobar" - def test_unknown_size_limit(self): rfile = BytesIO(b"foobar") with raises(HttpException): @@ -121,13 +116,13 @@ def test_connection_close(): def test_expected_http_body_size(): # Expect: 100-continue assert expected_http_body_size( - treq(headers=Headers(expect=b"100-continue", content_length=b"42")) + treq(headers=Headers(expect="100-continue", content_length="42")) ) == 0 # http://tools.ietf.org/html/rfc7230#section-3.3 assert expected_http_body_size( treq(method=b"HEAD"), - tresp(headers=Headers(content_length=b"42")) + tresp(headers=Headers(content_length="42")) ) == 0 assert expected_http_body_size( treq(method=b"CONNECT"), @@ -141,17 +136,17 @@ def test_expected_http_body_size(): # chunked assert expected_http_body_size( - treq(headers=Headers(transfer_encoding=b"chunked")), + treq(headers=Headers(transfer_encoding="chunked")), ) is None # explicit length - for l in (b"foo", b"-7"): + for val in (b"foo", b"-7"): with raises(HttpSyntaxException): expected_http_body_size( - treq(headers=Headers(content_length=l)) + treq(headers=Headers(content_length=val)) ) assert expected_http_body_size( - treq(headers=Headers(content_length=b"42")) + treq(headers=Headers(content_length="42")) ) == 42 # no length @@ -286,6 +281,7 @@ class TestReadHeaders(object): with raises(HttpSyntaxException): self._read(data) + def test_read_chunked(): req = treq(body=None) req.headers["Transfer-Encoding"] = "chunked" diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py index a2aa774a..1df7cd9c 100644 --- a/test/http/test_authentication.py +++ b/test/http/test_authentication.py @@ -5,13 +5,13 @@ from netlib.http import authentication, Headers def test_parse_http_basic_auth(): - vals = (b"basic", b"foo", b"bar") + vals = ("basic", "foo", "bar") assert authentication.parse_http_basic_auth( authentication.assemble_http_basic_auth(*vals) ) == vals assert not authentication.parse_http_basic_auth("") assert not authentication.parse_http_basic_auth("foo bar") - v = b"basic " + binascii.b2a_base64(b"foo") + v = "basic " + binascii.b2a_base64(b"foo").decode("ascii") assert not authentication.parse_http_basic_auth(v) @@ -34,7 +34,7 @@ class TestPassManHtpasswd: def test_simple(self): pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) - vals = (b"basic", b"test", b"test") + vals = ("basic", "test", "test") authentication.assemble_http_basic_auth(*vals) assert pm.test("test", "test") assert not pm.test("test", "foo") @@ -73,7 +73,7 @@ class TestBasicProxyAuth: ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") headers = Headers() - vals = (b"basic", b"foo", b"bar") + vals = ("basic", "foo", "bar") headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) assert ba.authenticate(headers) @@ -86,12 +86,12 @@ class TestBasicProxyAuth: headers[ba.AUTH_HEADER] = "foo" assert not ba.authenticate(headers) - vals = (b"foo", b"foo", b"bar") + vals = ("foo", "foo", "bar") headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) assert not ba.authenticate(headers) ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") - vals = (b"basic", b"foo", b"bar") + vals = ("basic", "foo", "bar") headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals) assert not ba.authenticate(headers) diff --git a/test/http/test_headers.py b/test/http/test_headers.py new file mode 100644 index 00000000..f1af1feb --- /dev/null +++ b/test/http/test_headers.py @@ -0,0 +1,149 @@ +from netlib.http import Headers +from netlib.tutils import raises + + +class TestHeaders(object): + def _2host(self): + return Headers( + [ + [b"Host", b"example.com"], + [b"host", b"example.org"] + ] + ) + + def test_init(self): + headers = Headers() + assert len(headers) == 0 + + headers = Headers([[b"Host", b"example.com"]]) + assert len(headers) == 1 + assert headers["Host"] == "example.com" + + headers = Headers(Host="example.com") + assert len(headers) == 1 + assert headers["Host"] == "example.com" + + headers = Headers( + [[b"Host", b"invalid"]], + Host="example.com" + ) + assert len(headers) == 1 + assert headers["Host"] == "example.com" + + headers = Headers( + [[b"Host", b"invalid"], [b"Accept", b"text/plain"]], + Host="example.com" + ) + assert len(headers) == 2 + assert headers["Host"] == "example.com" + assert headers["Accept"] == "text/plain" + + def test_getitem(self): + headers = Headers(Host="example.com") + assert headers["Host"] == "example.com" + assert headers["host"] == "example.com" + with raises(KeyError): + _ = headers["Accept"] + + headers = self._2host() + assert headers["Host"] == "example.com, example.org" + + def test_str(self): + headers = Headers(Host="example.com") + assert bytes(headers) == b"Host: example.com\r\n" + + headers = Headers([ + [b"Host", b"example.com"], + [b"Accept", b"text/plain"] + ]) + assert bytes(headers) == b"Host: example.com\r\nAccept: text/plain\r\n" + + headers = Headers() + assert bytes(headers) == b"" + + def test_setitem(self): + headers = Headers() + headers["Host"] = "example.com" + assert "Host" in headers + assert "host" in headers + assert headers["Host"] == "example.com" + + headers["host"] = "example.org" + assert "Host" in headers + assert "host" in headers + assert headers["Host"] == "example.org" + + headers["accept"] = "text/plain" + assert len(headers) == 2 + assert "Accept" in headers + assert "Host" in headers + + headers = self._2host() + assert len(headers.fields) == 2 + headers["Host"] = "example.com" + assert len(headers.fields) == 1 + assert "Host" in headers + + def test_delitem(self): + headers = Headers(Host="example.com") + assert len(headers) == 1 + del headers["host"] + assert len(headers) == 0 + try: + del headers["host"] + except KeyError: + assert True + else: + assert False + + headers = self._2host() + del headers["Host"] + assert len(headers) == 0 + + def test_keys(self): + headers = Headers(Host="example.com") + assert list(headers.keys()) == ["Host"] + + headers = self._2host() + assert list(headers.keys()) == ["Host"] + + def test_eq_ne(self): + headers1 = Headers(Host="example.com") + headers2 = Headers(host="example.com") + assert not (headers1 == headers2) + assert headers1 != headers2 + + headers1 = Headers(Host="example.com") + headers2 = Headers(Host="example.com") + assert headers1 == headers2 + assert not (headers1 != headers2) + + assert headers1 != 42 + + def test_get_all(self): + headers = self._2host() + assert headers.get_all("host") == ["example.com", "example.org"] + assert headers.get_all("accept") == [] + + def test_set_all(self): + headers = Headers(Host="example.com") + headers.set_all("Accept", ["text/plain"]) + assert len(headers) == 2 + assert "accept" in headers + + headers = self._2host() + headers.set_all("Host", ["example.org"]) + assert headers["host"] == "example.org" + + headers.set_all("Host", ["example.org", "example.net"]) + assert headers["host"] == "example.org, example.net" + + def test_state(self): + headers = self._2host() + assert len(headers.get_state()) == 2 + assert headers == Headers.from_state(headers.get_state()) + + headers2 = Headers() + assert headers != headers2 + headers2.load_state(headers.get_state()) + assert headers == headers2 diff --git a/test/http/test_models.py b/test/http/test_models.py index d420b22b..10e0795a 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -58,20 +58,20 @@ class TestRequest(object): req = tutils.treq() req.headers["Accept-Encoding"] = "foobar" req.anticomp() - assert req.headers["Accept-Encoding"] == b"identity" + assert req.headers["Accept-Encoding"] == "identity" def test_constrain_encoding(self): req = tutils.treq() req.headers["Accept-Encoding"] = "identity, gzip, foo" req.constrain_encoding() - assert b"foo" not in req.headers["Accept-Encoding"] + assert "foo" not in req.headers["Accept-Encoding"] def test_update_host(self): req = tutils.treq() req.headers["Host"] = "" req.host = "foobar" req.update_host_header() - assert req.headers["Host"] == b"foobar" + assert req.headers["Host"] == "foobar" def test_get_form(self): req = tutils.treq() @@ -393,149 +393,3 @@ class TestResponse(object): v = resp.get_cookies() assert len(v) == 1 assert v["foo"] == [["bar", ODictCaseless()]] - - -class TestHeaders(object): - def _2host(self): - return Headers( - [ - [b"Host", b"example.com"], - [b"host", b"example.org"] - ] - ) - - def test_init(self): - headers = Headers() - assert len(headers) == 0 - - headers = Headers([[b"Host", b"example.com"]]) - assert len(headers) == 1 - assert headers["Host"] == b"example.com" - - headers = Headers(Host="example.com") - assert len(headers) == 1 - assert headers["Host"] == b"example.com" - - headers = Headers( - [[b"Host", b"invalid"]], - Host="example.com" - ) - assert len(headers) == 1 - assert headers["Host"] == b"example.com" - - headers = Headers( - [[b"Host", b"invalid"], [b"Accept", b"text/plain"]], - Host="example.com" - ) - assert len(headers) == 2 - assert headers["Host"] == b"example.com" - assert headers["Accept"] == b"text/plain" - - def test_getitem(self): - headers = Headers(Host="example.com") - assert headers["Host"] == b"example.com" - assert headers["host"] == b"example.com" - tutils.raises(KeyError, headers.__getitem__, "Accept") - - headers = self._2host() - assert headers["Host"] == b"example.com, example.org" - - def test_str(self): - headers = Headers(Host="example.com") - assert bytes(headers) == b"Host: example.com\r\n" - - headers = Headers([ - [b"Host", b"example.com"], - [b"Accept", b"text/plain"] - ]) - assert bytes(headers) == b"Host: example.com\r\nAccept: text/plain\r\n" - - headers = Headers() - assert bytes(headers) == b"" - - def test_setitem(self): - headers = Headers() - headers["Host"] = "example.com" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == b"example.com" - - headers["host"] = "example.org" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == b"example.org" - - headers["accept"] = "text/plain" - assert len(headers) == 2 - assert "Accept" in headers - assert "Host" in headers - - headers = self._2host() - assert len(headers.fields) == 2 - headers["Host"] = "example.com" - assert len(headers.fields) == 1 - assert "Host" in headers - - def test_delitem(self): - headers = Headers(Host="example.com") - assert len(headers) == 1 - del headers["host"] - assert len(headers) == 0 - try: - del headers["host"] - except KeyError: - assert True - else: - assert False - - headers = self._2host() - del headers["Host"] - assert len(headers) == 0 - - def test_keys(self): - headers = Headers(Host="example.com") - assert list(headers.keys()) == [b"Host"] - - headers = self._2host() - assert list(headers.keys()) == [b"Host"] - - def test_eq_ne(self): - headers1 = Headers(Host="example.com") - headers2 = Headers(host="example.com") - assert not (headers1 == headers2) - assert headers1 != headers2 - - headers1 = Headers(Host="example.com") - headers2 = Headers(Host="example.com") - assert headers1 == headers2 - assert not (headers1 != headers2) - - assert headers1 != 42 - - def test_get_all(self): - headers = self._2host() - assert headers.get_all("host") == [b"example.com", b"example.org"] - assert headers.get_all("accept") == [] - - def test_set_all(self): - headers = Headers(Host="example.com") - headers.set_all("Accept", ["text/plain"]) - assert len(headers) == 2 - assert "accept" in headers - - headers = self._2host() - headers.set_all("Host", ["example.org"]) - assert headers["host"] == b"example.org" - - headers.set_all("Host", ["example.org", "example.net"]) - assert headers["host"] == b"example.org, example.net" - - def test_state(self): - headers = self._2host() - assert len(headers.get_state()) == 2 - assert headers == Headers.from_state(headers.get_state()) - - headers2 = Headers() - assert headers != headers2 - headers2.load_state(headers.get_state()) - assert headers == headers2 diff --git a/test/test_utils.py b/test/test_utils.py index 8f4b4059..17636cc4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -103,17 +103,17 @@ def test_get_header_tokens(): headers = Headers() assert utils.get_header_tokens(headers, "foo") == [] headers["foo"] = "bar" - assert utils.get_header_tokens(headers, "foo") == [b"bar"] + assert utils.get_header_tokens(headers, "foo") == ["bar"] headers["foo"] = "bar, voing" - assert utils.get_header_tokens(headers, "foo") == [b"bar", b"voing"] + assert utils.get_header_tokens(headers, "foo") == ["bar", "voing"] headers.set_all("foo", ["bar, voing", "oink"]) - assert utils.get_header_tokens(headers, "foo") == [b"bar", b"voing", b"oink"] + assert utils.get_header_tokens(headers, "foo") == ["bar", "voing", "oink"] def test_multipartdecode(): - boundary = b'somefancyboundary' + boundary = 'somefancyboundary' headers = Headers( - content_type=b'multipart/form-data; boundary=' + boundary + content_type='multipart/form-data; boundary=' + boundary ) content = ( "--{0}\n" @@ -122,7 +122,7 @@ def test_multipartdecode(): "--{0}\n" "Content-Disposition: form-data; name=\"field2\"\n\n" "value2\n" - "--{0}--".format(boundary.decode()).encode() + "--{0}--".format(boundary).encode() ) form = utils.multipartdecode(headers, content) @@ -134,8 +134,8 @@ def test_multipartdecode(): def test_parse_content_type(): p = utils.parse_content_type - assert p(b"text/html") == (b"text", b"html", {}) - assert p(b"text") is None + assert p("text/html") == ("text", "html", {}) + assert p("text") is None - v = p(b"text/html; charset=UTF-8") - assert v == (b'text', b'html', {b'charset': b'UTF-8'}) + v = p("text/html; charset=UTF-8") + assert v == ('text', 'html', {'charset': 'UTF-8'}) diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 48acc2d6..4ae4cf45 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -64,15 +64,14 @@ class WebSocketsClient(tcp.TCPClient): preamble = b'GET / HTTP/1.1' self.wfile.write(preamble + b"\r\n") headers = self.protocol.client_handshake_headers() - self.client_nonce = headers["sec-websocket-key"] + self.client_nonce = headers["sec-websocket-key"].encode("ascii") self.wfile.write(bytes(headers) + b"\r\n") self.wfile.flush() resp = read_response(self.rfile, treq(method="GET")) server_nonce = self.protocol.check_server_handshake(resp.headers) - if not server_nonce == self.protocol.create_server_nonce( - self.client_nonce): + if not server_nonce == self.protocol.create_server_nonce(self.client_nonce): self.close() def read_next_message(self): @@ -207,14 +206,14 @@ class TestFrameHeader: fin=True, payload_length=10 ) - assert f.human_readable() + assert repr(f) f = websockets.FrameHeader() - assert f.human_readable() + assert repr(f) def test_funky(self): f = websockets.FrameHeader(masking_key=b"test", mask=False) - bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) + raw = bytes(f) + f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) assert not f2.mask def test_violations(self): -- cgit v1.2.3 From c7b83225001505b32905376703ec7ddaf200af44 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 22 Sep 2015 01:56:09 +0200 Subject: also accept bytes as arguments --- netlib/http/headers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 1511ea2d..613beb4f 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -14,16 +14,16 @@ except ImportError: # Workaround for Python < 3.3 import six -from netlib.utils import always_byte_args +from netlib.utils import always_byte_args, always_bytes if six.PY2: _native = lambda x: x - _asbytes = lambda x: x + _always_bytes = lambda x: x _always_byte_args = lambda x: x else: # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. _native = lambda x: x.decode("utf-8", "surrogateescape") - _asbytes = lambda x: x.encode("utf-8", "surrogateescape") + _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") _always_byte_args = always_byte_args("utf-8", "surrogateescape") @@ -95,9 +95,9 @@ class Headers(MutableMapping, object): # content_type -> content-type headers = { - _asbytes(name).replace(b"_", b"-"): value + _always_bytes(name).replace(b"_", b"-"): value for name, value in six.iteritems(headers) - } + } self.update(headers) def __bytes__(self): @@ -183,7 +183,7 @@ class Headers(MutableMapping, object): Explicitly set multiple headers for the given key. See: :py:meth:`get_all` """ - values = map(_asbytes, values) # _always_byte_args does not fix lists + values = map(_always_bytes, values) # _always_byte_args does not fix lists if name in self: del self[name] self.fields.extend( -- cgit v1.2.3 From 45f2ea33b2fdb67ca89e7eedd860ebe683770497 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 25 Sep 2015 18:24:18 +0200 Subject: minor fixes --- netlib/utils.py | 2 +- netlib/websockets/protocol.py | 30 +++++++++++++----------------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/netlib/utils.py b/netlib/utils.py index d5b30128..6f6d1ea0 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -337,7 +337,7 @@ def multipartdecode(headers, content): """ Takes a multipart boundary encoded string and returns list of (key, value) tuples. """ - v = headers.get("Content-Type") + v = headers.get("content-type") if v: v = parse_content_type(v) if not v: diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index e62f8df6..1e95fa1c 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -25,10 +25,6 @@ from ..http import Headers websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" -HEADER_WEBSOCKET_KEY = 'sec-websocket-key' -HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept' -HEADER_WEBSOCKET_VERSION = 'sec-websocket-version' - class Masker(object): @@ -81,37 +77,37 @@ class WebsocketsProtocol(object): """ if not key: key = base64.b64encode(os.urandom(16)).decode('ascii') - return Headers(**{ - HEADER_WEBSOCKET_KEY: key, - HEADER_WEBSOCKET_VERSION: version, - "Connection": "Upgrade", - "Upgrade": "websocket", - }) + return Headers( + sec_websocket_key=key, + sec_websocket_version=version, + connection="Upgrade", + upgrade="websocket", + ) @classmethod def server_handshake_headers(self, key): """ The server response is a valid HTTP 101 response. """ - return Headers(**{ - HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key), - "connection": "Upgrade", - "upgrade": "websocket", - }) + return Headers( + sec_websocket_accept=self.create_server_nonce(key), + connection="Upgrade", + upgrade="websocket" + ) @classmethod def check_client_handshake(self, headers): if headers.get("upgrade") != "websocket": return - return headers.get(HEADER_WEBSOCKET_KEY) + return headers.get("sec-websocket-key") @classmethod def check_server_handshake(self, headers): if headers.get("upgrade") != "websocket": return - return headers.get(HEADER_WEBSOCKET_ACCEPT) + return headers.get("sec-websocket-accept") @classmethod -- cgit v1.2.3 From 106f7046d3862cb0e3cbb4f38335af0330b4e7e3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 26 Sep 2015 00:39:04 +0200 Subject: refactor request model --- netlib/http/__init__.py | 5 +- netlib/http/headers.py | 2 +- netlib/http/http1/assemble.py | 65 ++++--- netlib/http/http1/read.py | 8 +- netlib/http/message.py | 146 +++++++++++++++ netlib/http/models.py | 233 ------------------------ netlib/http/request.py | 351 +++++++++++++++++++++++++++++++++++++ netlib/http/response.py | 3 + netlib/tutils.py | 4 +- netlib/utils.py | 15 +- test/http/http1/test_assemble.py | 12 +- test/http/http1/test_read.py | 8 +- test/http/test_models.py | 75 +++----- test/http/test_request.py | 3 + test/http/test_response.py | 3 + test/test_utils.py | 8 +- test/websockets/test_websockets.py | 2 +- 17 files changed, 598 insertions(+), 345 deletions(-) create mode 100644 netlib/http/message.py create mode 100644 netlib/http/request.py create mode 100644 netlib/http/response.py create mode 100644 test/http/test_request.py create mode 100644 test/http/test_response.py diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 0ccf6b32..e8c7ba20 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,12 +1,15 @@ from __future__ import absolute_import, print_function, division from .headers import Headers -from .models import Request, Response +from .message import decoded +from .request import Request +from .models import Response from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING from . import http1, http2 __all__ = [ "Headers", + "decoded", "Request", "Response", "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 613beb4f..47ea923b 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -27,7 +27,7 @@ else: _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping, object): +class Headers(MutableMapping): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 88aeac05..864f6017 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -7,24 +7,24 @@ from .. import CONTENT_MISSING def assemble_request(request): - if request.body == CONTENT_MISSING: + if request.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_request_head(request) - body = b"".join(assemble_body(request.headers, [request.body])) + body = b"".join(assemble_body(request.headers, [request.data.content])) return head + body def assemble_request_head(request): - first_line = _assemble_request_line(request) - headers = _assemble_request_headers(request) + first_line = _assemble_request_line(request.data) + headers = _assemble_request_headers(request.data) return b"%s\r\n%s\r\n" % (first_line, headers) def assemble_response(response): - if response.body == CONTENT_MISSING: + if response.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_response_head(response) - body = b"".join(assemble_body(response.headers, [response.body])) + body = b"".join(assemble_body(response.headers, [response.content])) return head + body @@ -45,42 +45,49 @@ def assemble_body(headers, body_chunks): yield chunk -def _assemble_request_line(request, form=None): - if form is None: - form = request.form_out +def _assemble_request_line(request_data): + """ + Args: + request_data (netlib.http.request.RequestData) + """ + form = request_data.first_line_format if form == "relative": return b"%s %s %s" % ( - request.method, - request.path, - request.http_version + request_data.method, + request_data.path, + request_data.http_version ) elif form == "authority": return b"%s %s:%d %s" % ( - request.method, - request.host, - request.port, - request.http_version + request_data.method, + request_data.host, + request_data.port, + request_data.http_version ) elif form == "absolute": return b"%s %s://%s:%d%s %s" % ( - request.method, - request.scheme, - request.host, - request.port, - request.path, - request.http_version + request_data.method, + request_data.scheme, + request_data.host, + request_data.port, + request_data.path, + request_data.http_version ) - else: # pragma: nocover + else: raise RuntimeError("Invalid request form") -def _assemble_request_headers(request): - headers = request.headers.copy() - if "host" not in headers and request.scheme and request.host and request.port: +def _assemble_request_headers(request_data): + """ + Args: + request_data (netlib.http.request.RequestData) + """ + headers = request_data.headers.copy() + if "host" not in headers and request_data.scheme and request_data.host and request_data.port: headers["host"] = utils.hostport( - request.scheme, - request.host, - request.port + request_data.scheme, + request_data.host, + request_data.port ) return bytes(headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 4c898348..76721e06 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -11,7 +11,7 @@ from .. import Request, Response, Headers def read_request(rfile, body_size_limit=None): request = read_request_head(rfile) expected_body_size = expected_http_body_size(request) - request._body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) + request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) request.timestamp_end = time.time() return request @@ -155,7 +155,7 @@ def connection_close(http_version, headers): # If we don't have a Connection header, HTTP 1.1 connections are assumed to # be persistent - return http_version != b"HTTP/1.1" + return http_version != "HTTP/1.1" and http_version != b"HTTP/1.1" # FIXME: Remove one case. def expected_http_body_size(request, response=None): @@ -184,11 +184,11 @@ def expected_http_body_size(request, response=None): if headers.get("expect", "").lower() == "100-continue": return 0 else: - if request.method.upper() == b"HEAD": + if request.method.upper() == "HEAD": return 0 if 100 <= response_code <= 199: return 0 - if response_code == 200 and request.method.upper() == b"CONNECT": + if response_code == 200 and request.method.upper() == "CONNECT": return 0 if response_code in (204, 304): return 0 diff --git a/netlib/http/message.py b/netlib/http/message.py new file mode 100644 index 00000000..20497bd5 --- /dev/null +++ b/netlib/http/message.py @@ -0,0 +1,146 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +import six + +from .. import encoding, utils + +if six.PY2: + _native = lambda x: x + _always_bytes = lambda x: x +else: + # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. + _native = lambda x: x.decode("utf-8", "surrogateescape") + _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") + + +class Message(object): + def __init__(self, data): + self.data = data + + def __eq__(self, other): + if isinstance(other, Message): + return self.data == other.data + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def http_version(self): + """ + Version string, e.g. "HTTP/1.1" + """ + return _native(self.data.http_version) + + @http_version.setter + def http_version(self, http_version): + self.data.http_version = _always_bytes(http_version) + + @property + def headers(self): + """ + Message headers object + + Returns: + netlib.http.Headers + """ + return self.data.headers + + @headers.setter + def headers(self, h): + self.data.headers = h + + @property + def timestamp_start(self): + """ + First byte timestamp + """ + return self.data.timestamp_start + + @timestamp_start.setter + def timestamp_start(self, timestamp_start): + self.data.timestamp_start = timestamp_start + + @property + def timestamp_end(self): + """ + Last byte timestamp + """ + return self.data.timestamp_end + + @timestamp_end.setter + def timestamp_end(self, timestamp_end): + self.data.timestamp_end = timestamp_end + + @property + def content(self): + """ + The raw (encoded) HTTP message body + + See also: :py:attr:`text` + """ + return self.data.content + + @content.setter + def content(self, content): + self.data.content = content + if isinstance(content, bytes): + self.headers["content-length"] = str(len(content)) + + @property + def text(self): + """ + The decoded HTTP message body. + Decoded contents are not cached, so this method is relatively expensive to call. + + See also: :py:attr:`content`, :py:class:`decoded` + """ + # This attribute should be called text, because that's what requests does. + raise NotImplementedError() + + @text.setter + def text(self, text): + raise NotImplementedError() + + @property + def body(self): + warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) + return self.content + + @body.setter + def body(self, body): + warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) + self.content = body + + +class decoded(object): + """ + A context manager that decodes a request or response, and then + re-encodes it with the same encoding after execution of the block. + + Example: + + .. code-block:: python + + with decoded(request): + request.content = request.content.replace("foo", "bar") + """ + + def __init__(self, message): + self.message = message + ce = message.headers.get("content-encoding") + if ce in encoding.ENCODINGS: + self.ce = ce + else: + self.ce = None + + def __enter__(self): + if self.ce: + if not self.message.decode(): + self.ce = None + + def __exit__(self, type, value, tb): + if self.ce: + self.message.encode(self.ce) \ No newline at end of file diff --git a/netlib/http/models.py b/netlib/http/models.py index 55664533..40f6e98c 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -47,239 +47,6 @@ class Message(object): return False -class Request(Message): - def __init__( - self, - form_in, - method, - scheme, - host, - port, - path, - http_version, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - form_out=None - ): - super(Request, self).__init__(http_version, headers, body, timestamp_start, timestamp_end) - - self.form_in = form_in - self.method = method - self.scheme = scheme - self.host = host - self.port = port - self.path = path - self.form_out = form_out or form_in - - def __repr__(self): - if self.host and self.port: - hostport = "{}:{}".format(native(self.host,"idna"), self.port) - else: - hostport = "" - path = self.path or "" - return "HTTPRequest({} {}{})".format( - self.method, hostport, path - ) - - def anticache(self): - """ - Modifies this request to remove headers that might produce a cached - response. That is, we remove ETags and If-Modified-Since headers. - """ - delheaders = [ - "if-modified-since", - "if-none-match", - ] - for i in delheaders: - self.headers.pop(i, None) - - def anticomp(self): - """ - Modifies this request to remove headers that will compress the - resource's data. - """ - self.headers["accept-encoding"] = "identity" - - def constrain_encoding(self): - """ - Limits the permissible Accept-Encoding values, based on what we can - decode appropriately. - """ - accept_encoding = self.headers.get("accept-encoding") - if accept_encoding: - self.headers["accept-encoding"] = ( - ', '.join( - e - for e in encoding.ENCODINGS - if e in accept_encoding - ) - ) - - def update_host_header(self): - """ - Update the host header to reflect the current target. - """ - self.headers["host"] = self.host - - def get_form(self): - """ - Retrieves the URL-encoded or multipart form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.body: - if HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): - return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): - return self.get_form_multipart() - return ODict([]) - - def get_form_urlencoded(self): - """ - Retrieves the URL-encoded form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): - return ODict(utils.urldecode(self.body)) - return ODict([]) - - def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): - return ODict( - utils.multipartdecode( - self.headers, - self.body)) - return ODict([]) - - def set_form_urlencoded(self, odict): - """ - Sets the body to the URL-encoded form data, and adds the - appropriate content-type header. Note that this will destory the - existing body if there is one. - """ - # FIXME: If there's an existing content-type header indicating a - # url-encoded form, leave it alone. - self.headers["content-type"] = HDR_FORM_URLENCODED - self.body = utils.urlencode(odict.lst) - - def get_path_components(self): - """ - Returns the path components of the URL as a list of strings. - - Components are unquoted. - """ - _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(native(i,"ascii")) for i in path.split(b"/") if i] - - def set_path_components(self, lst): - """ - Takes a list of strings, and sets the path component of the URL. - - Components are quoted. - """ - lst = [urllib.parse.quote(i, safe="") for i in lst] - path = always_bytes("/" + "/".join(lst)) - scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) - self.url = urllib.parse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def get_query(self): - """ - Gets the request query string. Returns an ODict object. - """ - _, _, _, _, query, _ = urllib.parse.urlparse(self.url) - if query: - return ODict(utils.urldecode(query)) - return ODict([]) - - def set_query(self, odict): - """ - Takes an ODict object, and sets the request query string. - """ - scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) - query = utils.urlencode(odict.lst) - self.url = urllib.parse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def pretty_host(self, hostheader): - """ - Heuristic to get the host of the request. - - Note that pretty_host() does not always return the TCP destination - of the request, e.g. if an upstream proxy is in place - - If hostheader is set to True, the Host: header will be used as - additional (and preferred) data source. This is handy in - transparent mode, where only the IO of the destination is known, - but not the resolved name. This is disabled by default, as an - attacker may spoof the host header to confuse an analyst. - """ - if hostheader and "host" in self.headers: - try: - return self.headers["host"] - except ValueError: - pass - if self.host: - return self.host.decode("idna") - - def pretty_url(self, hostheader): - if self.form_out == "authority": # upstream proxy mode - return b"%s:%d" % (always_bytes(self.pretty_host(hostheader)), self.port) - return utils.unparse_url(self.scheme, - self.pretty_host(hostheader), - self.port, - self.path) - - def get_cookies(self): - """ - Returns a possibly empty netlib.odict.ODict object. - """ - ret = ODict() - for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret - - def set_cookies(self, odict): - """ - Takes an netlib.odict.ODict object. Over-writes any existing Cookie - headers. - """ - v = cookies.format_cookie_header(odict) - self.headers["cookie"] = v - - @property - def url(self): - """ - Returns a URL string, constructed from the Request's URL components. - """ - return utils.unparse_url( - self.scheme, - self.host, - self.port, - self.path - ) - - @url.setter - def url(self, url): - """ - Parses a URL specification, and updates the Request's information - accordingly. - - Raises: - ValueError if the URL was invalid - """ - # TODO: Should handle incoming unicode here. - parts = utils.parse_url(url) - if not parts: - raise ValueError("Invalid URL: %s" % url) - self.scheme, self.host, self.port, self.path = parts - - class Response(Message): def __init__( self, diff --git a/netlib/http/request.py b/netlib/http/request.py new file mode 100644 index 00000000..6830ca40 --- /dev/null +++ b/netlib/http/request.py @@ -0,0 +1,351 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +import six +from six.moves import urllib + +from netlib import utils +from netlib.http import cookies +from netlib.odict import ODict +from .. import encoding +from .headers import Headers +from .message import Message, _native, _always_bytes + + +class RequestData(object): + def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, + timestamp_start=None, timestamp_end=None): + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + + self.first_line_format = first_line_format + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.http_version = http_version + self.headers = headers + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + def __eq__(self, other): + if isinstance(other, RequestData): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + +class Request(Message): + """ + An HTTP request. + """ + def __init__(self, *args, **kwargs): + data = RequestData(*args, **kwargs) + super(Request, self).__init__(data) + + def __repr__(self): + if self.host and self.port: + hostport = "{}:{}".format(self.host, self.port) + else: + hostport = "" + path = self.path or "" + return "HTTPRequest({} {}{})".format( + self.method, hostport, path + ) + + @property + def first_line_format(self): + """ + HTTP request form as defined in `RFC7230 `_. + + origin-form and asterisk-form are subsumed as "relative". + """ + return self.data.first_line_format + + @first_line_format.setter + def first_line_format(self, first_line_format): + self.data.first_line_format = first_line_format + + @property + def method(self): + """ + HTTP request method, e.g. "GET". + """ + return _native(self.data.method) + + @method.setter + def method(self, method): + self.data.method = _always_bytes(method) + + @property + def scheme(self): + """ + HTTP request scheme, which should be "http" or "https". + """ + return _native(self.data.scheme) + + @scheme.setter + def scheme(self, scheme): + self.data.scheme = _always_bytes(scheme) + + @property + def host(self): + """ + Target host for the request. This may be directly taken in the request (e.g. "GET http://example.com/ HTTP/1.1") + or inferred from the proxy mode (e.g. an IP in transparent mode). + """ + + if six.PY2: + return self.data.host + + if not self.data.host: + return self.data.host + try: + return self.data.host.decode("idna") + except UnicodeError: + return self.data.host.decode("utf8", "surrogateescape") + + @host.setter + def host(self, host): + if isinstance(host, six.text_type): + try: + # There's no non-strict mode for IDNA encoding. + # We don't want this operation to fail though, so we try + # utf8 as a last resort. + host = host.encode("idna", "strict") + except UnicodeError: + host = host.encode("utf8", "surrogateescape") + + self.data.host = host + + # Update host header + if "host" in self.headers: + if host: + self.headers["host"] = host + else: + self.headers.pop("host") + + @property + def port(self): + """ + Target port + """ + return self.data.port + + @port.setter + def port(self, port): + self.data.port = port + + @property + def path(self): + """ + HTTP request path, e.g. "/index.html". + Guaranteed to start with a slash. + """ + return _native(self.data.path) + + @path.setter + def path(self, path): + self.data.path = _always_bytes(path) + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + self.headers.pop(i, None) + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = "identity" + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + accept_encoding = self.headers.get("accept-encoding") + if accept_encoding: + self.headers["accept-encoding"] = ( + ', '.join( + e + for e in encoding.ENCODINGS + if e in accept_encoding + ) + ) + + @property + def urlencoded_form(self): + """ + The URL-encoded form data as an ODict object. + None if there is no data or the content-type indicates non-form data. + """ + is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() + if self.content and is_valid_content_type: + return ODict(utils.urldecode(self.content)) + return None + + @urlencoded_form.setter + def urlencoded_form(self, odict): + """ + Sets the body to the URL-encoded form data, and adds the appropriate content-type header. + This will overwrite the existing content if there is one. + """ + self.headers["content-type"] = "application/x-www-form-urlencoded" + self.content = utils.urlencode(odict.lst) + + @property + def multipart_form(self): + """ + The multipart form data as an ODict object. + None if there is no data or the content-type indicates non-form data. + """ + is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() + if self.content and is_valid_content_type: + return ODict(utils.multipartdecode(self.headers,self.content)) + return None + + @multipart_form.setter + def multipart_form(self): + raise NotImplementedError() + + @property + def path_components(self): + """ + The URL's path components as a list of strings. + Components are unquoted. + """ + _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + return [urllib.parse.unquote(i) for i in path.split("/") if i] + + @path_components.setter + def path_components(self, components): + components = map(lambda x: urllib.parse.quote(x, safe=""), components) + path = "/" + "/".join(components) + scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + + @property + def query(self): + """ + The request query string as an ODict object. + None, if there is no query. + """ + _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + if query: + return ODict(utils.urldecode(query)) + return None + + @query.setter + def query(self, odict): + query = utils.urlencode(odict.lst) + scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + + @property + def cookies(self): + """ + The request cookies. + An empty ODict object if the cookie monster ate them all. + """ + ret = ODict() + for i in self.headers.get_all("Cookie"): + ret.extend(cookies.parse_cookie_header(i)) + return ret + + @cookies.setter + def cookies(self, odict): + self.headers["cookie"] = cookies.format_cookie_header(odict) + + @property + def url(self): + """ + The URL string, constructed from the request's URL components + """ + return utils.unparse_url(self.scheme, self.host, self.port, self.path) + + @url.setter + def url(self, url): + self.scheme, self.host, self.port, self.path = utils.parse_url(url) + + @property + def pretty_host(self): + return self.headers.get("host", self.host) + + @property + def pretty_url(self): + if self.first_line_format == "authority": + return "%s:%d" % (self.pretty_host, self.port) + return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) + + # Legacy + + def get_cookies(self): + warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) + return self.cookies + + def set_cookies(self, odict): + warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) + self.cookies = odict + + def get_query(self): + warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) + return self.query or ODict([]) + + def set_query(self, odict): + warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) + self.query = odict + + def get_path_components(self): + warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) + return self.path_components + + def set_path_components(self, lst): + warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) + self.path_components = lst + + def get_form_urlencoded(self): + warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) + return self.urlencoded_form or ODict([]) + + def set_form_urlencoded(self, odict): + warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) + self.urlencoded_form = odict + + def get_form_multipart(self): + warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) + return self.multipart_form or ODict([]) + + @property + def form_in(self): + warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) + return self.first_line_format + + @form_in.setter + def form_in(self, form_in): + warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) + self.first_line_format = form_in + + @property + def form_out(self): + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + return self.first_line_format + + @form_out.setter + def form_out(self, form_out): + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + self.first_line_format = form_out \ No newline at end of file diff --git a/netlib/http/response.py b/netlib/http/response.py new file mode 100644 index 00000000..02fac3df --- /dev/null +++ b/netlib/http/response.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import, print_function, division + +# TODO \ No newline at end of file diff --git a/netlib/tutils.py b/netlib/tutils.py index 1665a792..ff63c33c 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -98,7 +98,7 @@ def treq(**kwargs): netlib.http.Request """ default = dict( - form_in="relative", + first_line_format="relative", method=b"GET", scheme=b"http", host=b"address", @@ -106,7 +106,7 @@ def treq(**kwargs): path=b"/path", http_version=b"HTTP/1.1", headers=Headers(header="qvalue"), - body=b"content" + content=b"content" ) default.update(kwargs) return Request(**default) diff --git a/netlib/utils.py b/netlib/utils.py index 6f6d1ea0..3ec60890 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -273,22 +273,27 @@ def get_header_tokens(headers, key): return [token.strip() for token in tokens] -@always_byte_args() def hostport(scheme, host, port): """ Returns the host component, with a port specifcation if needed. """ - if (port, scheme) in [(80, b"http"), (443, b"https")]: + if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: return host else: - return b"%s:%d" % (host, port) + if isinstance(host, six.binary_type): + return b"%s:%d" % (host, port) + else: + return "%s:%d" % (host, port) def unparse_url(scheme, host, port, path=""): """ - Returns a URL string, constructed from the specified compnents. + Returns a URL string, constructed from the specified components. + + Args: + All args must be str. """ - return b"%s://%s%s" % (scheme, hostport(scheme, host, port), path) + return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) def urlencode(s): diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index 963e7549..47d11d33 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -20,7 +20,7 @@ def test_assemble_request(): ) with raises(HttpException): - assemble_request(treq(body=CONTENT_MISSING)) + assemble_request(treq(content=CONTENT_MISSING)) def test_assemble_request_head(): @@ -62,21 +62,21 @@ def test_assemble_body(): def test_assemble_request_line(): - assert _assemble_request_line(treq()) == b"GET /path HTTP/1.1" + assert _assemble_request_line(treq().data) == b"GET /path HTTP/1.1" - authority_request = treq(method=b"CONNECT", form_in="authority") + authority_request = treq(method=b"CONNECT", first_line_format="authority").data assert _assemble_request_line(authority_request) == b"CONNECT address:22 HTTP/1.1" - absolute_request = treq(form_in="absolute") + absolute_request = treq(first_line_format="absolute").data assert _assemble_request_line(absolute_request) == b"GET http://address:22/path HTTP/1.1" with raises(RuntimeError): - _assemble_request_line(treq(), "invalid_form") + _assemble_request_line(treq(first_line_format="invalid_form").data) def test_assemble_request_headers(): # https://github.com/mitmproxy/mitmproxy/issues/186 - r = treq(body=b"") + r = treq(content=b"") r.headers["Transfer-Encoding"] = "chunked" c = _assemble_request_headers(r) assert b"Transfer-Encoding" in c diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index 9eb02a24..c3f744bf 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -16,8 +16,8 @@ from netlib.tutils import treq, tresp, raises def test_read_request(): rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip") r = read_request(rfile) - assert r.method == b"GET" - assert r.body == b"" + assert r.method == "GET" + assert r.content == b"" assert r.timestamp_end assert rfile.read() == b"skip" @@ -32,7 +32,7 @@ def test_read_request_head(): rfile.reset_timestamps = Mock() rfile.first_byte_timestamp = 42 r = read_request_head(rfile) - assert r.method == b"GET" + assert r.method == "GET" assert r.headers["Content-Length"] == "4" assert r.body is None assert rfile.reset_timestamps.called @@ -283,7 +283,7 @@ class TestReadHeaders(object): def test_read_chunked(): - req = treq(body=None) + req = treq(content=None) req.headers["Transfer-Encoding"] = "chunked" data = b"1\r\na\r\n0\r\n" diff --git a/test/http/test_models.py b/test/http/test_models.py index 10e0795a..3c196847 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -39,6 +39,7 @@ class TestRequest(object): a = tutils.treq(timestamp_start=42, timestamp_end=43) b = tutils.treq(timestamp_start=42, timestamp_end=43) assert a == b + assert not a != b assert not a == 'foo' assert not b == 'foo' @@ -70,45 +71,17 @@ class TestRequest(object): req = tutils.treq() req.headers["Host"] = "" req.host = "foobar" - req.update_host_header() assert req.headers["Host"] == "foobar" - def test_get_form(self): - req = tutils.treq() - assert req.get_form() == ODict() - - @mock.patch("netlib.http.Request.get_form_multipart") - @mock.patch("netlib.http.Request.get_form_urlencoded") - def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart): - req = tutils.treq() - assert req.get_form() == ODict() - - req = tutils.treq() - req.body = "foobar" - req.headers["Content-Type"] = HDR_FORM_URLENCODED - req.get_form() - assert req.get_form_urlencoded.called - assert not req.get_form_multipart.called - - @mock.patch("netlib.http.Request.get_form_multipart") - @mock.patch("netlib.http.Request.get_form_urlencoded") - def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart): - req = tutils.treq() - req.body = "foobar" - req.headers["Content-Type"] = HDR_FORM_MULTIPART - req.get_form() - assert not req.get_form_urlencoded.called - assert req.get_form_multipart.called - def test_get_form_urlencoded(self): - req = tutils.treq(body="foobar") + req = tutils.treq(content="foobar") assert req.get_form_urlencoded() == ODict() req.headers["Content-Type"] = HDR_FORM_URLENCODED assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body)) def test_get_form_multipart(self): - req = tutils.treq(body="foobar") + req = tutils.treq(content="foobar") assert req.get_form_multipart() == ODict() req.headers["Content-Type"] = HDR_FORM_MULTIPART @@ -140,7 +113,7 @@ class TestRequest(object): assert req.get_query().lst == [] req.url = "http://localhost:80/foo?bar=42" - assert req.get_query().lst == [(b"bar", b"42")] + assert req.get_query().lst == [("bar", "42")] def test_set_query(self): req = tutils.treq() @@ -148,31 +121,23 @@ class TestRequest(object): def test_pretty_host(self): r = tutils.treq() - assert r.pretty_host(True) == "address" - assert r.pretty_host(False) == "address" + assert r.pretty_host == "address" + assert r.host == "address" r.headers["host"] = "other" - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) == "address" + assert r.pretty_host == "other" + assert r.host == "address" r.host = None - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) is None - del r.headers["host"] - assert r.pretty_host(True) is None - assert r.pretty_host(False) is None + assert r.pretty_host is None + assert r.host is None # Invalid IDNA r.headers["host"] = ".disqus.com" - assert r.pretty_host(True) == ".disqus.com" + assert r.pretty_host == ".disqus.com" def test_pretty_url(self): - req = tutils.treq() - req.form_out = "authority" - assert req.pretty_url(True) == b"address:22" - assert req.pretty_url(False) == b"address:22" - - req.form_out = "relative" - assert req.pretty_url(True) == b"http://address:22/path" - assert req.pretty_url(False) == b"http://address:22/path" + req = tutils.treq(first_line_format="relative") + assert req.pretty_url == "http://address:22/path" + assert req.url == "http://address:22/path" def test_get_cookies_none(self): headers = Headers() @@ -212,12 +177,12 @@ class TestRequest(object): assert r.get_cookies()["cookiename"] == ["foo"] def test_set_url(self): - r = tutils.treq(form_in="absolute") + r = tutils.treq(first_line_format="absolute") r.url = b"https://otheraddress:42/ORLY" - assert r.scheme == b"https" - assert r.host == b"otheraddress" + assert r.scheme == "https" + assert r.host == "otheraddress" assert r.port == 42 - assert r.path == b"/ORLY" + assert r.path == "/ORLY" try: r.url = "//localhost:80/foo@bar" @@ -230,7 +195,7 @@ class TestRequest(object): # protocol = mock_protocol("OPTIONS * HTTP/1.1") # f.request = HTTPRequest.from_protocol(protocol) # - # assert f.request.form_in == "relative" + # assert f.request.first_line_format == "relative" # f.request.host = f.server_conn.address.host # f.request.port = f.server_conn.address.port # f.request.scheme = "http" @@ -266,7 +231,7 @@ class TestRequest(object): # "CONNECT address:22 HTTP/1.1\r\n" # "Host: address:22\r\n" # "Content-Length: 0\r\n\r\n") - # assert r.pretty_url(False) == "address:22" + # assert r.pretty_url == "address:22" # # def test_absolute_form_in(self): # protocol = mock_protocol("GET oops-no-protocol.com HTTP/1.1") diff --git a/test/http/test_request.py b/test/http/test_request.py new file mode 100644 index 00000000..02fac3df --- /dev/null +++ b/test/http/test_request.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import, print_function, division + +# TODO \ No newline at end of file diff --git a/test/http/test_response.py b/test/http/test_response.py new file mode 100644 index 00000000..02fac3df --- /dev/null +++ b/test/http/test_response.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import, print_function, division + +# TODO \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index 17636cc4..b096e5bc 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -84,10 +84,10 @@ def test_parse_url(): def test_unparse_url(): - assert utils.unparse_url(b"http", b"foo.com", 99, b"") == b"http://foo.com:99" - assert utils.unparse_url(b"http", b"foo.com", 80, b"/bar") == b"http://foo.com/bar" - assert utils.unparse_url(b"https", b"foo.com", 80, b"") == b"https://foo.com:80" - assert utils.unparse_url(b"https", b"foo.com", 443, b"") == b"https://foo.com" + assert utils.unparse_url("http", "foo.com", 99, "") == "http://foo.com:99" + assert utils.unparse_url("http", "foo.com", 80, "/bar") == "http://foo.com/bar" + assert utils.unparse_url("https", "foo.com", 80, "") == "https://foo.com:80" + assert utils.unparse_url("https", "foo.com", 443, "") == "https://foo.com" def test_urlencode(): diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 4ae4cf45..9a1e5d3d 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -68,7 +68,7 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(bytes(headers) + b"\r\n") self.wfile.flush() - resp = read_response(self.rfile, treq(method="GET")) + resp = read_response(self.rfile, treq(method=b"GET")) server_nonce = self.protocol.check_server_handshake(resp.headers) if not server_nonce == self.protocol.create_server_nonce(self.client_nonce): -- cgit v1.2.3 From 49ea8fc0ebcfe4861f099200044a553f092faec7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 26 Sep 2015 17:39:50 +0200 Subject: refactor response model --- netlib/http/__init__.py | 15 ++-- netlib/http/headers.py | 26 +++---- netlib/http/http1/assemble.py | 16 ++-- netlib/http/http1/read.py | 2 +- netlib/http/http2/connections.py | 4 +- netlib/http/http2/frame.py | 3 - netlib/http/message.py | 64 +++++++++------- netlib/http/models.py | 112 ---------------------------- netlib/http/request.py | 155 +++++++++++++++++++++------------------ netlib/http/response.py | 124 ++++++++++++++++++++++++++++++- netlib/tutils.py | 6 +- netlib/wsgi.py | 6 +- test/http/http1/test_assemble.py | 4 +- test/http/http1/test_read.py | 6 +- test/http/http2/test_protocol.py | 12 +-- test/http/test_models.py | 12 ++- 16 files changed, 293 insertions(+), 274 deletions(-) delete mode 100644 netlib/http/models.py diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index e8c7ba20..fd632cd5 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,17 +1,14 @@ from __future__ import absolute_import, print_function, division -from .headers import Headers -from .message import decoded from .request import Request -from .models import Response -from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 -from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING +from .response import Response +from .headers import Headers +from .message import decoded, CONTENT_MISSING from . import http1, http2 __all__ = [ + "Request", + "Response", "Headers", - "decoded", - "Request", "Response", - "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", - "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", + "decoded", "CONTENT_MISSING", "http1", "http2", ] diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 47ea923b..c79c3344 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -36,12 +36,8 @@ class Headers(MutableMapping): .. code-block:: python - # Create header from a list of (header_name, header_value) tuples - >>> h = Headers([ - ["Host","example.com"], - ["Accept","text/html"], - ["accept","application/xml"] - ]) + # Create headers with keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") # Headers mostly behave like a normal dict. >>> h["Host"] @@ -51,6 +47,13 @@ class Headers(MutableMapping): >>> h["host"] "example.com" + # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples + >>> h = Headers([ + [b"Host",b"example.com"], + [b"Accept",b"text/html"], + [b"accept",b"application/xml"] + ]) + # Multiple headers are folded into a single header as per RFC7230 >>> h["Accept"] "text/html, application/xml" @@ -60,17 +63,14 @@ class Headers(MutableMapping): >>> h["Accept"] "application/text" - # str(h) returns a HTTP1 header block. - >>> print(h) + # bytes(h) returns a HTTP1 header block. + >>> print(bytes(h)) Host: example.com Accept: application/text # For full control, the raw header fields can be accessed >>> h.fields - # Headers can also be crated from keyword arguments - >>> h = Headers(host="example.com", content_type="application/xml") - Caveats: For use with the "Set-Cookie" header, see :py:meth:`get_all`. """ @@ -79,8 +79,8 @@ class Headers(MutableMapping): def __init__(self, fields=None, **headers): """ Args: - fields: (optional) list of ``(name, value)`` header tuples, - e.g. ``[("Host","example.com")]``. All names and values must be bytes. + fields: (optional) list of ``(name, value)`` header byte tuples, + e.g. ``[(b"Host", b"example.com")]``. All names and values must be bytes. **headers: Additional headers to set. Will overwrite existing values from `fields`. For convenience, underscores in header names will be transformed to dashes - this behaviour does not extend to other methods. diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 864f6017..785ee8d3 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -10,7 +10,7 @@ def assemble_request(request): if request.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_request_head(request) - body = b"".join(assemble_body(request.headers, [request.data.content])) + body = b"".join(assemble_body(request.data.headers, [request.data.content])) return head + body @@ -24,13 +24,13 @@ def assemble_response(response): if response.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_response_head(response) - body = b"".join(assemble_body(response.headers, [response.content])) + body = b"".join(assemble_body(response.data.headers, [response.data.content])) return head + body def assemble_response_head(response): - first_line = _assemble_response_line(response) - headers = _assemble_response_headers(response) + first_line = _assemble_response_line(response.data) + headers = _assemble_response_headers(response.data) return b"%s\r\n%s\r\n" % (first_line, headers) @@ -92,11 +92,11 @@ def _assemble_request_headers(request_data): return bytes(headers) -def _assemble_response_line(response): +def _assemble_response_line(response_data): return b"%s %d %s" % ( - response.http_version, - response.status_code, - response.msg, + response_data.http_version, + response_data.status_code, + response_data.reason, ) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 76721e06..0d5e7f4b 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -50,7 +50,7 @@ def read_request_head(rfile): def read_response(rfile, request, body_size_limit=None): response = read_response_head(rfile) expected_body_size = expected_http_body_size(request, response) - response._body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) + response.data.content = b"".join(read_body(rfile, expected_body_size, body_size_limit)) response.timestamp_end = time.time() return response diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index 5220d5d2..c493abe6 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -4,7 +4,7 @@ import time from hpack.hpack import Encoder, Decoder from ... import utils -from .. import Headers, Response, Request, ALPN_PROTO_H2 +from .. import Headers, Response, Request from . import frame @@ -283,7 +283,7 @@ class HTTP2Protocol(object): def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != ALPN_PROTO_H2: + if alp != b'h2': raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index cb2cde99..188629d4 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -25,9 +25,6 @@ ERROR_CODES = BiDi( CLIENT_CONNECTION_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" -ALPN_PROTO_H2 = b'h2' - - class Frame(object): """ diff --git a/netlib/http/message.py b/netlib/http/message.py index 20497bd5..ee138746 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -6,11 +6,14 @@ import six from .. import encoding, utils + +CONTENT_MISSING = 0 + if six.PY2: _native = lambda x: x _always_bytes = lambda x: x else: - # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. + # While the HTTP head _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. _native = lambda x: x.decode("utf-8", "surrogateescape") _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") @@ -27,17 +30,6 @@ class Message(object): def __ne__(self, other): return not self.__eq__(other) - @property - def http_version(self): - """ - Version string, e.g. "HTTP/1.1" - """ - return _native(self.data.http_version) - - @http_version.setter - def http_version(self, http_version): - self.data.http_version = _always_bytes(http_version) - @property def headers(self): """ @@ -52,6 +44,32 @@ class Message(object): def headers(self, h): self.data.headers = h + @property + def content(self): + """ + The raw (encoded) HTTP message body + + See also: :py:attr:`text` + """ + return self.data.content + + @content.setter + def content(self, content): + self.data.content = content + if isinstance(content, bytes): + self.headers["content-length"] = str(len(content)) + + @property + def http_version(self): + """ + Version string, e.g. "HTTP/1.1" + """ + return _native(self.data.http_version) + + @http_version.setter + def http_version(self, http_version): + self.data.http_version = _always_bytes(http_version) + @property def timestamp_start(self): """ @@ -74,26 +92,14 @@ class Message(object): def timestamp_end(self, timestamp_end): self.data.timestamp_end = timestamp_end - @property - def content(self): - """ - The raw (encoded) HTTP message body - - See also: :py:attr:`text` - """ - return self.data.content - - @content.setter - def content(self, content): - self.data.content = content - if isinstance(content, bytes): - self.headers["content-length"] = str(len(content)) - @property def text(self): """ The decoded HTTP message body. - Decoded contents are not cached, so this method is relatively expensive to call. + Decoded contents are not cached, so accessing this attribute repeatedly is relatively expensive. + + .. note:: + This is not implemented yet. See also: :py:attr:`content`, :py:class:`decoded` """ @@ -104,6 +110,8 @@ class Message(object): def text(self, text): raise NotImplementedError() + # Legacy + @property def body(self): warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) diff --git a/netlib/http/models.py b/netlib/http/models.py deleted file mode 100644 index 40f6e98c..00000000 --- a/netlib/http/models.py +++ /dev/null @@ -1,112 +0,0 @@ - - -from ..odict import ODict -from .. import utils, encoding -from ..utils import always_bytes, native -from . import cookies -from .headers import Headers - -from six.moves import urllib - -# TODO: Move somewhere else? -ALPN_PROTO_HTTP1 = b'http/1.1' -ALPN_PROTO_H2 = b'h2' -HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" -HDR_FORM_MULTIPART = "multipart/form-data" - -CONTENT_MISSING = 0 - - -class Message(object): - def __init__(self, http_version, headers, body, timestamp_start, timestamp_end): - self.http_version = http_version - if not headers: - headers = Headers() - assert isinstance(headers, Headers) - self.headers = headers - - self._body = body - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - @property - def body(self): - return self._body - - @body.setter - def body(self, body): - self._body = body - if isinstance(body, bytes): - self.headers["content-length"] = str(len(body)).encode() - - content = body - - def __eq__(self, other): - if isinstance(other, Message): - return self.__dict__ == other.__dict__ - return False - - -class Response(Message): - def __init__( - self, - http_version, - status_code, - msg=None, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - ): - super(Response, self).__init__(http_version, headers, body, timestamp_start, timestamp_end) - self.status_code = status_code - self.msg = msg - - def __repr__(self): - # return "Response(%s - %s)" % (self.status_code, self.msg) - - if self.body: - size = utils.pretty_size(len(self.body)) - else: - size = "content missing" - # TODO: Remove "(unknown content type, content missing)" edge-case - return "".format( - status_code=self.status_code, - msg=self.msg, - contenttype=self.headers.get("content-type", "unknown content type"), - size=size) - - def get_cookies(self): - """ - Get the contents of all Set-Cookie headers. - - Returns a possibly empty ODict, where keys are cookie name strings, - and values are [value, attr] lists. Value is a string, and attr is - an ODictCaseless containing cookie attributes. Within attrs, unary - attributes (e.g. HTTPOnly) are indicated by a Null value. - """ - ret = [] - for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(header) - if v: - name, value, attrs = v - ret.append([name, [value, attrs]]) - return ODict(ret) - - def set_cookies(self, odict): - """ - Set the Set-Cookie headers on this response, over-writing existing - headers. - - Accepts an ODict of the same format as that returned by get_cookies. - """ - values = [] - for i in odict.lst: - values.append( - cookies.format_set_cookie_header( - i[0], - i[1][0], - i[1][1] - ) - ) - self.headers.set_all("set-cookie", values) diff --git a/netlib/http/request.py b/netlib/http/request.py index 6830ca40..f8a3b5b9 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -55,7 +55,7 @@ class Request(Message): else: hostport = "" path = self.path or "" - return "HTTPRequest({} {}{})".format( + return "Request({} {}{})".format( self.method, hostport, path ) @@ -97,7 +97,8 @@ class Request(Message): @property def host(self): """ - Target host for the request. This may be directly taken in the request (e.g. "GET http://example.com/ HTTP/1.1") + Target host. This may be parsed from the raw request + (e.g. from a ``GET http://example.com/ HTTP/1.1`` request line) or inferred from the proxy mode (e.g. an IP in transparent mode). """ @@ -154,6 +155,83 @@ class Request(Message): def path(self, path): self.data.path = _always_bytes(path) + @property + def url(self): + """ + The URL string, constructed from the request's URL components + """ + return utils.unparse_url(self.scheme, self.host, self.port, self.path) + + @url.setter + def url(self, url): + self.scheme, self.host, self.port, self.path = utils.parse_url(url) + + @property + def pretty_host(self): + """ + Similar to :py:attr:`host`, but using the Host headers as an additional preferred data source. + This is useful in transparent mode where :py:attr:`host` is only an IP address, + but may not reflect the actual destination as the Host header could be spoofed. + """ + return self.headers.get("host", self.host) + + @property + def pretty_url(self): + """ + Like :py:attr:`url`, but using :py:attr:`pretty_host` instead of :py:attr:`host`. + """ + if self.first_line_format == "authority": + return "%s:%d" % (self.pretty_host, self.port) + return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) + + @property + def query(self): + """ + The request query string as an :py:class:`ODict` object. + None, if there is no query. + """ + _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + if query: + return ODict(utils.urldecode(query)) + return None + + @query.setter + def query(self, odict): + query = utils.urlencode(odict.lst) + scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + + @property + def cookies(self): + """ + The request cookies. + An empty :py:class:`ODict` object if the cookie monster ate them all. + """ + ret = ODict() + for i in self.headers.get_all("Cookie"): + ret.extend(cookies.parse_cookie_header(i)) + return ret + + @cookies.setter + def cookies(self, odict): + self.headers["cookie"] = cookies.format_cookie_header(odict) + + @property + def path_components(self): + """ + The URL's path components as a list of strings. + Components are unquoted. + """ + _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + return [urllib.parse.unquote(i) for i in path.split("/") if i] + + @path_components.setter + def path_components(self, components): + components = map(lambda x: urllib.parse.quote(x, safe=""), components) + path = "/" + "/".join(components) + scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + def anticache(self): """ Modifies this request to remove headers that might produce a cached @@ -191,7 +269,7 @@ class Request(Message): @property def urlencoded_form(self): """ - The URL-encoded form data as an ODict object. + The URL-encoded form data as an :py:class:`ODict` object. None if there is no data or the content-type indicates non-form data. """ is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() @@ -211,7 +289,7 @@ class Request(Message): @property def multipart_form(self): """ - The multipart form data as an ODict object. + The multipart form data as an :py:class:`ODict` object. None if there is no data or the content-type indicates non-form data. """ is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() @@ -223,75 +301,6 @@ class Request(Message): def multipart_form(self): raise NotImplementedError() - @property - def path_components(self): - """ - The URL's path components as a list of strings. - Components are unquoted. - """ - _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(i) for i in path.split("/") if i] - - @path_components.setter - def path_components(self, components): - components = map(lambda x: urllib.parse.quote(x, safe=""), components) - path = "/" + "/".join(components) - scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) - self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) - - @property - def query(self): - """ - The request query string as an ODict object. - None, if there is no query. - """ - _, _, _, _, query, _ = urllib.parse.urlparse(self.url) - if query: - return ODict(utils.urldecode(query)) - return None - - @query.setter - def query(self, odict): - query = utils.urlencode(odict.lst) - scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) - self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) - - @property - def cookies(self): - """ - The request cookies. - An empty ODict object if the cookie monster ate them all. - """ - ret = ODict() - for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret - - @cookies.setter - def cookies(self, odict): - self.headers["cookie"] = cookies.format_cookie_header(odict) - - @property - def url(self): - """ - The URL string, constructed from the request's URL components - """ - return utils.unparse_url(self.scheme, self.host, self.port, self.path) - - @url.setter - def url(self, url): - self.scheme, self.host, self.port, self.path = utils.parse_url(url) - - @property - def pretty_host(self): - return self.headers.get("host", self.host) - - @property - def pretty_url(self): - if self.first_line_format == "authority": - return "%s:%d" % (self.pretty_host, self.port) - return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) - # Legacy def get_cookies(self): diff --git a/netlib/http/response.py b/netlib/http/response.py index 02fac3df..7d64243d 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,3 +1,125 @@ from __future__ import absolute_import, print_function, division -# TODO \ No newline at end of file +import warnings + +from . import cookies +from .headers import Headers +from .message import Message, _native, _always_bytes +from .. import utils +from ..odict import ODict + + +class ResponseData(object): + def __init__(self, http_version, status_code, reason=None, headers=None, content=None, + timestamp_start=None, timestamp_end=None): + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + + self.http_version = http_version + self.status_code = status_code + self.reason = reason + self.headers = headers + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + def __eq__(self, other): + if isinstance(other, ResponseData): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + +class Response(Message): + """ + An HTTP response. + """ + def __init__(self, *args, **kwargs): + data = ResponseData(*args, **kwargs) + super(Response, self).__init__(data) + + def __repr__(self): + if self.content: + details = "{}, {}".format( + self.headers.get("content-type", "unknown content type"), + utils.pretty_size(len(self.content)) + ) + else: + details = "content missing" + return "Response({status_code} {reason}, {details})".format( + status_code=self.status_code, + reason=self.reason, + details=details + ) + + @property + def status_code(self): + """ + HTTP Status Code, e.g. ``200``. + """ + return self.data.status_code + + @status_code.setter + def status_code(self, status_code): + self.data.status_code = status_code + + @property + def reason(self): + """ + HTTP Reason Phrase, e.g. "Not Found". + This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase. + """ + return _native(self.data.reason) + + @reason.setter + def reason(self, reason): + self.data.reason = _always_bytes(reason) + + @property + def cookies(self): + """ + Get the contents of all Set-Cookie headers. + + A possibly empty :py:class:`ODict`, where keys are cookie name strings, + and values are [value, attr] lists. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary + attributes (e.g. HTTPOnly) are indicated by a Null value. + """ + ret = [] + for header in self.headers.get_all("set-cookie"): + v = cookies.parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append([name, [value, attrs]]) + return ODict(ret) + + @cookies.setter + def cookies(self, odict): + values = [] + for i in odict.lst: + header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) + values.append(header) + self.headers.set_all("set-cookie", values) + + # Legacy + + def get_cookies(self): + warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) + return self.cookies + + def set_cookies(self, odict): + warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) + self.cookies = odict + + @property + def msg(self): + warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) + return self.reason + + @msg.setter + def msg(self, reason): + warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) + self.reason = reason diff --git a/netlib/tutils.py b/netlib/tutils.py index ff63c33c..e16f1a76 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -120,9 +120,9 @@ def tresp(**kwargs): default = dict( http_version=b"HTTP/1.1", status_code=200, - msg=b"OK", - headers=Headers(header_response=b"svalue"), - body=b"message", + reason=b"OK", + headers=Headers(header_response="svalue"), + content=b"message", timestamp_start=time.time(), timestamp_end=time.time(), ) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 4fcd5178..df248a19 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -25,9 +25,9 @@ class Flow(object): class Request(object): - def __init__(self, scheme, method, path, http_version, headers, body): + def __init__(self, scheme, method, path, http_version, headers, content): self.scheme, self.method, self.path = scheme, method, path - self.headers, self.body = headers, body + self.headers, self.content = headers, content self.http_version = http_version @@ -64,7 +64,7 @@ class WSGIAdaptor(object): environ = { 'wsgi.version': (1, 0), 'wsgi.url_scheme': native(flow.request.scheme, "latin-1"), - 'wsgi.input': BytesIO(flow.request.body or b""), + 'wsgi.input': BytesIO(flow.request.content or b""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index 47d11d33..460e22c5 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -40,7 +40,7 @@ def test_assemble_response(): ) with raises(HttpException): - assemble_response(tresp(body=CONTENT_MISSING)) + assemble_response(tresp(content=CONTENT_MISSING)) def test_assemble_response_head(): @@ -86,7 +86,7 @@ def test_assemble_request_headers(): def test_assemble_response_headers(): # https://github.com/mitmproxy/mitmproxy/issues/186 - r = tresp(body=b"") + r = tresp(content=b"") r.headers["Transfer-Encoding"] = "chunked" c = _assemble_response_headers(r) assert b"Transfer-Encoding" in c diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index c3f744bf..fadfe446 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -34,7 +34,7 @@ def test_read_request_head(): r = read_request_head(rfile) assert r.method == "GET" assert r.headers["Content-Length"] == "4" - assert r.body is None + assert r.content is None assert rfile.reset_timestamps.called assert r.timestamp_start == 42 assert rfile.read() == b"skip" @@ -45,7 +45,7 @@ def test_read_response(): rfile = BytesIO(b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody") r = read_response(rfile, req) assert r.status_code == 418 - assert r.body == b"body" + assert r.content == b"body" assert r.timestamp_end @@ -61,7 +61,7 @@ def test_read_response_head(): r = read_response_head(rfile) assert r.status_code == 418 assert r.headers["Content-Length"] == "4" - assert r.body is None + assert r.content is None assert rfile.reset_timestamps.called assert r.timestamp_start == 42 assert rfile.read() == b"skip" diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index a55941e0..6bda96f5 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -65,7 +65,7 @@ class TestProtocol: class TestCheckALPNMatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - alpn_select=ALPN_PROTO_H2, + alpn_select=b'h2', ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -73,7 +73,7 @@ class TestCheckALPNMatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[ALPN_PROTO_H2]) + c.convert_to_ssl(alpn_protos=[b'h2']) protocol = HTTP2Protocol(c) assert protocol.check_alpn() @@ -89,7 +89,7 @@ class TestCheckALPNMismatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[ALPN_PROTO_H2]) + c.convert_to_ssl(alpn_protos=[b'h2']) protocol = HTTP2Protocol(c) tutils.raises(NotImplementedError, protocol.check_alpn) @@ -311,7 +311,7 @@ class TestReadRequest(tservers.ServerTestBase): assert req.stream_id assert req.headers.fields == [[':method', 'GET'], [':path', '/'], [':scheme', 'https']] - assert req.body == b'foobar' + assert req.content == b'foobar' class TestReadRequestRelative(tservers.ServerTestBase): @@ -417,7 +417,7 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.status_code == 200 assert resp.msg == "" assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] - assert resp.body == b'foobar' + assert resp.content == b'foobar' assert resp.timestamp_end @@ -444,7 +444,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.status_code == 200 assert resp.msg == "" assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] - assert resp.body == b'' + assert resp.content == b'' class TestAssembleRequest(object): diff --git a/test/http/test_models.py b/test/http/test_models.py index 3c196847..aa267944 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -3,9 +3,7 @@ import mock from netlib import tutils from netlib import utils from netlib.odict import ODict, ODictCaseless -from netlib.http import Request, Response, Headers, CONTENT_MISSING, HDR_FORM_URLENCODED, \ - HDR_FORM_MULTIPART - +from netlib.http import Request, Response, Headers, CONTENT_MISSING class TestRequest(object): def test_repr(self): @@ -77,14 +75,14 @@ class TestRequest(object): req = tutils.treq(content="foobar") assert req.get_form_urlencoded() == ODict() - req.headers["Content-Type"] = HDR_FORM_URLENCODED + req.headers["Content-Type"] = "application/x-www-form-urlencoded" assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body)) def test_get_form_multipart(self): req = tutils.treq(content="foobar") assert req.get_form_multipart() == ODict() - req.headers["Content-Type"] = HDR_FORM_MULTIPART + req.headers["Content-Type"] = "multipart/form-data" assert req.get_form_multipart() == ODict( utils.multipartdecode( req.headers, @@ -95,7 +93,7 @@ class TestRequest(object): def test_set_form_urlencoded(self): req = tutils.treq() req.set_form_urlencoded(ODict([('foo', 'bar'), ('rab', 'oof')])) - assert req.headers["Content-Type"] == HDR_FORM_URLENCODED + assert req.headers["Content-Type"] == "application/x-www-form-urlencoded" assert req.body def test_get_path_components(self): @@ -298,7 +296,7 @@ class TestResponse(object): assert "unknown content type" in repr(r) r.headers["content-type"] = "foo" assert "foo" in repr(r) - assert repr(tutils.tresp(body=CONTENT_MISSING)) + assert repr(tutils.tresp(content=CONTENT_MISSING)) def test_get_cookies_none(self): resp = tutils.tresp() -- cgit v1.2.3 From 466888b01a361e46fb3d4e66afa2c6a0fd168c8e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 26 Sep 2015 20:07:11 +0200 Subject: improve request tests, coverage++ --- netlib/encoding.py | 4 + netlib/http/headers.py | 8 +- netlib/http/message.py | 42 ++++++- netlib/http/request.py | 28 ++--- netlib/http/response.py | 8 +- netlib/http/status_codes.py | 4 +- test/http/http1/test_read.py | 17 ++- test/http/test_headers.py | 3 + test/http/test_message.py | 136 +++++++++++++++++++++ test/http/test_models.py | 266 +---------------------------------------- test/http/test_request.py | 229 ++++++++++++++++++++++++++++++++++- test/http/test_status_codes.py | 6 + 12 files changed, 455 insertions(+), 296 deletions(-) create mode 100644 test/http/test_message.py create mode 100644 test/http/test_status_codes.py diff --git a/netlib/encoding.py b/netlib/encoding.py index 4c11273b..14479e00 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -12,6 +12,8 @@ ENCODINGS = {"identity", "gzip", "deflate"} def decode(e, content): + if not isinstance(content, bytes): + return None encoding_map = { "identity": identity, "gzip": decode_gzip, @@ -23,6 +25,8 @@ def decode(e, content): def encode(e, content): + if not isinstance(content, bytes): + return None encoding_map = { "identity": identity, "gzip": encode_gzip, diff --git a/netlib/http/headers.py b/netlib/http/headers.py index c79c3344..f64e6200 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -8,15 +8,15 @@ from __future__ import absolute_import, print_function, division import copy try: from collections.abc import MutableMapping -except ImportError: # Workaround for Python < 3.3 - from collections import MutableMapping +except ImportError: # pragma: nocover + from collections import MutableMapping # Workaround for Python < 3.3 import six from netlib.utils import always_byte_args, always_bytes -if six.PY2: +if six.PY2: # pragma: nocover _native = lambda x: x _always_bytes = lambda x: x _always_byte_args = lambda x: x @@ -106,7 +106,7 @@ class Headers(MutableMapping): else: return b"" - if six.PY2: + if six.PY2: # pragma: nocover __str__ = __bytes__ @_always_byte_args diff --git a/netlib/http/message.py b/netlib/http/message.py index ee138746..7cb18f52 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -9,7 +9,7 @@ from .. import encoding, utils CONTENT_MISSING = 0 -if six.PY2: +if six.PY2: # pragma: nocover _native = lambda x: x _always_bytes = lambda x: x else: @@ -110,15 +110,48 @@ class Message(object): def text(self, text): raise NotImplementedError() + def decode(self): + """ + Decodes body based on the current Content-Encoding header, then + removes the header. If there is no Content-Encoding header, no + action is taken. + + Returns: + True, if decoding succeeded. + False, otherwise. + """ + ce = self.headers.get("content-encoding") + data = encoding.decode(ce, self.content) + if data is None: + return False + self.content = data + self.headers.pop("content-encoding", None) + return True + + def encode(self, e): + """ + Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". + + Returns: + True, if decoding succeeded. + False, otherwise. + """ + data = encoding.encode(e, self.content) + if data is None: + return False + self.content = data + self.headers["content-encoding"] = e + return True + # Legacy @property - def body(self): + def body(self): # pragma: nocover warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) return self.content @body.setter - def body(self, body): + def body(self, body): # pragma: nocover warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) self.content = body @@ -146,8 +179,7 @@ class decoded(object): def __enter__(self): if self.ce: - if not self.message.decode(): - self.ce = None + self.message.decode() def __exit__(self, type, value, tb): if self.ce: diff --git a/netlib/http/request.py b/netlib/http/request.py index f8a3b5b9..325c0080 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -102,7 +102,7 @@ class Request(Message): or inferred from the proxy mode (e.g. an IP in transparent mode). """ - if six.PY2: + if six.PY2: # pragma: nocover return self.data.host if not self.data.host: @@ -303,58 +303,58 @@ class Request(Message): # Legacy - def get_cookies(self): + def get_cookies(self): # pragma: nocover warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) return self.cookies - def set_cookies(self, odict): + def set_cookies(self, odict): # pragma: nocover warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) self.cookies = odict - def get_query(self): + def get_query(self): # pragma: nocover warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) return self.query or ODict([]) - def set_query(self, odict): + def set_query(self, odict): # pragma: nocover warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) self.query = odict - def get_path_components(self): + def get_path_components(self): # pragma: nocover warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) return self.path_components - def set_path_components(self, lst): + def set_path_components(self, lst): # pragma: nocover warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) self.path_components = lst - def get_form_urlencoded(self): + def get_form_urlencoded(self): # pragma: nocover warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) return self.urlencoded_form or ODict([]) - def set_form_urlencoded(self, odict): + def set_form_urlencoded(self, odict): # pragma: nocover warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) self.urlencoded_form = odict - def get_form_multipart(self): + def get_form_multipart(self): # pragma: nocover warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) return self.multipart_form or ODict([]) @property - def form_in(self): + def form_in(self): # pragma: nocover warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) return self.first_line_format @form_in.setter - def form_in(self, form_in): + def form_in(self, form_in): # pragma: nocover warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) self.first_line_format = form_in @property - def form_out(self): + def form_out(self): # pragma: nocover warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) return self.first_line_format @form_out.setter - def form_out(self, form_out): + def form_out(self, form_out): # pragma: nocover warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) self.first_line_format = form_out \ No newline at end of file diff --git a/netlib/http/response.py b/netlib/http/response.py index 7d64243d..db31d2b9 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -106,20 +106,20 @@ class Response(Message): # Legacy - def get_cookies(self): + def get_cookies(self): # pragma: nocover warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) return self.cookies - def set_cookies(self, odict): + def set_cookies(self, odict): # pragma: nocover warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) self.cookies = odict @property - def msg(self): + def msg(self): # pragma: nocover warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) return self.reason @msg.setter - def msg(self, reason): + def msg(self, reason): # pragma: nocover warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) self.reason = reason diff --git a/netlib/http/status_codes.py b/netlib/http/status_codes.py index dc09f465..8a4dc1f5 100644 --- a/netlib/http/status_codes.py +++ b/netlib/http/status_codes.py @@ -1,4 +1,4 @@ -from __future__ import (absolute_import, print_function, division) +from __future__ import absolute_import, print_function, division CONTINUE = 100 SWITCHING = 101 @@ -37,6 +37,7 @@ REQUEST_URI_TOO_LONG = 414 UNSUPPORTED_MEDIA_TYPE = 415 REQUESTED_RANGE_NOT_SATISFIABLE = 416 EXPECTATION_FAILED = 417 +IM_A_TEAPOT = 418 INTERNAL_SERVER_ERROR = 500 NOT_IMPLEMENTED = 501 @@ -91,6 +92,7 @@ RESPONSES = { UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", EXPECTATION_FAILED: "Expectation Failed", + IM_A_TEAPOT: "I'm a teapot", # 500 INTERNAL_SERVER_ERROR: "Internal Server Error", diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index fadfe446..a0085db9 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -2,7 +2,7 @@ from __future__ import absolute_import, print_function, division from io import BytesIO import textwrap from mock import Mock -from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect +from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect, TcpDisconnect from netlib.http import Headers from netlib.http.http1.read import ( read_request, read_response, read_request_head, @@ -100,6 +100,11 @@ class TestReadBody(object): with raises(HttpException): b"".join(read_body(rfile, -1, 3)) + def test_max_chunk_size(self): + rfile = BytesIO(b"123456") + assert list(read_body(rfile, -1, max_chunk_size=None)) == [b"123456"] + rfile = BytesIO(b"123456") + assert list(read_body(rfile, -1, max_chunk_size=1)) == [b"1", b"2", b"3", b"4", b"5", b"6"] def test_connection_close(): headers = Headers() @@ -169,6 +174,11 @@ def test_get_first_line(): rfile = BytesIO(b"") _get_first_line(rfile) + with raises(HttpReadDisconnect): + rfile = Mock() + rfile.readline.side_effect = TcpDisconnect + _get_first_line(rfile) + with raises(HttpSyntaxException): rfile = BytesIO(b"GET /\xff HTTP/1.1") _get_first_line(rfile) @@ -191,7 +201,8 @@ def test_read_request_line(): t(b"GET / WTF/1.1") with raises(HttpSyntaxException): t(b"this is not http") - + with raises(HttpReadDisconnect): + t(b"") def test_parse_authority_form(): assert _parse_authority_form(b"foo:42") == (b"foo", 42) @@ -218,6 +229,8 @@ def test_read_response_line(): t(b"HTTP/1.1 OK OK") with raises(HttpSyntaxException): t(b"WTF/1.1 200 OK") + with raises(HttpReadDisconnect): + t(b"") def test_check_http_version(): diff --git a/test/http/test_headers.py b/test/http/test_headers.py index f1af1feb..8bddc0b2 100644 --- a/test/http/test_headers.py +++ b/test/http/test_headers.py @@ -38,6 +38,9 @@ class TestHeaders(object): assert headers["Host"] == "example.com" assert headers["Accept"] == "text/plain" + with raises(ValueError): + Headers([[b"Host", u"not-bytes"]]) + def test_getitem(self): headers = Headers(Host="example.com") assert headers["Host"] == "example.com" diff --git a/test/http/test_message.py b/test/http/test_message.py new file mode 100644 index 00000000..b0b7e27f --- /dev/null +++ b/test/http/test_message.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, division + +from netlib.http import decoded +from netlib.tutils import tresp + + +def _test_passthrough_attr(message, attr): + def t(self=None): + assert getattr(message, attr) == getattr(message.data, attr) + setattr(message, attr, "foo") + assert getattr(message.data, attr) == "foo" + return t + + +def _test_decoded_attr(message, attr): + def t(self=None): + assert getattr(message, attr) == getattr(message.data, attr).decode("utf8") + # Set str, get raw bytes + setattr(message, attr, "foo") + assert getattr(message.data, attr) == b"foo" + # Set raw bytes, get decoded + setattr(message.data, attr, b"bar") + assert getattr(message, attr) == "bar" + # Set bytes, get raw bytes + setattr(message, attr, b"baz") + assert getattr(message.data, attr) == b"baz" + + # Set UTF8 + setattr(message, attr, "Non-Autorisé") + assert getattr(message.data, attr) == b"Non-Autoris\xc3\xa9" + # Don't fail on garbage + setattr(message.data, attr, b"foo\xFF\x00bar") + assert getattr(message, attr).startswith("foo") + assert getattr(message, attr).endswith("bar") + # foo.bar = foo.bar should not cause any side effects. + d = getattr(message, attr) + setattr(message, attr, d) + assert getattr(message.data, attr) == b"foo\xFF\x00bar" + return t + + +class TestMessage(object): + + def test_init(self): + resp = tresp() + assert resp.data + + def test_eq_ne(self): + resp = tresp(timestamp_start=42, timestamp_end=42) + same = tresp(timestamp_start=42, timestamp_end=42) + assert resp == same + assert not resp != same + + other = tresp(timestamp_start=0, timestamp_end=0) + assert not resp == other + assert resp != other + + assert resp != 0 + + def test_content_length_update(self): + resp = tresp() + resp.content = b"foo" + assert resp.data.content == b"foo" + assert resp.headers["content-length"] == "3" + resp.content = b"" + assert resp.data.content == b"" + assert resp.headers["content-length"] == "0" + + test_content_basic = _test_passthrough_attr(tresp(), "content") + test_headers = _test_passthrough_attr(tresp(), "headers") + test_timestamp_start = _test_passthrough_attr(tresp(), "timestamp_start") + test_timestamp_end = _test_passthrough_attr(tresp(), "timestamp_end") + + test_http_version = _test_decoded_attr(tresp(), "http_version") + + +class TestDecodedDecorator(object): + + def test_simple(self): + r = tresp() + assert r.content == b"message" + assert "content-encoding" not in r.headers + assert r.encode("gzip") + + assert r.headers["content-encoding"] + assert r.content != b"message" + with decoded(r): + assert "content-encoding" not in r.headers + assert r.content == b"message" + assert r.headers["content-encoding"] + assert r.content != b"message" + + def test_modify(self): + r = tresp() + assert "content-encoding" not in r.headers + assert r.encode("gzip") + + with decoded(r): + r.content = b"foo" + + assert r.content != b"foo" + r.decode() + assert r.content == b"foo" + + def test_unknown_ce(self): + r = tresp() + r.headers["content-encoding"] = "zopfli" + r.content = b"foo" + with decoded(r): + assert r.headers["content-encoding"] + assert r.content == b"foo" + assert r.headers["content-encoding"] + assert r.content == b"foo" + + def test_cannot_decode(self): + r = tresp() + assert r.encode("gzip") + r.content = b"foo" + with decoded(r): + assert r.headers["content-encoding"] + assert r.content == b"foo" + assert r.headers["content-encoding"] + assert r.content != b"foo" + r.decode() + assert r.content == b"foo" + + def test_cannot_encode(self): + r = tresp() + assert r.encode("gzip") + with decoded(r): + r.content = None + + assert "content-encoding" not in r.headers + assert r.content is None + diff --git a/test/http/test_models.py b/test/http/test_models.py index aa267944..76a05446 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -1,271 +1,7 @@ -import mock from netlib import tutils -from netlib import utils from netlib.odict import ODict, ODictCaseless -from netlib.http import Request, Response, Headers, CONTENT_MISSING - -class TestRequest(object): - def test_repr(self): - r = tutils.treq() - assert repr(r) - - def test_headers(self): - tutils.raises(AssertionError, Request, - 'form_in', - 'method', - 'scheme', - 'host', - 'port', - 'path', - b"HTTP/1.1", - 'foobar', - ) - - req = Request( - 'form_in', - 'method', - 'scheme', - 'host', - 'port', - 'path', - b"HTTP/1.1", - ) - assert isinstance(req.headers, Headers) - - def test_equal(self): - a = tutils.treq(timestamp_start=42, timestamp_end=43) - b = tutils.treq(timestamp_start=42, timestamp_end=43) - assert a == b - assert not a != b - - assert not a == 'foo' - assert not b == 'foo' - assert not 'foo' == a - assert not 'foo' == b - - - def test_anticache(self): - req = tutils.treq() - req.headers["If-Modified-Since"] = "foo" - req.headers["If-None-Match"] = "bar" - req.anticache() - assert "If-Modified-Since" not in req.headers - assert "If-None-Match" not in req.headers - - def test_anticomp(self): - req = tutils.treq() - req.headers["Accept-Encoding"] = "foobar" - req.anticomp() - assert req.headers["Accept-Encoding"] == "identity" - - def test_constrain_encoding(self): - req = tutils.treq() - req.headers["Accept-Encoding"] = "identity, gzip, foo" - req.constrain_encoding() - assert "foo" not in req.headers["Accept-Encoding"] - - def test_update_host(self): - req = tutils.treq() - req.headers["Host"] = "" - req.host = "foobar" - assert req.headers["Host"] == "foobar" - - def test_get_form_urlencoded(self): - req = tutils.treq(content="foobar") - assert req.get_form_urlencoded() == ODict() - - req.headers["Content-Type"] = "application/x-www-form-urlencoded" - assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body)) - - def test_get_form_multipart(self): - req = tutils.treq(content="foobar") - assert req.get_form_multipart() == ODict() - - req.headers["Content-Type"] = "multipart/form-data" - assert req.get_form_multipart() == ODict( - utils.multipartdecode( - req.headers, - req.body - ) - ) - - def test_set_form_urlencoded(self): - req = tutils.treq() - req.set_form_urlencoded(ODict([('foo', 'bar'), ('rab', 'oof')])) - assert req.headers["Content-Type"] == "application/x-www-form-urlencoded" - assert req.body - - def test_get_path_components(self): - req = tutils.treq() - assert req.get_path_components() - # TODO: add meaningful assertions - - def test_set_path_components(self): - req = tutils.treq() - req.set_path_components([b"foo", b"bar"]) - # TODO: add meaningful assertions - - def test_get_query(self): - req = tutils.treq() - assert req.get_query().lst == [] - - req.url = "http://localhost:80/foo?bar=42" - assert req.get_query().lst == [("bar", "42")] - - def test_set_query(self): - req = tutils.treq() - req.set_query(ODict([])) - - def test_pretty_host(self): - r = tutils.treq() - assert r.pretty_host == "address" - assert r.host == "address" - r.headers["host"] = "other" - assert r.pretty_host == "other" - assert r.host == "address" - r.host = None - assert r.pretty_host is None - assert r.host is None - - # Invalid IDNA - r.headers["host"] = ".disqus.com" - assert r.pretty_host == ".disqus.com" - - def test_pretty_url(self): - req = tutils.treq(first_line_format="relative") - assert req.pretty_url == "http://address:22/path" - assert req.url == "http://address:22/path" - - def test_get_cookies_none(self): - headers = Headers() - r = tutils.treq() - r.headers = headers - assert len(r.get_cookies()) == 0 - - def test_get_cookies_single(self): - r = tutils.treq() - r.headers = Headers(cookie="cookiename=cookievalue") - result = r.get_cookies() - assert len(result) == 1 - assert result['cookiename'] == ['cookievalue'] - - def test_get_cookies_double(self): - r = tutils.treq() - r.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['cookievalue'] - assert result['othercookiename'] == ['othercookievalue'] - - def test_get_cookies_withequalsign(self): - r = tutils.treq() - r.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['coo=kievalue'] - assert result['othercookiename'] == ['othercookievalue'] - - def test_set_cookies(self): - r = tutils.treq() - r.headers = Headers(cookie="cookiename=cookievalue") - result = r.get_cookies() - result["cookiename"] = ["foo"] - r.set_cookies(result) - assert r.get_cookies()["cookiename"] == ["foo"] - - def test_set_url(self): - r = tutils.treq(first_line_format="absolute") - r.url = b"https://otheraddress:42/ORLY" - assert r.scheme == "https" - assert r.host == "otheraddress" - assert r.port == 42 - assert r.path == "/ORLY" - - try: - r.url = "//localhost:80/foo@bar" - assert False - except: - assert True - - # def test_asterisk_form_in(self): - # f = tutils.tflow(req=None) - # protocol = mock_protocol("OPTIONS * HTTP/1.1") - # f.request = HTTPRequest.from_protocol(protocol) - # - # assert f.request.first_line_format == "relative" - # f.request.host = f.server_conn.address.host - # f.request.port = f.server_conn.address.port - # f.request.scheme = "http" - # assert protocol.assemble(f.request) == ( - # "OPTIONS * HTTP/1.1\r\n" - # "Host: address:22\r\n" - # "Content-Length: 0\r\n\r\n") - # - # def test_relative_form_in(self): - # protocol = mock_protocol("GET /foo\xff HTTP/1.1") - # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) - # - # protocol = mock_protocol("GET /foo HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: h2c") - # r = HTTPRequest.from_protocol(protocol) - # assert r.headers["Upgrade"] == ["h2c"] - # - # def test_expect_header(self): - # protocol = mock_protocol( - # "GET / HTTP/1.1\r\nContent-Length: 3\r\nExpect: 100-continue\r\n\r\nfoobar") - # r = HTTPRequest.from_protocol(protocol) - # assert protocol.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" - # assert r.content == "foo" - # assert protocol.tcp_handler.rfile.read(3) == "bar" - # - # def test_authority_form_in(self): - # protocol = mock_protocol("CONNECT oops-no-port.com HTTP/1.1") - # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) - # - # protocol = mock_protocol("CONNECT address:22 HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # r.scheme, r.host, r.port = "http", "address", 22 - # assert protocol.assemble(r) == ( - # "CONNECT address:22 HTTP/1.1\r\n" - # "Host: address:22\r\n" - # "Content-Length: 0\r\n\r\n") - # assert r.pretty_url == "address:22" - # - # def test_absolute_form_in(self): - # protocol = mock_protocol("GET oops-no-protocol.com HTTP/1.1") - # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) - # - # protocol = mock_protocol("GET http://address:22/ HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # assert protocol.assemble(r) == ( - # "GET http://address:22/ HTTP/1.1\r\n" - # "Host: address:22\r\n" - # "Content-Length: 0\r\n\r\n") - # - # def test_http_options_relative_form_in(self): - # """ - # Exercises fix for Issue #392. - # """ - # protocol = mock_protocol("OPTIONS /secret/resource HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # r.host = 'address' - # r.port = 80 - # r.scheme = "http" - # assert protocol.assemble(r) == ( - # "OPTIONS /secret/resource HTTP/1.1\r\n" - # "Host: address\r\n" - # "Content-Length: 0\r\n\r\n") - # - # def test_http_options_absolute_form_in(self): - # protocol = mock_protocol("OPTIONS http://address/secret/resource HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # r.host = 'address' - # r.port = 80 - # r.scheme = "http" - # assert protocol.assemble(r) == ( - # "OPTIONS http://address:80/secret/resource HTTP/1.1\r\n" - # "Host: address\r\n" - # "Content-Length: 0\r\n\r\n") +from netlib.http import Response, Headers, CONTENT_MISSING class TestResponse(object): def test_headers(self): diff --git a/test/http/test_request.py b/test/http/test_request.py index 02fac3df..15bdd3e3 100644 --- a/test/http/test_request.py +++ b/test/http/test_request.py @@ -1,3 +1,230 @@ +# -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, division -# TODO \ No newline at end of file +import six + +from netlib import utils +from netlib.http import Headers +from netlib.odict import ODict +from netlib.tutils import treq, raises +from .test_message import _test_decoded_attr, _test_passthrough_attr + + +class TestRequestData(object): + def test_init(self): + with raises(AssertionError): + treq(headers="foobar") + + assert isinstance(treq(headers=None).headers, Headers) + + def test_eq_ne(self): + request_data = treq().data + same = treq().data + assert request_data == same + assert not request_data != same + + other = treq(content=b"foo").data + assert not request_data == other + assert request_data != other + + assert request_data != 0 + + +class TestRequestCore(object): + def test_repr(self): + request = treq() + assert repr(request) == "Request(GET address:22/path)" + request.host = None + assert repr(request) == "Request(GET /path)" + + test_first_line_format = _test_passthrough_attr(treq(), "first_line_format") + test_method = _test_decoded_attr(treq(), "method") + test_scheme = _test_decoded_attr(treq(), "scheme") + test_port = _test_passthrough_attr(treq(), "port") + test_path = _test_decoded_attr(treq(), "path") + + def test_host(self): + if six.PY2: + from unittest import SkipTest + raise SkipTest() + + request = treq() + assert request.host == request.data.host.decode("idna") + + # Test IDNA encoding + # Set str, get raw bytes + request.host = "ídna.example" + assert request.data.host == b"xn--dna-qma.example" + # Set raw bytes, get decoded + request.data.host = b"xn--idn-gla.example" + assert request.host == "idná.example" + # Set bytes, get raw bytes + request.host = b"xn--dn-qia9b.example" + assert request.data.host == b"xn--dn-qia9b.example" + # IDNA encoding is not bijective + request.host = "fußball" + assert request.host == "fussball" + + # Don't fail on garbage + request.data.host = b"foo\xFF\x00bar" + assert request.host.startswith("foo") + assert request.host.endswith("bar") + # foo.bar = foo.bar should not cause any side effects. + d = request.host + request.host = d + assert request.data.host == b"foo\xFF\x00bar" + + def test_host_header_update(self): + request = treq() + assert "host" not in request.headers + request.host = "example.com" + assert "host" not in request.headers + + request.headers["Host"] = "foo" + request.host = "example.org" + assert request.headers["Host"] == "example.org" + + +class TestRequestUtils(object): + def test_url(self): + request = treq() + assert request.url == "http://address:22/path" + + request.url = "https://otheraddress:42/foo" + assert request.scheme == "https" + assert request.host == "otheraddress" + assert request.port == 42 + assert request.path == "/foo" + + with raises(ValueError): + request.url = "not-a-url" + + def test_pretty_host(self): + request = treq() + assert request.pretty_host == "address" + assert request.host == "address" + request.headers["host"] = "other" + assert request.pretty_host == "other" + assert request.host == "address" + request.host = None + assert request.pretty_host is None + assert request.host is None + + # Invalid IDNA + request.headers["host"] = ".disqus.com" + assert request.pretty_host == ".disqus.com" + + def test_pretty_url(self): + request = treq() + assert request.url == "http://address:22/path" + assert request.pretty_url == "http://address:22/path" + request.headers["host"] = "other" + assert request.pretty_url == "http://other:22/path" + + def test_pretty_url_authority(self): + request = treq(first_line_format="authority") + assert request.pretty_url == "address:22" + + def test_get_query(self): + request = treq() + assert request.query is None + + request.url = "http://localhost:80/foo?bar=42" + assert request.query.lst == [("bar", "42")] + + def test_set_query(self): + request = treq() + request.query = ODict([]) + + def test_get_cookies_none(self): + request = treq() + request.headers = Headers() + assert len(request.cookies) == 0 + + def test_get_cookies_single(self): + request = treq() + request.headers = Headers(cookie="cookiename=cookievalue") + result = request.cookies + assert len(result) == 1 + assert result['cookiename'] == ['cookievalue'] + + def test_get_cookies_double(self): + request = treq() + request.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") + result = request.cookies + assert len(result) == 2 + assert result['cookiename'] == ['cookievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_get_cookies_withequalsign(self): + request = treq() + request.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") + result = request.cookies + assert len(result) == 2 + assert result['cookiename'] == ['coo=kievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_set_cookies(self): + request = treq() + request.headers = Headers(cookie="cookiename=cookievalue") + result = request.cookies + result["cookiename"] = ["foo"] + request.cookies = result + assert request.cookies["cookiename"] == ["foo"] + + def test_get_path_components(self): + request = treq(path=b"/foo/bar") + assert request.path_components == ["foo", "bar"] + + def test_set_path_components(self): + request = treq() + request.path_components = ["foo", "baz"] + assert request.path == "/foo/baz" + request.path_components = [] + assert request.path == "/" + + def test_anticache(self): + request = treq() + request.headers["If-Modified-Since"] = "foo" + request.headers["If-None-Match"] = "bar" + request.anticache() + assert "If-Modified-Since" not in request.headers + assert "If-None-Match" not in request.headers + + def test_anticomp(self): + request = treq() + request.headers["Accept-Encoding"] = "foobar" + request.anticomp() + assert request.headers["Accept-Encoding"] == "identity" + + def test_constrain_encoding(self): + request = treq() + request.headers["Accept-Encoding"] = "identity, gzip, foo" + request.constrain_encoding() + assert "foo" not in request.headers["Accept-Encoding"] + assert "gzip" in request.headers["Accept-Encoding"] + + def test_get_urlencoded_form(self): + request = treq(content="foobar") + assert request.urlencoded_form is None + + request.headers["Content-Type"] = "application/x-www-form-urlencoded" + assert request.urlencoded_form == ODict(utils.urldecode(request.content)) + + def test_set_urlencoded_form(self): + request = treq() + request.urlencoded_form = ODict([('foo', 'bar'), ('rab', 'oof')]) + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert request.content + + def test_get_multipart_form(self): + request = treq(content="foobar") + assert request.multipart_form is None + + request.headers["Content-Type"] = "multipart/form-data" + assert request.multipart_form == ODict( + utils.multipartdecode( + request.headers, + request.content + ) + ) diff --git a/test/http/test_status_codes.py b/test/http/test_status_codes.py new file mode 100644 index 00000000..9fea6b70 --- /dev/null +++ b/test/http/test_status_codes.py @@ -0,0 +1,6 @@ +from netlib.http import status_codes + + +def test_simple(): + assert status_codes.IM_A_TEAPOT == 418 + assert status_codes.RESPONSES[418] == "I'm a teapot" -- cgit v1.2.3 From 23d13e4c1282bc46c54222479c3b83032dad3335 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 27 Sep 2015 00:49:41 +0200 Subject: test response model, push coverage to 100% branch cov --- netlib/http/cookies.py | 1 + netlib/http/message.py | 10 ++++ netlib/http/request.py | 12 +---- netlib/http/response.py | 14 ++---- test/http/http1/test_assemble.py | 13 +++++- test/http/http1/test_read.py | 3 ++ test/http/test_cookies.py | 1 + test/http/test_message.py | 91 +++++++++++++++++++++--------------- test/http/test_models.py | 94 -------------------------------------- test/http/test_request.py | 42 ++++++++++------- test/http/test_response.py | 99 +++++++++++++++++++++++++++++++++++++++- 11 files changed, 208 insertions(+), 172 deletions(-) delete mode 100644 test/http/test_models.py diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index 78b03a83..18544b5e 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -58,6 +58,7 @@ def _read_quoted_string(s, start): escaping = False ret = [] # Skip the first quote + i = start # initialize in case the loop doesn't run. for i in range(start + 1, len(s)): if escaping: ret.append(s[i]) diff --git a/netlib/http/message.py b/netlib/http/message.py index 7cb18f52..e4e799ca 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -18,6 +18,16 @@ else: _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") +class MessageData(object): + def __eq__(self, other): + if isinstance(other, MessageData): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + class Message(object): def __init__(self, data): self.data = data diff --git a/netlib/http/request.py b/netlib/http/request.py index 325c0080..095b5945 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -10,10 +10,10 @@ from netlib.http import cookies from netlib.odict import ODict from .. import encoding from .headers import Headers -from .message import Message, _native, _always_bytes +from .message import Message, _native, _always_bytes, MessageData -class RequestData(object): +class RequestData(MessageData): def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, timestamp_start=None, timestamp_end=None): if not headers: @@ -32,14 +32,6 @@ class RequestData(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): - if isinstance(other, RequestData): - return self.__dict__ == other.__dict__ - return False - - def __ne__(self, other): - return not self.__eq__(other) - class Request(Message): """ diff --git a/netlib/http/response.py b/netlib/http/response.py index db31d2b9..66e5ded6 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -4,12 +4,12 @@ import warnings from . import cookies from .headers import Headers -from .message import Message, _native, _always_bytes +from .message import Message, _native, _always_bytes, MessageData from .. import utils from ..odict import ODict -class ResponseData(object): +class ResponseData(MessageData): def __init__(self, http_version, status_code, reason=None, headers=None, content=None, timestamp_start=None, timestamp_end=None): if not headers: @@ -24,14 +24,6 @@ class ResponseData(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): - if isinstance(other, ResponseData): - return self.__dict__ == other.__dict__ - return False - - def __ne__(self, other): - return not self.__eq__(other) - class Response(Message): """ @@ -48,7 +40,7 @@ class Response(Message): utils.pretty_size(len(self.content)) ) else: - details = "content missing" + details = "no content" return "Response({status_code} {reason}, {details})".format( status_code=self.status_code, reason=self.reason, diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index 460e22c5..ed94292d 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -78,10 +78,19 @@ def test_assemble_request_headers(): # https://github.com/mitmproxy/mitmproxy/issues/186 r = treq(content=b"") r.headers["Transfer-Encoding"] = "chunked" - c = _assemble_request_headers(r) + c = _assemble_request_headers(r.data) assert b"Transfer-Encoding" in c - assert b"host" in _assemble_request_headers(treq(headers=Headers())) + +def test_assemble_request_headers_host_header(): + r = treq() + r.headers = Headers() + c = _assemble_request_headers(r.data) + assert b"host" in c + + r.host = None + c = _assemble_request_headers(r.data) + assert b"host" not in c def test_assemble_response_headers(): diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index a0085db9..84a43f8b 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -117,6 +117,9 @@ def test_connection_close(): headers["connection"] = "close" assert connection_close(b"HTTP/1.1", headers) + headers["connection"] = "foobar" + assert connection_close(b"HTTP/1.0", headers) + assert not connection_close(b"HTTP/1.1", headers) def test_expected_http_body_size(): # Expect: 100-continue diff --git a/test/http/test_cookies.py b/test/http/test_cookies.py index 413b6241..34bb64f2 100644 --- a/test/http/test_cookies.py +++ b/test/http/test_cookies.py @@ -21,6 +21,7 @@ def test_read_quoted_string(): [(r'"f\\o" x', 0), (r"f\o", 6)], [(r'"f\\" x', 0), (r"f" + '\\', 5)], [('"fo\\\"" x', 0), ("fo\"", 6)], + [('"foo" x', 7), ("", 8)], ] for q, a in tokens: assert cookies._read_quoted_string(*q) == a diff --git a/test/http/test_message.py b/test/http/test_message.py index b0b7e27f..2c37dc3e 100644 --- a/test/http/test_message.py +++ b/test/http/test_message.py @@ -1,43 +1,53 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, division -from netlib.http import decoded -from netlib.tutils import tresp +from netlib.http import decoded, Headers +from netlib.tutils import tresp, raises def _test_passthrough_attr(message, attr): - def t(self=None): - assert getattr(message, attr) == getattr(message.data, attr) - setattr(message, attr, "foo") - assert getattr(message.data, attr) == "foo" - return t + assert getattr(message, attr) == getattr(message.data, attr) + setattr(message, attr, "foo") + assert getattr(message.data, attr) == "foo" def _test_decoded_attr(message, attr): - def t(self=None): - assert getattr(message, attr) == getattr(message.data, attr).decode("utf8") - # Set str, get raw bytes - setattr(message, attr, "foo") - assert getattr(message.data, attr) == b"foo" - # Set raw bytes, get decoded - setattr(message.data, attr, b"bar") - assert getattr(message, attr) == "bar" - # Set bytes, get raw bytes - setattr(message, attr, b"baz") - assert getattr(message.data, attr) == b"baz" - - # Set UTF8 - setattr(message, attr, "Non-Autorisé") - assert getattr(message.data, attr) == b"Non-Autoris\xc3\xa9" - # Don't fail on garbage - setattr(message.data, attr, b"foo\xFF\x00bar") - assert getattr(message, attr).startswith("foo") - assert getattr(message, attr).endswith("bar") - # foo.bar = foo.bar should not cause any side effects. - d = getattr(message, attr) - setattr(message, attr, d) - assert getattr(message.data, attr) == b"foo\xFF\x00bar" - return t + assert getattr(message, attr) == getattr(message.data, attr).decode("utf8") + # Set str, get raw bytes + setattr(message, attr, "foo") + assert getattr(message.data, attr) == b"foo" + # Set raw bytes, get decoded + setattr(message.data, attr, b"bar") + assert getattr(message, attr) == "bar" + # Set bytes, get raw bytes + setattr(message, attr, b"baz") + assert getattr(message.data, attr) == b"baz" + + # Set UTF8 + setattr(message, attr, "Non-Autorisé") + assert getattr(message.data, attr) == b"Non-Autoris\xc3\xa9" + # Don't fail on garbage + setattr(message.data, attr, b"foo\xFF\x00bar") + assert getattr(message, attr).startswith("foo") + assert getattr(message, attr).endswith("bar") + # foo.bar = foo.bar should not cause any side effects. + d = getattr(message, attr) + setattr(message, attr, d) + assert getattr(message.data, attr) == b"foo\xFF\x00bar" + + +class TestMessageData(object): + def test_eq_ne(self): + data = tresp(timestamp_start=42, timestamp_end=42).data + same = tresp(timestamp_start=42, timestamp_end=42).data + assert data == same + assert not data != same + + other = tresp(content=b"foo").data + assert not data == other + assert data != other + + assert data != 0 class TestMessage(object): @@ -67,12 +77,20 @@ class TestMessage(object): assert resp.data.content == b"" assert resp.headers["content-length"] == "0" - test_content_basic = _test_passthrough_attr(tresp(), "content") - test_headers = _test_passthrough_attr(tresp(), "headers") - test_timestamp_start = _test_passthrough_attr(tresp(), "timestamp_start") - test_timestamp_end = _test_passthrough_attr(tresp(), "timestamp_end") + def test_content_basic(self): + _test_passthrough_attr(tresp(), "content") + + def test_headers(self): + _test_passthrough_attr(tresp(), "headers") - test_http_version = _test_decoded_attr(tresp(), "http_version") + def test_timestamp_start(self): + _test_passthrough_attr(tresp(), "timestamp_start") + + def test_timestamp_end(self): + _test_passthrough_attr(tresp(), "timestamp_end") + + def teste_http_version(self): + _test_decoded_attr(tresp(), "http_version") class TestDecodedDecorator(object): @@ -133,4 +151,3 @@ class TestDecodedDecorator(object): assert "content-encoding" not in r.headers assert r.content is None - diff --git a/test/http/test_models.py b/test/http/test_models.py deleted file mode 100644 index 76a05446..00000000 --- a/test/http/test_models.py +++ /dev/null @@ -1,94 +0,0 @@ - -from netlib import tutils -from netlib.odict import ODict, ODictCaseless -from netlib.http import Response, Headers, CONTENT_MISSING - -class TestResponse(object): - def test_headers(self): - tutils.raises(AssertionError, Response, - b"HTTP/1.1", - 200, - headers='foobar', - ) - - resp = Response( - b"HTTP/1.1", - 200, - ) - assert isinstance(resp.headers, Headers) - - def test_equal(self): - a = tutils.tresp(timestamp_start=42, timestamp_end=43) - b = tutils.tresp(timestamp_start=42, timestamp_end=43) - assert a == b - - assert not a == 'foo' - assert not b == 'foo' - assert not 'foo' == a - assert not 'foo' == b - - def test_repr(self): - r = tutils.tresp() - assert "unknown content type" in repr(r) - r.headers["content-type"] = "foo" - assert "foo" in repr(r) - assert repr(tutils.tresp(content=CONTENT_MISSING)) - - def test_get_cookies_none(self): - resp = tutils.tresp() - resp.headers = Headers() - assert not resp.get_cookies() - - def test_get_cookies_simple(self): - resp = tutils.tresp() - resp.headers = Headers(set_cookie="cookiename=cookievalue") - result = resp.get_cookies() - assert len(result) == 1 - assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", ODict()] - - def test_get_cookies_with_parameters(self): - resp = tutils.tresp() - resp.headers = Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly") - result = resp.get_cookies() - assert len(result) == 1 - assert "cookiename" in result - assert result["cookiename"][0][0] == "cookievalue" - attrs = result["cookiename"][0][1] - assert len(attrs) == 4 - assert attrs["domain"] == ["example.com"] - assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"] - assert attrs["path"] == ["/"] - assert attrs["httponly"] == [None] - - def test_get_cookies_no_value(self): - resp = tutils.tresp() - resp.headers = Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") - result = resp.get_cookies() - assert len(result) == 1 - assert "cookiename" in result - assert result["cookiename"][0][0] == "" - assert len(result["cookiename"][0][1]) == 2 - - def test_get_cookies_twocookies(self): - resp = tutils.tresp() - resp.headers = Headers([ - [b"Set-Cookie", b"cookiename=cookievalue"], - [b"Set-Cookie", b"othercookie=othervalue"] - ]) - result = resp.get_cookies() - assert len(result) == 2 - assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", ODict()] - assert "othercookie" in result - assert result["othercookie"][0] == ["othervalue", ODict()] - - def test_set_cookies(self): - resp = tutils.tresp() - v = resp.get_cookies() - v.add("foo", ["bar", ODictCaseless()]) - resp.set_cookies(v) - - v = resp.get_cookies() - assert len(v) == 1 - assert v["foo"] == [["bar", ODictCaseless()]] diff --git a/test/http/test_request.py b/test/http/test_request.py index 15bdd3e3..8cf69ffe 100644 --- a/test/http/test_request.py +++ b/test/http/test_request.py @@ -17,31 +17,31 @@ class TestRequestData(object): assert isinstance(treq(headers=None).headers, Headers) - def test_eq_ne(self): - request_data = treq().data - same = treq().data - assert request_data == same - assert not request_data != same - - other = treq(content=b"foo").data - assert not request_data == other - assert request_data != other - - assert request_data != 0 - class TestRequestCore(object): + """ + Tests for builtins and the attributes that are directly proxied from the data structure + """ def test_repr(self): request = treq() assert repr(request) == "Request(GET address:22/path)" request.host = None assert repr(request) == "Request(GET /path)" - test_first_line_format = _test_passthrough_attr(treq(), "first_line_format") - test_method = _test_decoded_attr(treq(), "method") - test_scheme = _test_decoded_attr(treq(), "scheme") - test_port = _test_passthrough_attr(treq(), "port") - test_path = _test_decoded_attr(treq(), "path") + def test_first_line_format(self): + _test_passthrough_attr(treq(), "first_line_format") + + def test_method(self): + _test_decoded_attr(treq(), "method") + + def test_scheme(self): + _test_decoded_attr(treq(), "scheme") + + def test_port(self): + _test_passthrough_attr(treq(), "port") + + def test_path(self): + _test_decoded_attr(treq(), "path") def test_host(self): if six.PY2: @@ -86,6 +86,9 @@ class TestRequestCore(object): class TestRequestUtils(object): + """ + Tests for additional convenience methods. + """ def test_url(self): request = treq() assert request.url == "http://address:22/path" @@ -199,6 +202,11 @@ class TestRequestUtils(object): def test_constrain_encoding(self): request = treq() + + h = request.headers.copy() + request.constrain_encoding() # no-op if there is no accept_encoding header. + assert request.headers == h + request.headers["Accept-Encoding"] = "identity, gzip, foo" request.constrain_encoding() assert "foo" not in request.headers["Accept-Encoding"] diff --git a/test/http/test_response.py b/test/http/test_response.py index 02fac3df..a1f4abd7 100644 --- a/test/http/test_response.py +++ b/test/http/test_response.py @@ -1,3 +1,100 @@ from __future__ import absolute_import, print_function, division -# TODO \ No newline at end of file +from netlib.http import Headers +from netlib.odict import ODict, ODictCaseless +from netlib.tutils import raises, tresp +from .test_message import _test_passthrough_attr, _test_decoded_attr + + +class TestResponseData(object): + def test_init(self): + with raises(AssertionError): + tresp(headers="foobar") + + assert isinstance(tresp(headers=None).headers, Headers) + + +class TestResponseCore(object): + """ + Tests for builtins and the attributes that are directly proxied from the data structure + """ + def test_repr(self): + response = tresp() + assert repr(response) == "Response(200 OK, unknown content type, 7B)" + response.content = None + assert repr(response) == "Response(200 OK, no content)" + + def test_status_code(self): + _test_passthrough_attr(tresp(), "status_code") + + def test_reason(self): + _test_decoded_attr(tresp(), "reason") + + +class TestResponseUtils(object): + """ + Tests for additional convenience methods. + """ + def test_get_cookies_none(self): + resp = tresp() + resp.headers = Headers() + assert not resp.cookies + + def test_get_cookies_empty(self): + resp = tresp() + resp.headers = Headers(set_cookie="") + assert not resp.cookies + + def test_get_cookies_simple(self): + resp = tresp() + resp.headers = Headers(set_cookie="cookiename=cookievalue") + result = resp.cookies + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0] == ["cookievalue", ODict()] + + def test_get_cookies_with_parameters(self): + resp = tresp() + resp.headers = Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly") + result = resp.cookies + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0][0] == "cookievalue" + attrs = result["cookiename"][0][1] + assert len(attrs) == 4 + assert attrs["domain"] == ["example.com"] + assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"] + assert attrs["path"] == ["/"] + assert attrs["httponly"] == [None] + + def test_get_cookies_no_value(self): + resp = tresp() + resp.headers = Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") + result = resp.cookies + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0][0] == "" + assert len(result["cookiename"][0][1]) == 2 + + def test_get_cookies_twocookies(self): + resp = tresp() + resp.headers = Headers([ + [b"Set-Cookie", b"cookiename=cookievalue"], + [b"Set-Cookie", b"othercookie=othervalue"] + ]) + result = resp.cookies + assert len(result) == 2 + assert "cookiename" in result + assert result["cookiename"][0] == ["cookievalue", ODict()] + assert "othercookie" in result + assert result["othercookie"][0] == ["othervalue", ODict()] + + def test_set_cookies(self): + resp = tresp() + v = resp.cookies + v.add("foo", ["bar", ODictCaseless()]) + resp.set_cookies(v) + + v = resp.cookies + assert len(v) == 1 + assert v["foo"] == [["bar", ODictCaseless()]] -- cgit v1.2.3 From 87566da3babcc827e9dae0f2e9ab9154c353aa11 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 28 Sep 2015 11:18:00 +0200 Subject: fix mitmproxy/mitmproxy#784 --- netlib/http/http1/read.py | 5 ----- netlib/utils.py | 5 +++-- test/http/http1/test_read.py | 4 ++++ 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 4c898348..73c7deed 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -218,11 +218,6 @@ def _get_first_line(rfile): raise HttpReadDisconnect("Remote disconnected") if not line: raise HttpReadDisconnect("Remote disconnected") - line = line.strip() - try: - line.decode("ascii") - except ValueError: - raise HttpSyntaxException("Non-ascii characters in first line: {}".format(line)) return line.strip() diff --git a/netlib/utils.py b/netlib/utils.py index 6f6d1ea0..8b9548ed 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -237,8 +237,9 @@ def parse_url(url): if isinstance(url, six.binary_type): host = parsed.hostname - # this should not raise a ValueError - decode_parse_result(parsed, "ascii") + # this should not raise a ValueError, + # but we try to be very forgiving here and accept just everything. + # decode_parse_result(parsed, "ascii") else: host = parsed.hostname.encode("idna") parsed = encode_parse_result(parsed, "ascii") diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index 9eb02a24..36cf7e1d 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -211,6 +211,10 @@ def test_read_response_line(): assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK") assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"") + + # https://github.com/mitmproxy/mitmproxy/issues/784 + assert t(b"HTTP/1.1 200") == (b"HTTP/1.1 Non-Autoris\xc3\xa9", 200, b"") + with raises(HttpSyntaxException): assert t(b"HTTP/1.1") -- cgit v1.2.3 From 5261bcdf4b0976b8db3295292143282b34f10c51 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 28 Sep 2015 11:46:18 +0200 Subject: properly adjust tests for 87566da3ba --- test/http/http1/test_read.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index 36cf7e1d..98d31bc2 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -169,10 +169,6 @@ def test_get_first_line(): rfile = BytesIO(b"") _get_first_line(rfile) - with raises(HttpSyntaxException): - rfile = BytesIO(b"GET /\xff HTTP/1.1") - _get_first_line(rfile) - def test_read_request_line(): def t(b): @@ -213,7 +209,7 @@ def test_read_response_line(): assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"") # https://github.com/mitmproxy/mitmproxy/issues/784 - assert t(b"HTTP/1.1 200") == (b"HTTP/1.1 Non-Autoris\xc3\xa9", 200, b"") + assert t(b"HTTP/1.1 200 Non-Autoris\xc3\xa9") == (b"HTTP/1.1", 200, b"Non-Autoris\xc3\xa9") with raises(HttpSyntaxException): assert t(b"HTTP/1.1") -- cgit v1.2.3 From 2e1f7ecd558659191abb2cd300fddb82c53c31a3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 28 Sep 2015 14:04:25 +0200 Subject: fix tests --- test/http/test_message.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/http/test_message.py b/test/http/test_message.py index 2c37dc3e..4b1f4630 100644 --- a/test/http/test_message.py +++ b/test/http/test_message.py @@ -17,8 +17,8 @@ def _test_decoded_attr(message, attr): setattr(message, attr, "foo") assert getattr(message.data, attr) == b"foo" # Set raw bytes, get decoded - setattr(message.data, attr, b"bar") - assert getattr(message, attr) == "bar" + setattr(message.data, attr, b"BAR") # use uppercase so that we can also cover request.method + assert getattr(message, attr) == "BAR" # Set bytes, get raw bytes setattr(message, attr, b"baz") assert getattr(message.data, attr) == b"baz" @@ -27,13 +27,13 @@ def _test_decoded_attr(message, attr): setattr(message, attr, "Non-Autorisé") assert getattr(message.data, attr) == b"Non-Autoris\xc3\xa9" # Don't fail on garbage - setattr(message.data, attr, b"foo\xFF\x00bar") - assert getattr(message, attr).startswith("foo") - assert getattr(message, attr).endswith("bar") + setattr(message.data, attr, b"FOO\xFF\x00BAR") + assert getattr(message, attr).startswith("FOO") + assert getattr(message, attr).endswith("BAR") # foo.bar = foo.bar should not cause any side effects. d = getattr(message, attr) setattr(message, attr, d) - assert getattr(message.data, attr) == b"foo\xFF\x00bar" + assert getattr(message.data, attr) == b"FOO\xFF\x00BAR" class TestMessageData(object): -- cgit v1.2.3 From 267837f441dca41af495bca61140fd9d657bd02e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 16 Oct 2015 18:12:36 +0200 Subject: add test certificate generator --- test/data/verificationcerts/8117bdb9.0 | 15 ----- test/data/verificationcerts/9d45e6a9.0 | 15 ----- test/data/verificationcerts/9da13359.0 | 21 +++++++ test/data/verificationcerts/generate.py | 72 ++++++++++++++++++++++ test/data/verificationcerts/interm.key | 16 ----- test/data/verificationcerts/self-signed.crt | 18 ++++++ test/data/verificationcerts/self-signed.key | 27 ++++++++ test/data/verificationcerts/trusted-chain.crt | 35 ----------- test/data/verificationcerts/trusted-interm.crt | 19 ------ .../verificationcerts/trusted-leaf-bad-host.crt | 18 ++++++ .../verificationcerts/trusted-leaf-bad-host.key | 27 ++++++++ test/data/verificationcerts/trusted-leaf.crt | 18 ++++++ test/data/verificationcerts/trusted-leaf.key | 27 ++++++++ test/data/verificationcerts/trusted-root.crt | 21 +++++++ test/data/verificationcerts/trusted-root.key | 27 ++++++++ test/data/verificationcerts/trusted-root.srl | 1 + test/data/verificationcerts/trusted.key | 15 ----- test/data/verificationcerts/trusted.pem | 15 ----- test/data/verificationcerts/untrusted-chain.crt | 33 ---------- test/data/verificationcerts/untrusted-interm.crt | 17 ----- test/data/verificationcerts/untrusted.crt | 16 ----- .../data/verificationcerts/verification-server.key | 16 ----- 22 files changed, 277 insertions(+), 212 deletions(-) delete mode 100644 test/data/verificationcerts/8117bdb9.0 delete mode 100644 test/data/verificationcerts/9d45e6a9.0 create mode 100644 test/data/verificationcerts/9da13359.0 create mode 100644 test/data/verificationcerts/generate.py delete mode 100644 test/data/verificationcerts/interm.key create mode 100644 test/data/verificationcerts/self-signed.crt create mode 100644 test/data/verificationcerts/self-signed.key delete mode 100644 test/data/verificationcerts/trusted-chain.crt delete mode 100644 test/data/verificationcerts/trusted-interm.crt create mode 100644 test/data/verificationcerts/trusted-leaf-bad-host.crt create mode 100644 test/data/verificationcerts/trusted-leaf-bad-host.key create mode 100644 test/data/verificationcerts/trusted-leaf.crt create mode 100644 test/data/verificationcerts/trusted-leaf.key create mode 100644 test/data/verificationcerts/trusted-root.crt create mode 100644 test/data/verificationcerts/trusted-root.key create mode 100644 test/data/verificationcerts/trusted-root.srl delete mode 100644 test/data/verificationcerts/trusted.key delete mode 100644 test/data/verificationcerts/trusted.pem delete mode 100644 test/data/verificationcerts/untrusted-chain.crt delete mode 100644 test/data/verificationcerts/untrusted-interm.crt delete mode 100644 test/data/verificationcerts/untrusted.crt delete mode 100644 test/data/verificationcerts/verification-server.key diff --git a/test/data/verificationcerts/8117bdb9.0 b/test/data/verificationcerts/8117bdb9.0 deleted file mode 100644 index 8ebc0e5c..00000000 --- a/test/data/verificationcerts/8117bdb9.0 +++ /dev/null @@ -1,15 +0,0 @@ -# Self signed ------BEGIN CERTIFICATE----- -MIICJzCCAZACCQCo1BdopddN/TANBgkqhkiG9w0BAQUFADBXMQswCQYDVQQGEwJB -VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 -cyBQdHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEMCAXDTE1MDYxOTE4MDEzMVoYDzIx -MTUwNTI2MTgwMTMxWjBXMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0 -ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRAwDgYDVQQDEwdU -UlVTVEVEMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC00Jf3KrBAmLQWl+Dz -8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IR -fwrYCtBE77UbxklSlrwn06j6YSotz0/dwLEQEFDXWITJq7AyntaiafDHazbbXESN -m/+I/YEl2wKemEHE//qWbeM9kwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAF0NREP3 -X+fTebzJGttzrFkDhGVFKRNyLXblXRVanlGOYF+q8grgZY2ufC/55gqf+ub6FRT5 -gKPhL4V2rqL8UAvCE7jq8ujpVfTB8kRAKC675W2DBZk2EJX9mjlr89t7qXGsI5nF -onpfJ1UtiJshNoV7h/NFHeoag91kx628807n ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/9d45e6a9.0 b/test/data/verificationcerts/9d45e6a9.0 deleted file mode 100644 index 8ebc0e5c..00000000 --- a/test/data/verificationcerts/9d45e6a9.0 +++ /dev/null @@ -1,15 +0,0 @@ -# Self signed ------BEGIN CERTIFICATE----- -MIICJzCCAZACCQCo1BdopddN/TANBgkqhkiG9w0BAQUFADBXMQswCQYDVQQGEwJB -VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 -cyBQdHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEMCAXDTE1MDYxOTE4MDEzMVoYDzIx -MTUwNTI2MTgwMTMxWjBXMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0 -ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRAwDgYDVQQDEwdU -UlVTVEVEMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC00Jf3KrBAmLQWl+Dz -8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IR -fwrYCtBE77UbxklSlrwn06j6YSotz0/dwLEQEFDXWITJq7AyntaiafDHazbbXESN -m/+I/YEl2wKemEHE//qWbeM9kwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAF0NREP3 -X+fTebzJGttzrFkDhGVFKRNyLXblXRVanlGOYF+q8grgZY2ufC/55gqf+ub6FRT5 -gKPhL4V2rqL8UAvCE7jq8ujpVfTB8kRAKC675W2DBZk2EJX9mjlr89t7qXGsI5nF -onpfJ1UtiJshNoV7h/NFHeoag91kx628807n ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/9da13359.0 b/test/data/verificationcerts/9da13359.0 new file mode 100644 index 00000000..7d91e288 --- /dev/null +++ b/test/data/verificationcerts/9da13359.0 @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDXTCCAkWgAwIBAgIJAPJ/OeIFZUrJMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTUxMDE2MTUwMjU4WhcNMTgwODA1MTUwMjU4WjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAs8EFHXjYTdmPf9J37wUuLcx5hi9HFmXfFEbJ0tSm/U8lajfsnr869LiO +2App2JHgntreemHe/OZaaa/fPykDnDDiQBVb74H55YGHYCGphIPeyT78KEvInPOs +m/CaYFxlXB/ao81SXeGKkKagcFq/D4FjFYjmjxDxzUJVxX67knjr5WwPK60NfJSq +JzRIvFFXUtkByRv2VZmEAj56KRQx1W0+Ant51j52ryuD7pvCZ6P5TU4CdGlu34bu +1DJ/7uRBCIGYffZs7vE2wMhCvbwQAPl0q+Kq9yZdPXY+sgoGgmkydB/INuXSv/Ce +IgpBW+EjjeYD32YbnOTQ0Fi5yvxEjwIDAQABo1AwTjAdBgNVHQ4EFgQU8X+ohuC4 +QOemuutP/xX6ZCddKqowHwYDVR0jBBgwFoAU8X+ohuC4QOemuutP/xX6ZCddKqow +DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAnD4fPo/ztU6g77BSf88o +TtsQ1x2Cu3I7DNFEDBpBubooQpQZpwuLspMSgQfTnlPT4V7iBE/+3x7gJm8BcWEi +QxjJhoiWVWDpDe0GdcgNvScPq+3kupzxEJrTGSY/SJjiftlTvI1oGRmto9VXhNlU +6TeFEwieDWfm2waqJCYlGI86go47piqjh3E8ODPAT1SBRLvrfU6b3nvSPl3r1JvF +iurGxMPUk3DHb/Y19MdkFiaUqu/P+c/rO6BDxhPfuJxhmw4OdMuPA7cY0H3bbXHE +yoXqEvQ43ItEiYXVRoc9CCT1l9+ExC8cUsOTUqFi5Fwyr7K3ZDpAOqCpzaLZnss7 +mw== +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/generate.py b/test/data/verificationcerts/generate.py new file mode 100644 index 00000000..922cb95d --- /dev/null +++ b/test/data/verificationcerts/generate.py @@ -0,0 +1,72 @@ +""" +Generate SSL test certificates. +""" +import subprocess +import shlex +import os +import shutil + + +ROOT_CA = "trusted-root" +SUBJECT = "/CN=127.0.0.1/" + + +def do(args): + print("> %s" % args) + args = shlex.split(args) + output = subprocess.check_output(args) + print(output) + return output + + +def genrsa(cert): + do("openssl genrsa -out {cert}.key 2048".format(cert=cert)) + + +def sign(cert): + do("openssl x509 -req -in {cert}.csr " + "-CA {root_ca}.crt " + "-CAkey {root_ca}.key " + "-CAcreateserial " + "-days 1024 " + "-out {cert}.crt".format(root_ca=ROOT_CA, cert=cert) + ) + + +def mkcert(cert, args): + genrsa(cert) + do("openssl req -new -nodes -batch " + "-key {cert}.key " + "{args} " + "-out {cert}.csr".format(cert=cert, args=args) + ) + sign(cert) + os.remove("{cert}.csr".format(cert=cert)) + + +# create trusted root CA +genrsa("trusted-root") +do("openssl req -x509 -new -nodes -batch " + "-key trusted-root.key " + "-days 1024 " + "-out trusted-root.crt" + ) +h = do("openssl x509 -hash -noout -in trusted-root.crt").strip() +shutil.copyfile("trusted-root.crt", "{}.0".format(h)) + +# create trusted leaf cert. +mkcert("trusted-leaf", "-subj {}".format(SUBJECT)) + +# create wrong host leaf cert. +mkcert("trusted-leaf-bad-host", "-subj /CN=wrong.host/") + +# create self-signed cert +genrsa("self-signed") +do("openssl req -x509 -new -nodes -batch " + "-key self-signed.key " + "-subj {} " + "-days 1024 " + "-out self-signed.crt".format(SUBJECT) + ) + + diff --git a/test/data/verificationcerts/interm.key b/test/data/verificationcerts/interm.key deleted file mode 100644 index 76c05cf4..00000000 --- a/test/data/verificationcerts/interm.key +++ /dev/null @@ -1,16 +0,0 @@ -# Key used to sign trusted-interm.crt and untrusted-interm.crt ------BEGIN RSA PRIVATE KEY----- -MIICXAIBAAKBgQC1E80qCHhZ1gaZTYB7pN/Yxt3ehpEj+5hCbpop5iTWLuDjULS9 -WjA1wP+p02kZQ2dqL8pqT1qcc5jKmk2jvMeB/cQ7zNDg1NCmQMqx0KptRByMZ+GN -Zcqc7D4jl6vhGP4zAzV/lxvBvxtgeJI+ZdrHN0vT9I1cYADKz9SzCDCRTwIDAQAB -AoGAfKHocKnrzEmXuSSy7meI+vfF9kfA1ndxUSg3S+dwK0uQ1mTSQhI1ZIo2bnlo -uU6/e0Lxm0KLJ2wZGjoifjSNTC8pcxIfAQY4kM9fqoUcXVSBVSS2kByTunhNSVZQ -yQyc+UTq9g1zBnJsZAltn7/PaihU4heWgP/++lposuShqmECQQDaG+7l0qul1xak -9kuZgc88BSTfn9iMK2zIQRcVKuidK4dT3QEp0wmWR5Ue8jq8lvTmVTGNGZbHcheh -KhoZfLgLAkEA1IjwAw/8z02yV3lbc2QUjIl9m9lvjHBoE2sGuSfq/cZskLKrGat+ -CVj3spqVAg22tpQwVBuHiipBziWVnEtiTQJAB9FKfchQSLBt6lm9mfHyKJeSm8VR -8Kw5yO+0URjpn4CI6DOasBIVXOKR8LsD6fCLNJpHHWSWZ+2p9SfaKaGzwwJBAM31 -Scld89qca4fzNZkT0goCrvOZeUy6HVE79Q72zPVSFSD/02kT1BaQ3bB5to5/5aD2 -6AKJjwZoPs7bgykrsD0CQBzU8U/8x2dNQnG0QeqaKQu5kKhZSZ9bsawvrCkxSl6b -WAjl/Jehi5bbQ07zQo3cge6qeR38FCWVCHQ/5wNbc54= ------END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/self-signed.crt b/test/data/verificationcerts/self-signed.crt new file mode 100644 index 00000000..d7f07214 --- /dev/null +++ b/test/data/verificationcerts/self-signed.crt @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC+zCCAeOgAwIBAgIJAMLvc0tz5r3vMA0GCSqGSIb3DQEBCwUAMBQxEjAQBgNV +BAMMCTEyNy4wLjAuMTAeFw0xNTEwMTYxNTAzMDJaFw0xODA4MDUxNTAzMDJaMBQx +EjAQBgNVBAMMCTEyNy4wLjAuMTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC +ggEBALCzDuJl7g55J+ZNKnir0cekd48JnjPFk7sbJOPudsQ6pj/HXFrAXTPVix2n +eKtj2nADUds1C1fEgsJnYqYp9DtwesJEnnc0i2ykQmQZFygd7/0P7Z+YtUtup3F6 +jtUGEcCJ3dOOXJNyhESeyBcQwNvLgHYXAHFyN4svxueQ4fW7+d44fm0JaqZjHEtX +Q8tcVadIDsp65s+WWVP6gC0sMO2DikoF2g/98p1U0CeUCmueYJsmKpm+53smWrOp +cqwUXoxAdg03pbgC10aeWDvxm3aBC/Et9EDbaKuzHhBkOJ8E7CkyqLT/Vs7DQ9xl +WFF/Ebs1vsVniBFl3QpObxqhbM0CAwEAAaNQME4wHQYDVR0OBBYEFOTCuMxDnuup +hNAT1/gxdU9DIs82MB8GA1UdIwQYMBaAFOTCuMxDnuuphNAT1/gxdU9DIs82MAwG +A1UdEwQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBADOlFmM1fryPDFIP6mM7O4df +0GfMC9XWODf2NdJ9VWa8P7HrMbiPZy26ORkLpcWc+fuGbcd1ejf8TGbCz4f9aQ82 +P33s5jtGKRRAoB8rPmyALPSt9xrMUHYLYzN97sqY7ZHdHsc4NfzcbMVLOF+3aG4X +LIQiPIp6sLncBwvu0mHSjlcDcTM4n/Sqov4eeCNTGlVzTzsJQ6/lAwq9LIggRZA1 +RKWd+u7IQUcEMTKP0gvaWtfbxJH76RFPJX3wg7YSm97ArU9ZGna0rPORoIORrucL +aBncUwIXEPH4rtP1zy7Rg4ZeHyzoFcgR2W46ONTds+5aZDx98OyWv+gT9HSLgEo= +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/self-signed.key b/test/data/verificationcerts/self-signed.key new file mode 100644 index 00000000..54111eca --- /dev/null +++ b/test/data/verificationcerts/self-signed.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsLMO4mXuDnkn5k0qeKvRx6R3jwmeM8WTuxsk4+52xDqmP8dc +WsBdM9WLHad4q2PacANR2zULV8SCwmdipin0O3B6wkSedzSLbKRCZBkXKB3v/Q/t +n5i1S26ncXqO1QYRwInd045ck3KERJ7IFxDA28uAdhcAcXI3iy/G55Dh9bv53jh+ +bQlqpmMcS1dDy1xVp0gOynrmz5ZZU/qALSww7YOKSgXaD/3ynVTQJ5QKa55gmyYq +mb7neyZas6lyrBRejEB2DTeluALXRp5YO/GbdoEL8S30QNtoq7MeEGQ4nwTsKTKo +tP9WzsND3GVYUX8RuzW+xWeIEWXdCk5vGqFszQIDAQABAoIBAQCdc5DJ0IYmQ3N4 +Vj6INKLDwRwAS1O7Uk1nprJioLUX+iL2JhF3lH34mEpUbEysfFfDBFJGgKfQ13yk ++jb/VdcZuArLXRXPpvSuJFg8ldb6mmKlHzJgylSSGNH/3nO0AqqC5NbTksGPabXO +56XoV7dio52enLR6Yop37mTRJ1sR+ahLFUDZ8K0pEXn0pdZVEp+LVksJ6txtklGo +x6oDyQW/AOu2QWIhrneyvSO9XzFCqOnN9KPQDhWdqRmdPjiX+sbLevX7Tf5PhiEH +nNuPxUv19+4xmu7s2tZLY6C19noRSCo4835i25smmItU9hHJ9VvHKID0oLJCMtdD +4HSErPLJAoGBAOa8Hz927R1y124geYfl0+IG+yfF0Spe7HqYk7wyHlY5EGQAncoA +n0UclagRVNQzC4Y3s+QOLIV5HGw2ENMz7flCLe3f8SPRvFu6nqWKQLAnF1U5eO8Q +YVgaWadr8PT/iOPp4PHHfhXsNx3p6RPbDyntqG9xpGYpoy97iEMkWm+jAoGBAMQM +PBIIJ+5dgPQLE42KDK3iyNQLahVFDRXozVdGm3NERsZFAB1NjfaS+HMZRr+/WID7 +tVIxrgumY8iI8SO5nD51EaPYfppjmE55hIB2eN7GqL32JwwL4fQiT3WZ0aU0mY3m +3av+RKunXCNc7LBWPzQfAAf21D4Y8N36H6i57LjPAoGBAOX2vRYdy7m8Ceaiyz2c +3I678nnzeMLIFN0jUKsTMJUzDpj83EbGU/cnxCjcDTXpIiVFQy+ayNjGmoNnZ2F4 +skfpo6kft1DB6v9pglDu+AYZD/JK87MhGkQbDxwEQwWL4b12DlIrSAlFgrF3vmuh +uv1I9sUL+JQyD4h1kJuKkfANAoGANoJoWWMnJyGcbz59K0eNCvQZfsvFrTBL2SGn +pnKdWklLnGknBP7BUCPBLM+EWmArjYFvAvGJQPf8mo9o7NP422zVgMb7PJYgjQFA +lC9coCSAWoEMjk7nfmfjzAD+x35+i3P7gozqLwgTmEmIDeeNH0LXUV+R18o7fpzD +HLjFVwUCgYEAjcv9BwK+qMhRxFcxYKsb5HkPp5LaFa3PKgitF8jsGKd+pLDyIkDD +ih2Hohf9LjR/EqlPT/w5JLmgrF6zWAKtNzWMHKP4hae322/Xh5jTJQY3rbEf0k8D +aB3XoleKD0+5erl6tDRNAPlc8qJcgBv+UzZVBmf0n3aJD3mwoS06dvQ= +-----END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/trusted-chain.crt b/test/data/verificationcerts/trusted-chain.crt deleted file mode 100644 index dd30bff3..00000000 --- a/test/data/verificationcerts/trusted-chain.crt +++ /dev/null @@ -1,35 +0,0 @@ -# untrusted.crt, signed by trusted-interm.crt ------BEGIN CERTIFICATE----- -MIICYzCCAcwCAhAIMA0GCSqGSIb3DQEBBQUAMH4xCzAJBgNVBAYTAkFVMRMwEQYD -VQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBM -dGQxFDASBgNVBAsTC0lOVEVSTSBVTklUMSEwHwYDVQQDExhPUkcgV0lUSCBJTlRF -Uk1FRElBVEUgQ0EwIBcNMTUwNjIwMDEyMDI1WhgPMjExNTA1MjcwMTIwMjVaMHMx -CzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRl -cm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAsTCUxFQUYgVU5JVDEYMBYGA1UE -AxMPTk9UIFRSVVNURUQgT1JHMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDf -NZx/tugICrWGcpP8sa+EBX9WhazCsYIm8YgQrQO9B19dK7cHsWB+vIdFuDKHxfS2 -JBIeVSaZ6H4onWGnZRAMpi5xnitVhBQKCZP1yOewtrg2umZIbcTz8A+BwAcvmmQN -7RZMfpxN9PMccWDfgtAXsjZ2E47o9EfhpGvxfcFc0wIDAQABMA0GCSqGSIb3DQEB -BQUAA4GBABtmc8zn5efVi3iVIgODadKkTv43elIwNZBqEJ6IaoVXvi5Mp1m4VxML -LQGPTNG1lpuVDz2z/Ml78942316ailCTOx48oDnb/yy4jI6hsp+N8p6T28/Wvkbm -cCgohk6/Cwat5gf+HwoIe5Z3B3HRJaIcB0OteluuLsHAvverBjc4 ------END CERTIFICATE----- -# trusted-interm.crt, signed by trusted.pem ------BEGIN CERTIFICATE----- -MIIC8jCCAlugAwIBAgICEAcwDQYJKoZIhvcNAQEFBQAwVzELMAkGA1UEBhMCQVUx -EzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMg -UHR5IEx0ZDEQMA4GA1UEAxMHVFJVU1RFRDAgFw0xNTA2MjAwMTE4MjdaGA8yMTE1 -MDUyNzAxMTgyN1owfjELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUx -ITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UECxMLSU5U -RVJNIFVOSVQxITAfBgNVBAMTGE9SRyBXSVRIIElOVEVSTUVESUFURSBDQTCBnzAN -BgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAtRPNKgh4WdYGmU2Ae6Tf2Mbd3oaRI/uY -Qm6aKeYk1i7g41C0vVowNcD/qdNpGUNnai/Kak9anHOYyppNo7zHgf3EO8zQ4NTQ -pkDKsdCqbUQcjGfhjWXKnOw+I5er4Rj+MwM1f5cbwb8bYHiSPmXaxzdL0/SNXGAA -ys/UswgwkU8CAwEAAaOBozCBoDAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBTPkPQW -DAPOIy8mipuEsZcP1694EDBxBgNVHSMEajBooVukWTBXMQswCQYDVQQGEwJBVTET -MBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQ -dHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEggkAqNQXaKXXTf0wDQYJKoZIhvcNAQEF -BQADgYEApaPbwonY8l+zSxlY2Fw4WNKfl5nwcTW4fuv/0tZLzvsS6P4hTXxbYJNa -k3hQ1qlrr8DiWJewF85hYvEI2F/7eqS5dhhPTEUFPpsjhbgiqnASvW+WKQIgoY2r -aHgOXi7RNFtTcCgk0UZISWOY7ORLy8Xu6vKrLRjDhyfIbGlqnAs= ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/trusted-interm.crt b/test/data/verificationcerts/trusted-interm.crt deleted file mode 100644 index d577db7d..00000000 --- a/test/data/verificationcerts/trusted-interm.crt +++ /dev/null @@ -1,19 +0,0 @@ -# trusted-interm.crt, signed by trusted.pem ------BEGIN CERTIFICATE----- -MIIC8jCCAlugAwIBAgICEAcwDQYJKoZIhvcNAQEFBQAwVzELMAkGA1UEBhMCQVUx -EzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMg -UHR5IEx0ZDEQMA4GA1UEAxMHVFJVU1RFRDAgFw0xNTA2MjAwMTE4MjdaGA8yMTE1 -MDUyNzAxMTgyN1owfjELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUx -ITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UECxMLSU5U -RVJNIFVOSVQxITAfBgNVBAMTGE9SRyBXSVRIIElOVEVSTUVESUFURSBDQTCBnzAN -BgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAtRPNKgh4WdYGmU2Ae6Tf2Mbd3oaRI/uY -Qm6aKeYk1i7g41C0vVowNcD/qdNpGUNnai/Kak9anHOYyppNo7zHgf3EO8zQ4NTQ -pkDKsdCqbUQcjGfhjWXKnOw+I5er4Rj+MwM1f5cbwb8bYHiSPmXaxzdL0/SNXGAA -ys/UswgwkU8CAwEAAaOBozCBoDAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBTPkPQW -DAPOIy8mipuEsZcP1694EDBxBgNVHSMEajBooVukWTBXMQswCQYDVQQGEwJBVTET -MBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQ -dHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEggkAqNQXaKXXTf0wDQYJKoZIhvcNAQEF -BQADgYEApaPbwonY8l+zSxlY2Fw4WNKfl5nwcTW4fuv/0tZLzvsS6P4hTXxbYJNa -k3hQ1qlrr8DiWJewF85hYvEI2F/7eqS5dhhPTEUFPpsjhbgiqnASvW+WKQIgoY2r -aHgOXi7RNFtTcCgk0UZISWOY7ORLy8Xu6vKrLRjDhyfIbGlqnAs= ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/trusted-leaf-bad-host.crt b/test/data/verificationcerts/trusted-leaf-bad-host.crt new file mode 100644 index 00000000..bbf2fb0a --- /dev/null +++ b/test/data/verificationcerts/trusted-leaf-bad-host.crt @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC1jCCAb4CCQCzDwzVB+KILzANBgkqhkiG9w0BAQsFADBFMQswCQYDVQQGEwJB +VTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMB4XDTE1MTAxNjE1MDMwMVoXDTE4MDgwNTE1MDMwMVowFTETMBEG +A1UEAwwKd3JvbmcuaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +AMnJESjt6YT6x2z4SBvsrZyhlwCZ0GwYdSpfJLSaQXmzDG60i9qeqrLKDHGSUfak +W6RTl/Hh+EoJtVaVQirJyApkLOGkrMpS3HabWI/nFtShrCK5kcTDmbP52bfvhago +YZiXWoYV1WzSWKK+WiAMsGc6cUmfaoWego7dc+E9BzCP8PJniEBctWNt1wBZwxAv +G657CaHvlkEAIc6jIFIE0jL/Gi2T8J8jCAsboXYyP5AXIn+aEu/VJDGys7DnftU0 +uyK7l/qFwjTvkgs52ZqyUyoWVoM/7miXVe2D2HSzhLwXeVv+w3CtnwZ2BZA8WUIc +KhGr2sjjOIwY9xguBwi1k8kCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAktS8+9Jz +c0WatiFpo1DbHVBpN9VjjWF6uQyCOFu6uKiJgXAgCc/YekPHy9auu+DtDBVlpncV +NS/+aZlLYF7dGpbkh5Qx1q2zSf5kH1tzbH3+qJpmJcRgKXNasu5aPRFqJLRHu5Lu +V7K9Q/vRTbRNdu0Axn6yZEK+3/2bO5x5nFfUmAV2HLxFFIa6DbQhaBQjLnVyYFxD +I6+G00MAZ47rj4m+PrxsXTOq050mg519FK0t5X7ifaG56R96EKvUkfifQzZmpmgX +gs/ZaFzRkRLdqvsxyYHICL8BEKfwZUQiyAAb6Shf09/xO05a3LHl3ZXm87UxJlwW +9qWySdIdCc41RA== +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/trusted-leaf-bad-host.key b/test/data/verificationcerts/trusted-leaf-bad-host.key new file mode 100644 index 00000000..30711ece --- /dev/null +++ b/test/data/verificationcerts/trusted-leaf-bad-host.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAyckRKO3phPrHbPhIG+ytnKGXAJnQbBh1Kl8ktJpBebMMbrSL +2p6qssoMcZJR9qRbpFOX8eH4Sgm1VpVCKsnICmQs4aSsylLcdptYj+cW1KGsIrmR +xMOZs/nZt++FqChhmJdahhXVbNJYor5aIAywZzpxSZ9qhZ6Cjt1z4T0HMI/w8meI +QFy1Y23XAFnDEC8brnsJoe+WQQAhzqMgUgTSMv8aLZPwnyMICxuhdjI/kBcif5oS +79UkMbKzsOd+1TS7IruX+oXCNO+SCznZmrJTKhZWgz/uaJdV7YPYdLOEvBd5W/7D +cK2fBnYFkDxZQhwqEavayOM4jBj3GC4HCLWTyQIDAQABAoIBAQCJzs3vW/w9m1+T +ZkUo/Qzciecsu9+B03pBQ9U3mpnY2ZVGDfvthKsji6XP8pQTk9AafBSrVx5Qwiyc +Qzd7LW922U9lkyeGzexO/G0RaktHUFrVJFMPRF62cY5ldimb3Gg65DMom8S0mzt5 +efLnLINVHK6+DyeatdSIaWl4jEtat9tsxp8UNtm0rnpa+jEy13wUsTcPe9f/pLXS +KqFXdyq263R2FkKC7FaT2HHYDJmiDPwta/hHPGzc3A8/CfPDAr0SrFEuWmYRj5mW +0QrcDh+BTIavs5I5cD+95lLtWnJvak03o5eQvvWw5K4PqWidZk6OQlSoQe82uQXw +AWLVH1thAoGBAPrLy39ACxFTA7dOQnwnJJBPN1MVV4ZnUAjE49iCnZfzyr5mZWRZ +nNGJLSekwOqBbBa1dfh8n5cnv2aBNXv1m1NFMsPsnwmcm7ugrr3UPiIT81ZnJgR0 +5SzBfHTQRcegzaWq2Je79BYsa4SB6mAwPkjmOlnn03aMQICsbeFRYy59AoGBAM34 +7qCxZkz6vGxx7L6jtxP6q96Jd0S22eZqB3cccai9EfPgpywAzbYcoXfhz07RtGEU +JBf1975tKHtwxzE1YTFKtvDjkRtikI/sw8TpDVfy9fDts7RF4nNmlhQJwAXTtWAk +3Ui5u25WFi2don4XvcIexmaQviz/sguvtx3vOYA9AoGBANnQIR6VKoeTR4jt6QQW +osTKZ8w6ntdV5saW6SNi3SfZTd3q5GgxA+dfcd4aUonYeV2Hn7t90MTgenS2BxNv +jcTWNm6+lKkuYHql5N1s9cF2/kGuN/Bq7ZbfPA3fzJrB55jYNmAhlq2jSoW8pyd+ +/rklaswmcRtmV6bpGk0z+CWpAoGBAI7aAD6I6uem2rnnxYduqlH7/+mGs6Z/nt60 +WNseaiHah7H59FeLcyDD+KTZgtsqjAzsWCAaIqn6sSHz1OLnH7J9HCYz3nb8xEBd +uGVAMVX3FuXzJjh4Y5cf5iSdooUoENpOlv6SelEK+bTHaGRFeQFCMN3/szYoXMbI +JptnSB0NAoGAT9vfD/GhokSCvNO99XQOSR/r9wkrv/AzXjFUkAyEFv00+DXoNoEw +eT62HaUjdnBoIiwCc7whrYzk94BOoHGkMF5qUCNWEN9G3kOT4VWhsEYqXepdFy5Y +x/Jt5UIXABtFMKS4ZE8VbjFTodpmdqUcdCP6Zb+ASDJCzbGlqHZofJA= +-----END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/trusted-leaf.crt b/test/data/verificationcerts/trusted-leaf.crt new file mode 100644 index 00000000..10432db8 --- /dev/null +++ b/test/data/verificationcerts/trusted-leaf.crt @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC1TCCAb0CCQCzDwzVB+KILjANBgkqhkiG9w0BAQsFADBFMQswCQYDVQQGEwJB +VTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0 +cyBQdHkgTHRkMB4XDTE1MTAxNjE1MDMwMFoXDTE4MDgwNTE1MDMwMFowFDESMBAG +A1UEAwwJMTI3LjAuMC4xMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA +uDZ6dL1IxfAY8uxpzUScwTg1jHpktLk1Uwj7m+dsDGHCvEhmbLhllYdCiqpcEjxm +2IGtzzKU5iAtN46WQsWApw5H6n+Ozw3k3Wf7KqwcSxn0pemGglj7lRg8PNTBWHLe +aw3qQcZUKgmAwVly0ILYNmKbTBiRh1IpCjI1lqHM9gLY9GqrQ5N6D4iZPX3Snxq3 +IKVcpAvGShxjrRXyXwrVdk5vHdMRiNiMOLE+drpoK9ShmIz8OCCA2D+PvClaYGjz +2GbvNzrHMcSiJzggeT8aRNv8HT5JEj8NOgt9NM2yzQSZsOCGc9r6scKV0A0n8PQH +KiwBMulH9ums73MfM/NstQIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQA1pWmPnijh +szj3hf7hBrk27jEO8VO87Y9rVlPlII2JAdSGU/5AbrgsiWIMKeLJM5eqftWjuELG +ZGAJHNk/J74x16I/YlyZZZ1pzklGmVp3VYbHeabRCP77a+qLzBhhirdqPaZuFK3U +3GTm/fsyAypHxDM5xsDJVqLolLgrasFgUxEoNuI3LRbMKhcGURAOKiJJpJIwBqGo +xiZVdC5ZAOK70jU+8jNpNFrgo7gN1tinuQYFoZZ5fGIQObo5rgbqkF7U/fCknkLj +N7ykvCkMqeax3gj7htkpfXYTvG0zRiX59D11hhRGoTs3XZS52+jFHAJau6netga2 +fT7jVsojtgw8 +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/trusted-leaf.key b/test/data/verificationcerts/trusted-leaf.key new file mode 100644 index 00000000..a6aba170 --- /dev/null +++ b/test/data/verificationcerts/trusted-leaf.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAuDZ6dL1IxfAY8uxpzUScwTg1jHpktLk1Uwj7m+dsDGHCvEhm +bLhllYdCiqpcEjxm2IGtzzKU5iAtN46WQsWApw5H6n+Ozw3k3Wf7KqwcSxn0pemG +glj7lRg8PNTBWHLeaw3qQcZUKgmAwVly0ILYNmKbTBiRh1IpCjI1lqHM9gLY9Gqr +Q5N6D4iZPX3Snxq3IKVcpAvGShxjrRXyXwrVdk5vHdMRiNiMOLE+drpoK9ShmIz8 +OCCA2D+PvClaYGjz2GbvNzrHMcSiJzggeT8aRNv8HT5JEj8NOgt9NM2yzQSZsOCG +c9r6scKV0A0n8PQHKiwBMulH9ums73MfM/NstQIDAQABAoIBAE6hQmfuG9ARijS2 +4PpzXQ3Etma+H5pcq/xDi1Ki16X5XKwNo1qo4wOOdsLFsQM/sQ6dW9ljV9dayLI8 +NLtPnniwSdY4mHadEaHILpeqW3FbJOhk47tjzA96BsxYbCca8QF1MRbeVzKSV9kw +GygRkcS1FmDG4+eFFGt7vxALBHfFSdI2qkqrRRv+PHtBcV5/06CR/7CFF2GC5ZSs +X1DaceoEW9Qjl+0+EP/XoecOA6W/3zgjqv9hbPwJZuMDYYQ574BtWAi0LC5Q5G5s +L+Zbl1pMeUIcZ2nHThMSo6xM/2SU+5KNcGwYa2jfx5Q4TzP0BB7b8M2IB+bYIim/ +D7fGsUkCgYEA4r3XF5FI5eSXswrkqQWMjkjI0IJIuhAld8tRILT2c2yADTaQ4dB9 +v/SfdoZBmXxatOKrC4KJwSkdoy3YOIz1knVd27amYcTQnwpux1XQWl9Cxxc/C85h +hy5LWxsBaxy4JzFn98N9PUdX9jeXgm6yBzGcBdNU1sSXOYul3T5hzvcCgYEAz/u/ +QslIlffYiLGZ+vu9CBiVVERrix0Uj9K+I3wM7T6WtP471bMjcsQ4IKVzKeiHu/0S +bpqktdxIIbQEnziwIS15Vz165HXR9lfqurk1Vi5x4O94MB9A2pB5qTPo6IS3aClB +gyA5gUw5dUxI4iSu5nxOBaVg9jsgFzyfbr0CerMCgYEAoNnD6PgsGsqbw2wK4s0I +9Tc1HpYOOdCSg/U8TFOUMjXacYUwKsHZM3+6UD7V8qiBQKk8ZiHoz5r3Z3dyWEvH +OmsAdomQZvNUfD7Ob6K0+Cd0HAClvR5fmaKB2tPBodbx3PvzoZSRGBOwlv7BAMq+ +iNPst0VAfktgbHZg6B8FC+kCgYEAzjgKQxk7HF+b1qVqTL5QhveBEQWqMExMN/K4 +TozQcGfPnHQ8Nb6iVkgScuQ5lQMXmqDqJrq0uBFLgAdzUcAueycQmhy+fkoIPh6c +AjpjlSkGBwbJ/8TtVAlOaCOtOudkxyWo7HAGNJq0mgZieb/vn17/KX/57Qtg3Ulh +t7Y3ABsCgYEA0ehcAnI14nW5pPfmXiVg4MSaSRNHtVKsDrQ+g5SUGLuXFqyYiNxC +/yzRhtknLe1rwHjp4+bNpFj+OzAad8MXh5FClIsa7w3cc2S/9ixO6vA8BMshHNBL +GOMxdqBTzSKNf0kE/M7YYznCi4kodxy8wWwsAbQYswKCU0jCn/GBCOw= +-----END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/trusted-root.crt b/test/data/verificationcerts/trusted-root.crt new file mode 100644 index 00000000..7d91e288 --- /dev/null +++ b/test/data/verificationcerts/trusted-root.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDXTCCAkWgAwIBAgIJAPJ/OeIFZUrJMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTUxMDE2MTUwMjU4WhcNMTgwODA1MTUwMjU4WjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAs8EFHXjYTdmPf9J37wUuLcx5hi9HFmXfFEbJ0tSm/U8lajfsnr869LiO +2App2JHgntreemHe/OZaaa/fPykDnDDiQBVb74H55YGHYCGphIPeyT78KEvInPOs +m/CaYFxlXB/ao81SXeGKkKagcFq/D4FjFYjmjxDxzUJVxX67knjr5WwPK60NfJSq +JzRIvFFXUtkByRv2VZmEAj56KRQx1W0+Ant51j52ryuD7pvCZ6P5TU4CdGlu34bu +1DJ/7uRBCIGYffZs7vE2wMhCvbwQAPl0q+Kq9yZdPXY+sgoGgmkydB/INuXSv/Ce +IgpBW+EjjeYD32YbnOTQ0Fi5yvxEjwIDAQABo1AwTjAdBgNVHQ4EFgQU8X+ohuC4 +QOemuutP/xX6ZCddKqowHwYDVR0jBBgwFoAU8X+ohuC4QOemuutP/xX6ZCddKqow +DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAnD4fPo/ztU6g77BSf88o +TtsQ1x2Cu3I7DNFEDBpBubooQpQZpwuLspMSgQfTnlPT4V7iBE/+3x7gJm8BcWEi +QxjJhoiWVWDpDe0GdcgNvScPq+3kupzxEJrTGSY/SJjiftlTvI1oGRmto9VXhNlU +6TeFEwieDWfm2waqJCYlGI86go47piqjh3E8ODPAT1SBRLvrfU6b3nvSPl3r1JvF +iurGxMPUk3DHb/Y19MdkFiaUqu/P+c/rO6BDxhPfuJxhmw4OdMuPA7cY0H3bbXHE +yoXqEvQ43ItEiYXVRoc9CCT1l9+ExC8cUsOTUqFi5Fwyr7K3ZDpAOqCpzaLZnss7 +mw== +-----END CERTIFICATE----- diff --git a/test/data/verificationcerts/trusted-root.key b/test/data/verificationcerts/trusted-root.key new file mode 100644 index 00000000..298e8bd9 --- /dev/null +++ b/test/data/verificationcerts/trusted-root.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAs8EFHXjYTdmPf9J37wUuLcx5hi9HFmXfFEbJ0tSm/U8lajfs +nr869LiO2App2JHgntreemHe/OZaaa/fPykDnDDiQBVb74H55YGHYCGphIPeyT78 +KEvInPOsm/CaYFxlXB/ao81SXeGKkKagcFq/D4FjFYjmjxDxzUJVxX67knjr5WwP +K60NfJSqJzRIvFFXUtkByRv2VZmEAj56KRQx1W0+Ant51j52ryuD7pvCZ6P5TU4C +dGlu34bu1DJ/7uRBCIGYffZs7vE2wMhCvbwQAPl0q+Kq9yZdPXY+sgoGgmkydB/I +NuXSv/CeIgpBW+EjjeYD32YbnOTQ0Fi5yvxEjwIDAQABAoIBAGwOu6l07OsXv0kC +1/BK9/C2O9OrdhuDz+/bghYImf4q5v4McmUX5jQZAl3jHLABObulLRr63NbBD1b7 +T8QjPrVVOZ12eZboVrZeAGiMs+AiefoWr/T1YbrgTUJNCDCnOpN/3qqbkkk3fVnp +oQcJtlN2336hlqAoeoN+vhsETXQF1L598oOkb54O9fuZWKR/WrVgd1492oktHA0L +19RmoZ6J7a3ojZh4IN838jlc3istqywbuB77dHSXYA9ZUjg+ejkZlUG1mGiv1OQg +HUbRIqW+OOieMVvXUTGjAQDqXw/oz2d684rwXr3x7G9MKtBciVjUrNk1+JdM9M+9 +531xmNkCgYEA4+nHbNyEju33X2DB9W7xQYyrVvRu/+3+eCjCuOP+tppXs3WvE7kP +TH2kjRpptPUm0FlpmYS0Uj1ty3WApNLEP9rdPyAtRBSzwPZOZ7Wg+GiNozarJwro +FtUfJVYkO/vEOntMBOJQbih2eZudxSmKi1tC5eKKZXQJ6y56+nznrUsCgYEAyefp +Kv1qCZ3cgqxoOdt1rSpPjkmB6JOtVii/BcoKx3NiqFTI11qPKdV1dllqgNnDEber +fH9FA0POtNJrvulbw2YBq6DqySYVKZvxsDQ+Z9Ho0K3dicv4UwU5wbIQiOoNDQYC +Xb+hqBp6ZMTaK4BnBfPQld6IkKN8yU9Fw9uZT00CgYEAvd/64+fHi+gW6eALVvUJ +i2mtKTFU9GULVoHmz/AqOWjWXc1SgaTwaPJXz7JMlJSUtIl5H4veSpGg0htfhHGP +S/+DyV5+N7TjmIPbCC3aIHnCXlJiPpGoj7UYUJu2bj6u2WX1DDCbf1q4cVHDHAoi +wTzTu/+C+0i0Jrm/fMXooYcCgYEAj1igWY47d4JlaT0AbntaK8xLWUjk+3vFZ9Nb +879DMeHA3KP9R7AazmenkpPfIoX4kZ6mGKi/FZdRrV1rc8p4BN1qODDyIEdyZO07 +hY9B8zG7qlSWYdu3fTHLlLJYPOx2wZVPnsGMAy5xURPVlWb/PeGhaJXqvU3lLYOj +k29YhE0CgYEAnsoVVhuZZ2SIyPOGE3E2O/y0475lGjp83cUoEaRTRBS1DJS7LVdN +QD1PCq0owNFKUZzDcbOD1x+an4X6gxTKd0GVDjNkmVCVwkKPQZRPFBNzVRN8do5e +WFJLqY/3shJvSRmvut+SLnN5U5iYzHP1MOIatJSGfK/DXshOLlrQPqs= +-----END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/trusted-root.srl b/test/data/verificationcerts/trusted-root.srl new file mode 100644 index 00000000..22f0855d --- /dev/null +++ b/test/data/verificationcerts/trusted-root.srl @@ -0,0 +1 @@ +B30F0CD507E2882F diff --git a/test/data/verificationcerts/trusted.key b/test/data/verificationcerts/trusted.key deleted file mode 100644 index 3c26edf6..00000000 --- a/test/data/verificationcerts/trusted.key +++ /dev/null @@ -1,15 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIICXAIBAAKBgQC00Jf3KrBAmLQWl+Dz8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h -3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IRfwrYCtBE77UbxklSlrwn06j6YSotz0/d -wLEQEFDXWITJq7AyntaiafDHazbbXESNm/+I/YEl2wKemEHE//qWbeM9kwIDAQAB -AoGAVs2FBs1hi8FDQ01qWvGuzgt94MnACfxWw0xd6RY5OFUT25DqHxmb/7YVSIag -T/SS38osQ3zCA2s2FTkD7u5UX5AzJyqYJwmJhe6ZmaVly6IpebMxkX5w/hy15/N4 -uy+kzdtEBUUTNLL3DM7THkDYUxmeDzCBrHsMvYUqFgsBLOECQQDeNc1pDC++ovg5 -d9sKqMnEykBfvuvR6ra/343tYxy9zNFBvYjU3BA83MITIbEa/KtlSkIppz/K/jk5 -IRwSrwsJAkEA0E9aZfjDZbC9Z4oL7T8gtj2ftSh2g37KE5AWW2OxMJwrzoJ/6wjB -nG26ATlHEFP9bRzL2O1iovFLalqEjQo+uwJAMjtZXvjZRjATCvK0Onmjeu/5k2tW -ZdK4UzGXJOW11pYZa9ILv4qrxQZmfOqt3Zrmp/QcdswPGLVVfDum2/Zj+QJABJO5 -yMPOh0162+uMl4nrjhWMjM52zCzdA9EGrLtkCU1lKQR1CxUGLAm9LIm1pgYya1NW -p02P/USQA6Y5g1/WQQJBAIwl42Bebgaxl7dUbQX/vF+TryoCkM3B3eSM+P4XKB4f -kKSkNxvp59uq+b40gkoqEowhdq97y+pmrCxJHK43NJM= ------END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/trusted.pem b/test/data/verificationcerts/trusted.pem deleted file mode 100644 index 8ebc0e5c..00000000 --- a/test/data/verificationcerts/trusted.pem +++ /dev/null @@ -1,15 +0,0 @@ -# Self signed ------BEGIN CERTIFICATE----- -MIICJzCCAZACCQCo1BdopddN/TANBgkqhkiG9w0BAQUFADBXMQswCQYDVQQGEwJB -VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 -cyBQdHkgTHRkMRAwDgYDVQQDEwdUUlVTVEVEMCAXDTE1MDYxOTE4MDEzMVoYDzIx -MTUwNTI2MTgwMTMxWjBXMQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0 -ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRAwDgYDVQQDEwdU -UlVTVEVEMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQC00Jf3KrBAmLQWl+Dz -8Qrig8ActB94kv0/Lu03P/2DwOR8kH2h3w4OC3b3CFKX31h7hm/H1PPHq7cIX6IR -fwrYCtBE77UbxklSlrwn06j6YSotz0/dwLEQEFDXWITJq7AyntaiafDHazbbXESN -m/+I/YEl2wKemEHE//qWbeM9kwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAF0NREP3 -X+fTebzJGttzrFkDhGVFKRNyLXblXRVanlGOYF+q8grgZY2ufC/55gqf+ub6FRT5 -gKPhL4V2rqL8UAvCE7jq8ujpVfTB8kRAKC675W2DBZk2EJX9mjlr89t7qXGsI5nF -onpfJ1UtiJshNoV7h/NFHeoag91kx628807n ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/untrusted-chain.crt b/test/data/verificationcerts/untrusted-chain.crt deleted file mode 100644 index 272779d8..00000000 --- a/test/data/verificationcerts/untrusted-chain.crt +++ /dev/null @@ -1,33 +0,0 @@ -# untrusted.crt, signed by trusted-interm.crt ------BEGIN CERTIFICATE----- -MIICYzCCAcwCAhAIMA0GCSqGSIb3DQEBBQUAMH4xCzAJBgNVBAYTAkFVMRMwEQYD -VQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBM -dGQxFDASBgNVBAsTC0lOVEVSTSBVTklUMSEwHwYDVQQDExhPUkcgV0lUSCBJTlRF -Uk1FRElBVEUgQ0EwIBcNMTUwNjIwMDEyMDI1WhgPMjExNTA1MjcwMTIwMjVaMHMx -CzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRl -cm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAsTCUxFQUYgVU5JVDEYMBYGA1UE -AxMPTk9UIFRSVVNURUQgT1JHMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDf -NZx/tugICrWGcpP8sa+EBX9WhazCsYIm8YgQrQO9B19dK7cHsWB+vIdFuDKHxfS2 -JBIeVSaZ6H4onWGnZRAMpi5xnitVhBQKCZP1yOewtrg2umZIbcTz8A+BwAcvmmQN -7RZMfpxN9PMccWDfgtAXsjZ2E47o9EfhpGvxfcFc0wIDAQABMA0GCSqGSIb3DQEB -BQUAA4GBABtmc8zn5efVi3iVIgODadKkTv43elIwNZBqEJ6IaoVXvi5Mp1m4VxML -LQGPTNG1lpuVDz2z/Ml78942316ailCTOx48oDnb/yy4jI6hsp+N8p6T28/Wvkbm -cCgohk6/Cwat5gf+HwoIe5Z3B3HRJaIcB0OteluuLsHAvverBjc4 ------END CERTIFICATE----- -# untrusted-interm.crt, self-signed ------BEGIN CERTIFICATE----- -MIICdTCCAd4CCQDRSKOnIMbTgDANBgkqhkiG9w0BAQUFADB+MQswCQYDVQQGEwJB -VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 -cyBQdHkgTHRkMRQwEgYDVQQLEwtJTlRFUk0gVU5JVDEhMB8GA1UEAxMYT1JHIFdJ -VEggSU5URVJNRURJQVRFIENBMCAXDTE1MDYyMDAxMzY0M1oYDzIxMTUwNTI3MDEz -NjQzWjB+MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UE -ChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRQwEgYDVQQLEwtJTlRFUk0gVU5J -VDEhMB8GA1UEAxMYT1JHIFdJVEggSU5URVJNRURJQVRFIENBMIGfMA0GCSqGSIb3 -DQEBAQUAA4GNADCBiQKBgQC1E80qCHhZ1gaZTYB7pN/Yxt3ehpEj+5hCbpop5iTW -LuDjULS9WjA1wP+p02kZQ2dqL8pqT1qcc5jKmk2jvMeB/cQ7zNDg1NCmQMqx0Kpt -RByMZ+GNZcqc7D4jl6vhGP4zAzV/lxvBvxtgeJI+ZdrHN0vT9I1cYADKz9SzCDCR -TwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAGbObAMEajCz4kj7OP2/DB5SRy2+H/G3 -8Qvc43xlMMNQyYxsDuLOFL0UMRzoKgntrrm2nni8jND+tuMt+hv3ZlBcJlYJ6ynR -sC1ITTC/1SwwwO0AFIyduUEIJYr/B3sgcVYPLcEfeDZgmEQc9Tnc01aEu3lx2+l9 -0JTSPL2L9LdA ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/untrusted-interm.crt b/test/data/verificationcerts/untrusted-interm.crt deleted file mode 100644 index 875cdcd6..00000000 --- a/test/data/verificationcerts/untrusted-interm.crt +++ /dev/null @@ -1,17 +0,0 @@ -# untrusted-interm.crt, self-signed ------BEGIN CERTIFICATE----- -MIICdTCCAd4CCQDRSKOnIMbTgDANBgkqhkiG9w0BAQUFADB+MQswCQYDVQQGEwJB -VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0 -cyBQdHkgTHRkMRQwEgYDVQQLEwtJTlRFUk0gVU5JVDEhMB8GA1UEAxMYT1JHIFdJ -VEggSU5URVJNRURJQVRFIENBMCAXDTE1MDYyMDAxMzY0M1oYDzIxMTUwNTI3MDEz -NjQzWjB+MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UE -ChMYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRQwEgYDVQQLEwtJTlRFUk0gVU5J -VDEhMB8GA1UEAxMYT1JHIFdJVEggSU5URVJNRURJQVRFIENBMIGfMA0GCSqGSIb3 -DQEBAQUAA4GNADCBiQKBgQC1E80qCHhZ1gaZTYB7pN/Yxt3ehpEj+5hCbpop5iTW -LuDjULS9WjA1wP+p02kZQ2dqL8pqT1qcc5jKmk2jvMeB/cQ7zNDg1NCmQMqx0Kpt -RByMZ+GNZcqc7D4jl6vhGP4zAzV/lxvBvxtgeJI+ZdrHN0vT9I1cYADKz9SzCDCR -TwIDAQABMA0GCSqGSIb3DQEBBQUAA4GBAGbObAMEajCz4kj7OP2/DB5SRy2+H/G3 -8Qvc43xlMMNQyYxsDuLOFL0UMRzoKgntrrm2nni8jND+tuMt+hv3ZlBcJlYJ6ynR -sC1ITTC/1SwwwO0AFIyduUEIJYr/B3sgcVYPLcEfeDZgmEQc9Tnc01aEu3lx2+l9 -0JTSPL2L9LdA ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/untrusted.crt b/test/data/verificationcerts/untrusted.crt deleted file mode 100644 index 2dab470b..00000000 --- a/test/data/verificationcerts/untrusted.crt +++ /dev/null @@ -1,16 +0,0 @@ -# untrusted.crt, signed by trusted-interm.crt ------BEGIN CERTIFICATE----- -MIICYzCCAcwCAhAIMA0GCSqGSIb3DQEBBQUAMH4xCzAJBgNVBAYTAkFVMRMwEQYD -VQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBM -dGQxFDASBgNVBAsTC0lOVEVSTSBVTklUMSEwHwYDVQQDExhPUkcgV0lUSCBJTlRF -Uk1FRElBVEUgQ0EwIBcNMTUwNjIwMDEyMDI1WhgPMjExNTA1MjcwMTIwMjVaMHMx -CzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRl -cm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAsTCUxFQUYgVU5JVDEYMBYGA1UE -AxMPTk9UIFRSVVNURUQgT1JHMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDf -NZx/tugICrWGcpP8sa+EBX9WhazCsYIm8YgQrQO9B19dK7cHsWB+vIdFuDKHxfS2 -JBIeVSaZ6H4onWGnZRAMpi5xnitVhBQKCZP1yOewtrg2umZIbcTz8A+BwAcvmmQN -7RZMfpxN9PMccWDfgtAXsjZ2E47o9EfhpGvxfcFc0wIDAQABMA0GCSqGSIb3DQEB -BQUAA4GBABtmc8zn5efVi3iVIgODadKkTv43elIwNZBqEJ6IaoVXvi5Mp1m4VxML -LQGPTNG1lpuVDz2z/Ml78942316ailCTOx48oDnb/yy4jI6hsp+N8p6T28/Wvkbm -cCgohk6/Cwat5gf+HwoIe5Z3B3HRJaIcB0OteluuLsHAvverBjc4 ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/verification-server.key b/test/data/verificationcerts/verification-server.key deleted file mode 100644 index c527b09f..00000000 --- a/test/data/verificationcerts/verification-server.key +++ /dev/null @@ -1,16 +0,0 @@ -# Key used for untrusted.crt, untrusted-chain.crt and trusted-chain.crt ------BEGIN RSA PRIVATE KEY----- -MIICXQIBAAKBgQDfNZx/tugICrWGcpP8sa+EBX9WhazCsYIm8YgQrQO9B19dK7cH -sWB+vIdFuDKHxfS2JBIeVSaZ6H4onWGnZRAMpi5xnitVhBQKCZP1yOewtrg2umZI -bcTz8A+BwAcvmmQN7RZMfpxN9PMccWDfgtAXsjZ2E47o9EfhpGvxfcFc0wIDAQAB -AoGAE4B9ofL7Jui4n3yXTXbA3QoV7BtV0tTriDeGKd7T+soQHPXa0gM/aRNTxlWn -pJE5JkjUhG3wJ3ZWv3mwtI1x718y0yL9uEgQJYsrNN+VJQwbGxXPio5SaG39gs+y -/8xklytMIgvuCXxmcfljemW9+PGT8otYlHeIU3wvHQennDECQQD2vWAEU9k02R9w -EkCM7mZEaW+WwrzyAD1NqatsVWErbNeXFPcHwU6y+DiDg2s5iEk89+xN2rX5mW2S -PF/2RpaNAkEA55YpZN5nN4P8yCYNz5mWN0kuSPytSgJ3fQY3BY2GkdIft/KcAuDV -1pf6jxubwP4vlamnZpqLfylbGdlRBoMY3wJBALQVE3cVG3qO3XsWVzaE6O8VZPRL -vUuDETsVkp/G0Ny428DQ9FscoyvMLrMNv7yF065D5JwN/LLnYClTF1bPviECQQCo -1BavO1eh6C3DN8K/wmb5PPdqLBKkrrGvSnWYLbmZ2sZW0p4blw8tVzRJWcYtZuEH -yVuJeEcT1/FbIcto5O+fAkASbZXZka3nm41wWNYg479Sl8I+qvtScfJgpyByYhCx -QaUAtZ791U+WNNHLqfZhSzP9lFZNRI0WNBSAy3SBR2Ur ------END RSA PRIVATE KEY----- -- cgit v1.2.3 From b4eb4eab92aa7fee0fb1c3aaaedad0d08d1c6c3b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 1 Nov 2015 17:48:34 +0100 Subject: adjust test certificate generation --- test/data/verificationcerts/9da13359.0 | 32 +++++++------- test/data/verificationcerts/generate.py | 8 +--- test/data/verificationcerts/self-signed.crt | 33 +++++++------- test/data/verificationcerts/self-signed.key | 50 +++++++++++----------- .../verificationcerts/trusted-leaf-bad-host.crt | 18 -------- .../verificationcerts/trusted-leaf-bad-host.key | 27 ------------ test/data/verificationcerts/trusted-leaf.crt | 30 ++++++------- test/data/verificationcerts/trusted-leaf.key | 50 +++++++++++----------- test/data/verificationcerts/trusted-root.crt | 32 +++++++------- test/data/verificationcerts/trusted-root.key | 50 +++++++++++----------- test/data/verificationcerts/trusted-root.srl | 2 +- 11 files changed, 142 insertions(+), 190 deletions(-) delete mode 100644 test/data/verificationcerts/trusted-leaf-bad-host.crt delete mode 100644 test/data/verificationcerts/trusted-leaf-bad-host.key diff --git a/test/data/verificationcerts/9da13359.0 b/test/data/verificationcerts/9da13359.0 index 7d91e288..b22e4d20 100644 --- a/test/data/verificationcerts/9da13359.0 +++ b/test/data/verificationcerts/9da13359.0 @@ -1,21 +1,21 @@ -----BEGIN CERTIFICATE----- -MIIDXTCCAkWgAwIBAgIJAPJ/OeIFZUrJMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV +MIIDXTCCAkWgAwIBAgIJAPAfPQGCV/Z4MA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX -aWRnaXRzIFB0eSBMdGQwHhcNMTUxMDE2MTUwMjU4WhcNMTgwODA1MTUwMjU4WjBF +aWRnaXRzIFB0eSBMdGQwHhcNMTUxMTAxMTY0ODAxWhcNMTgwODIxMTY0ODAxWjBF MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB -CgKCAQEAs8EFHXjYTdmPf9J37wUuLcx5hi9HFmXfFEbJ0tSm/U8lajfsnr869LiO -2App2JHgntreemHe/OZaaa/fPykDnDDiQBVb74H55YGHYCGphIPeyT78KEvInPOs -m/CaYFxlXB/ao81SXeGKkKagcFq/D4FjFYjmjxDxzUJVxX67knjr5WwPK60NfJSq -JzRIvFFXUtkByRv2VZmEAj56KRQx1W0+Ant51j52ryuD7pvCZ6P5TU4CdGlu34bu -1DJ/7uRBCIGYffZs7vE2wMhCvbwQAPl0q+Kq9yZdPXY+sgoGgmkydB/INuXSv/Ce -IgpBW+EjjeYD32YbnOTQ0Fi5yvxEjwIDAQABo1AwTjAdBgNVHQ4EFgQU8X+ohuC4 -QOemuutP/xX6ZCddKqowHwYDVR0jBBgwFoAU8X+ohuC4QOemuutP/xX6ZCddKqow -DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAnD4fPo/ztU6g77BSf88o -TtsQ1x2Cu3I7DNFEDBpBubooQpQZpwuLspMSgQfTnlPT4V7iBE/+3x7gJm8BcWEi -QxjJhoiWVWDpDe0GdcgNvScPq+3kupzxEJrTGSY/SJjiftlTvI1oGRmto9VXhNlU -6TeFEwieDWfm2waqJCYlGI86go47piqjh3E8ODPAT1SBRLvrfU6b3nvSPl3r1JvF -iurGxMPUk3DHb/Y19MdkFiaUqu/P+c/rO6BDxhPfuJxhmw4OdMuPA7cY0H3bbXHE -yoXqEvQ43ItEiYXVRoc9CCT1l9+ExC8cUsOTUqFi5Fwyr7K3ZDpAOqCpzaLZnss7 -mw== +CgKCAQEArp8LD34JhKCwcQbwIYQMg4+eCgLVN8fwB7+/qOfJbArPs0djFBN+F7c6 +HGvMr24BKUk5u8pn4dPtNurm/vPC8ovNGmcXz62BQJpcMX2veVdRsF7yNwhNacNJ +Arq+70zNMwYBznx0XUxMF6j6nVFf3AW6SU04ylT4Mp3SY/BUUDAdfl1eRo0mPLNS +8rpsN+8YBw1Q7SCuBRVqpOgVIsL88svgQUSOlzvMZPBpG/cmB3BNKNrltwb5iFEI +1jAV7uSj5IcIuNO/246kfsDVPTFMJIzav/CUoidd5UNw+SoFDlzh8sA7L1Bm7D1/ +3KHYSKswGsSR3kynAl10w/SJKDtn8wIDAQABo1AwTjAdBgNVHQ4EFgQUgOcrtxBX +LxbpnOT65d+vpfyWUkgwHwYDVR0jBBgwFoAUgOcrtxBXLxbpnOT65d+vpfyWUkgw +DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAEE9bFmUCA+6cvESKPoi2 +TGSpV652d0xd2U66LpEXeiWRJFLz8YGgoJCx3QFGBscJDXxrLxrBBBV/tCpEqypo +pYIqsawH7M66jpOr83Us3M8JC2eFBZJocMpXxdytWqHik5VKZNx6VQFT8bS7+yVC +VoUKePhlgcg+pmo41qjqieBNKRMh/1tXS77DI1lgO5wZLVrLXcdqWuDpmaQOKJeq +G/nxytCW/YJA7bFn/8Gjy8DYypJSeeaKu7o3P3+ONJHdIMHb+MdcheDBS9AOFSeo +xI0D5EbO9F873O77l7nbD7B0X34HFN0nGczC4poexIpbDFG3hAPekwZ5KC6VwJLc +1Q== -----END CERTIFICATE----- diff --git a/test/data/verificationcerts/generate.py b/test/data/verificationcerts/generate.py index 922cb95d..9203abbb 100644 --- a/test/data/verificationcerts/generate.py +++ b/test/data/verificationcerts/generate.py @@ -8,14 +8,13 @@ import shutil ROOT_CA = "trusted-root" -SUBJECT = "/CN=127.0.0.1/" +SUBJECT = "/CN=example.mitmproxy.org/" def do(args): print("> %s" % args) args = shlex.split(args) output = subprocess.check_output(args) - print(output) return output @@ -51,15 +50,12 @@ do("openssl req -x509 -new -nodes -batch " "-days 1024 " "-out trusted-root.crt" ) -h = do("openssl x509 -hash -noout -in trusted-root.crt").strip() +h = do("openssl x509 -hash -noout -in trusted-root.crt").decode("ascii").strip() shutil.copyfile("trusted-root.crt", "{}.0".format(h)) # create trusted leaf cert. mkcert("trusted-leaf", "-subj {}".format(SUBJECT)) -# create wrong host leaf cert. -mkcert("trusted-leaf-bad-host", "-subj /CN=wrong.host/") - # create self-signed cert genrsa("self-signed") do("openssl req -x509 -new -nodes -batch " diff --git a/test/data/verificationcerts/self-signed.crt b/test/data/verificationcerts/self-signed.crt index d7f07214..dce2a7e0 100644 --- a/test/data/verificationcerts/self-signed.crt +++ b/test/data/verificationcerts/self-signed.crt @@ -1,18 +1,19 @@ -----BEGIN CERTIFICATE----- -MIIC+zCCAeOgAwIBAgIJAMLvc0tz5r3vMA0GCSqGSIb3DQEBCwUAMBQxEjAQBgNV -BAMMCTEyNy4wLjAuMTAeFw0xNTEwMTYxNTAzMDJaFw0xODA4MDUxNTAzMDJaMBQx -EjAQBgNVBAMMCTEyNy4wLjAuMTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC -ggEBALCzDuJl7g55J+ZNKnir0cekd48JnjPFk7sbJOPudsQ6pj/HXFrAXTPVix2n -eKtj2nADUds1C1fEgsJnYqYp9DtwesJEnnc0i2ykQmQZFygd7/0P7Z+YtUtup3F6 -jtUGEcCJ3dOOXJNyhESeyBcQwNvLgHYXAHFyN4svxueQ4fW7+d44fm0JaqZjHEtX -Q8tcVadIDsp65s+WWVP6gC0sMO2DikoF2g/98p1U0CeUCmueYJsmKpm+53smWrOp -cqwUXoxAdg03pbgC10aeWDvxm3aBC/Et9EDbaKuzHhBkOJ8E7CkyqLT/Vs7DQ9xl -WFF/Ebs1vsVniBFl3QpObxqhbM0CAwEAAaNQME4wHQYDVR0OBBYEFOTCuMxDnuup -hNAT1/gxdU9DIs82MB8GA1UdIwQYMBaAFOTCuMxDnuuphNAT1/gxdU9DIs82MAwG -A1UdEwQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBADOlFmM1fryPDFIP6mM7O4df -0GfMC9XWODf2NdJ9VWa8P7HrMbiPZy26ORkLpcWc+fuGbcd1ejf8TGbCz4f9aQ82 -P33s5jtGKRRAoB8rPmyALPSt9xrMUHYLYzN97sqY7ZHdHsc4NfzcbMVLOF+3aG4X -LIQiPIp6sLncBwvu0mHSjlcDcTM4n/Sqov4eeCNTGlVzTzsJQ6/lAwq9LIggRZA1 -RKWd+u7IQUcEMTKP0gvaWtfbxJH76RFPJX3wg7YSm97ArU9ZGna0rPORoIORrucL -aBncUwIXEPH4rtP1zy7Rg4ZeHyzoFcgR2W46ONTds+5aZDx98OyWv+gT9HSLgEo= +MIIDEzCCAfugAwIBAgIJAJ945xt1FRsfMA0GCSqGSIb3DQEBCwUAMCAxHjAcBgNV +BAMMFWV4YW1wbGUubWl0bXByb3h5Lm9yZzAeFw0xNTExMDExNjQ4MDJaFw0xODA4 +MjExNjQ4MDJaMCAxHjAcBgNVBAMMFWV4YW1wbGUubWl0bXByb3h5Lm9yZzCCASIw +DQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALFxyzPfjgIghOMMnJlW80yB84xC +nJtko3tuyOdozgTCyha2W+NdIKPNZJtWrzN4P0B5PlozCDwfcSYffLs0WZs8LRWv +BfZX8+oX+14qQjKFsiqgO65cTLP3qlPySYPJQQ37vOP1Y5Yf8nQq2mwQdC18hLtT +QOANG6OFoSplpBLsYF+QeoMgqCTa6hrl/5GLmQoDRTjXkv3Sj379AUDMybuBqccm +q5EIqCrE4+xJ8JywJclAVn2YP14baiFrrYCsYYg4sS1Od6xFj+xtpLe7My3AYjB9 +/aeHd8vDiob0cqOW1TFwhqgJKuErfFyg8lZ2hJmStJKyfofWuY/gl/vnvX0CAwEA +AaNQME4wHQYDVR0OBBYEFB8d32zK8eqZIoKw4jXzYzhw4amPMB8GA1UdIwQYMBaA +FB8d32zK8eqZIoKw4jXzYzhw4amPMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEL +BQADggEBAJmo2oKv1OEjZ0Q4yELO6BAnHAkmBKpW+zmLyQa8idxtLVkI9uXk3iqY +GWugkmcUZCTVFRWv/QXQQSex+00IY3x2rdHbtuZwcyKiz2u8WEmfW1rOIwBaFJ1i +v7+SA2aZs6vepN2sE56X54c/YbwQooaKZtOb+djWXYMJrc/Ezj0J7oQIJTptYV8v +/3216yCHRp/KCL7yTLtiw25xKuXNu/gkcd8wZOY9rS2qMUD897MJF0MvgJoauRBd +d4XEYCNKkrIRmfqrkiRQfAZpvpoutH6NCk7KuQYcI0BlOHlsnHHcs/w72EEqHwFq +x6476tW/t8GJDZVD74+pNBcLifXxArE= -----END CERTIFICATE----- diff --git a/test/data/verificationcerts/self-signed.key b/test/data/verificationcerts/self-signed.key index 54111eca..71a6ad6a 100644 --- a/test/data/verificationcerts/self-signed.key +++ b/test/data/verificationcerts/self-signed.key @@ -1,27 +1,27 @@ -----BEGIN RSA PRIVATE KEY----- -MIIEpQIBAAKCAQEAsLMO4mXuDnkn5k0qeKvRx6R3jwmeM8WTuxsk4+52xDqmP8dc -WsBdM9WLHad4q2PacANR2zULV8SCwmdipin0O3B6wkSedzSLbKRCZBkXKB3v/Q/t -n5i1S26ncXqO1QYRwInd045ck3KERJ7IFxDA28uAdhcAcXI3iy/G55Dh9bv53jh+ -bQlqpmMcS1dDy1xVp0gOynrmz5ZZU/qALSww7YOKSgXaD/3ynVTQJ5QKa55gmyYq -mb7neyZas6lyrBRejEB2DTeluALXRp5YO/GbdoEL8S30QNtoq7MeEGQ4nwTsKTKo -tP9WzsND3GVYUX8RuzW+xWeIEWXdCk5vGqFszQIDAQABAoIBAQCdc5DJ0IYmQ3N4 -Vj6INKLDwRwAS1O7Uk1nprJioLUX+iL2JhF3lH34mEpUbEysfFfDBFJGgKfQ13yk -+jb/VdcZuArLXRXPpvSuJFg8ldb6mmKlHzJgylSSGNH/3nO0AqqC5NbTksGPabXO -56XoV7dio52enLR6Yop37mTRJ1sR+ahLFUDZ8K0pEXn0pdZVEp+LVksJ6txtklGo -x6oDyQW/AOu2QWIhrneyvSO9XzFCqOnN9KPQDhWdqRmdPjiX+sbLevX7Tf5PhiEH -nNuPxUv19+4xmu7s2tZLY6C19noRSCo4835i25smmItU9hHJ9VvHKID0oLJCMtdD -4HSErPLJAoGBAOa8Hz927R1y124geYfl0+IG+yfF0Spe7HqYk7wyHlY5EGQAncoA -n0UclagRVNQzC4Y3s+QOLIV5HGw2ENMz7flCLe3f8SPRvFu6nqWKQLAnF1U5eO8Q -YVgaWadr8PT/iOPp4PHHfhXsNx3p6RPbDyntqG9xpGYpoy97iEMkWm+jAoGBAMQM -PBIIJ+5dgPQLE42KDK3iyNQLahVFDRXozVdGm3NERsZFAB1NjfaS+HMZRr+/WID7 -tVIxrgumY8iI8SO5nD51EaPYfppjmE55hIB2eN7GqL32JwwL4fQiT3WZ0aU0mY3m -3av+RKunXCNc7LBWPzQfAAf21D4Y8N36H6i57LjPAoGBAOX2vRYdy7m8Ceaiyz2c -3I678nnzeMLIFN0jUKsTMJUzDpj83EbGU/cnxCjcDTXpIiVFQy+ayNjGmoNnZ2F4 -skfpo6kft1DB6v9pglDu+AYZD/JK87MhGkQbDxwEQwWL4b12DlIrSAlFgrF3vmuh -uv1I9sUL+JQyD4h1kJuKkfANAoGANoJoWWMnJyGcbz59K0eNCvQZfsvFrTBL2SGn -pnKdWklLnGknBP7BUCPBLM+EWmArjYFvAvGJQPf8mo9o7NP422zVgMb7PJYgjQFA -lC9coCSAWoEMjk7nfmfjzAD+x35+i3P7gozqLwgTmEmIDeeNH0LXUV+R18o7fpzD -HLjFVwUCgYEAjcv9BwK+qMhRxFcxYKsb5HkPp5LaFa3PKgitF8jsGKd+pLDyIkDD -ih2Hohf9LjR/EqlPT/w5JLmgrF6zWAKtNzWMHKP4hae322/Xh5jTJQY3rbEf0k8D -aB3XoleKD0+5erl6tDRNAPlc8qJcgBv+UzZVBmf0n3aJD3mwoS06dvQ= +MIIEowIBAAKCAQEAsXHLM9+OAiCE4wycmVbzTIHzjEKcm2Sje27I52jOBMLKFrZb +410go81km1avM3g/QHk+WjMIPB9xJh98uzRZmzwtFa8F9lfz6hf7XipCMoWyKqA7 +rlxMs/eqU/JJg8lBDfu84/Vjlh/ydCrabBB0LXyEu1NA4A0bo4WhKmWkEuxgX5B6 +gyCoJNrqGuX/kYuZCgNFONeS/dKPfv0BQMzJu4GpxyarkQioKsTj7EnwnLAlyUBW +fZg/XhtqIWutgKxhiDixLU53rEWP7G2kt7szLcBiMH39p4d3y8OKhvRyo5bVMXCG +qAkq4St8XKDyVnaEmZK0krJ+h9a5j+CX++e9fQIDAQABAoIBAQCT+FvGbych2PJX +0D2KlXqgE0IAdc/YuYymstSwPLKIP9N8KyfnKtK8Jdw+uYOyfRTp8/EuEJ5OXL3j +V6CRD++lRwIlseVb7y5EySjh9oVrUhgn+aSrGucPsHkGNeZeEmbAfWugARLBrvRl +MRMhyHrJL6wT9jIEZInmy9mA3G99IuFW3rS8UR1Yu7zyvhtjvop1xg/wfEUu24Ty +PvMfnwaDcZHCz2tmu2KJvaxSBAG3FKmAqeMvk1Gt5m2keKgw03M+EX0LrM8ybWqn +VwB8tnSyMBLVFLIXMpIiSfpji10+p9fdKFMRF++D6qVwyoxPiIq+yEJapxXiqLea +mkhtJW91AoGBAOvIb7bZvH4wYvi6txs2pygF3ZMjqg/fycnplrmYMrjeeDeeN4v1 +h/5tkN9TeTkHRaN3L7v49NEUDhDyuopLTNfWpYdv63U/BVzvgMm/guacTYkx9whB +OvQ2YekR/WKg7kuyrTZidTDz+mjU+1b8JaWGjiDc6vFwxZA7uWicaGGHAoGBAMCo +y/2AwFGwCR+5bET1nTTyxok6iKo4k6R/7DJe4Bq8VLifoyX3zDlGG/33KN3xVqBU +xnT9gkii1lfX2U+4iM+GOSPl0nG0hOEqEH+vFHszpHybDeNez3FEyIbgOzg6u7sV +NOy+P94L5EMQVEmWp5g6Vm3k9kr92Bd9UacKQPnbAoGAMN8KyMu41i8RVJze9zUM +0K7mjmkGBuRL3x4br7xsRwVVxbF1sfzig0oSjTewGLH5LTi3HC8uD2gowjqNj7yr +4NEM3lXEaDj305uRBkA70bD0IUvJ+FwM7DGZecXQz3Cr8+TFIlCmGc94R+Jddlot +M3IAY69mw0SsroiylYxV1mECgYAcSGtx8rXJCDO+sYTgdsI2ZLGasbogax/ZlWIC +XwU9R4qUc/MKft8/RTiUxvT76BMUhH2B7Tl0GlunF6vyVR/Yf1biGzoSsTKUr40u +gXBbSdCK7mRSjbecZEGf80keTxkCNPHJE4DiwxImej41c2V1JpNLnMI/bhaMFDyp +bgrt4wKBgHFzZgAgM1v07F038tAkIBGrYLukY1ZFBaZoGZ9xHfy/EmLJM3HCHLO5 +8wszMGhMTe2+39EeChwgj0kFaq1YnDiucU74BC57KR1tD59y7l6UnsQXTm4/32j8 +Or6i8GekBibCb97DzzOU0ZK//fNhHTXpDDXsYt5lJUWSmgW+S9Qp -----END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/trusted-leaf-bad-host.crt b/test/data/verificationcerts/trusted-leaf-bad-host.crt deleted file mode 100644 index bbf2fb0a..00000000 --- a/test/data/verificationcerts/trusted-leaf-bad-host.crt +++ /dev/null @@ -1,18 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIC1jCCAb4CCQCzDwzVB+KILzANBgkqhkiG9w0BAQsFADBFMQswCQYDVQQGEwJB -VTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0 -cyBQdHkgTHRkMB4XDTE1MTAxNjE1MDMwMVoXDTE4MDgwNTE1MDMwMVowFTETMBEG -A1UEAwwKd3JvbmcuaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB -AMnJESjt6YT6x2z4SBvsrZyhlwCZ0GwYdSpfJLSaQXmzDG60i9qeqrLKDHGSUfak -W6RTl/Hh+EoJtVaVQirJyApkLOGkrMpS3HabWI/nFtShrCK5kcTDmbP52bfvhago -YZiXWoYV1WzSWKK+WiAMsGc6cUmfaoWego7dc+E9BzCP8PJniEBctWNt1wBZwxAv -G657CaHvlkEAIc6jIFIE0jL/Gi2T8J8jCAsboXYyP5AXIn+aEu/VJDGys7DnftU0 -uyK7l/qFwjTvkgs52ZqyUyoWVoM/7miXVe2D2HSzhLwXeVv+w3CtnwZ2BZA8WUIc -KhGr2sjjOIwY9xguBwi1k8kCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAktS8+9Jz -c0WatiFpo1DbHVBpN9VjjWF6uQyCOFu6uKiJgXAgCc/YekPHy9auu+DtDBVlpncV -NS/+aZlLYF7dGpbkh5Qx1q2zSf5kH1tzbH3+qJpmJcRgKXNasu5aPRFqJLRHu5Lu -V7K9Q/vRTbRNdu0Axn6yZEK+3/2bO5x5nFfUmAV2HLxFFIa6DbQhaBQjLnVyYFxD -I6+G00MAZ47rj4m+PrxsXTOq050mg519FK0t5X7ifaG56R96EKvUkfifQzZmpmgX -gs/ZaFzRkRLdqvsxyYHICL8BEKfwZUQiyAAb6Shf09/xO05a3LHl3ZXm87UxJlwW -9qWySdIdCc41RA== ------END CERTIFICATE----- diff --git a/test/data/verificationcerts/trusted-leaf-bad-host.key b/test/data/verificationcerts/trusted-leaf-bad-host.key deleted file mode 100644 index 30711ece..00000000 --- a/test/data/verificationcerts/trusted-leaf-bad-host.key +++ /dev/null @@ -1,27 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIIEpQIBAAKCAQEAyckRKO3phPrHbPhIG+ytnKGXAJnQbBh1Kl8ktJpBebMMbrSL -2p6qssoMcZJR9qRbpFOX8eH4Sgm1VpVCKsnICmQs4aSsylLcdptYj+cW1KGsIrmR -xMOZs/nZt++FqChhmJdahhXVbNJYor5aIAywZzpxSZ9qhZ6Cjt1z4T0HMI/w8meI -QFy1Y23XAFnDEC8brnsJoe+WQQAhzqMgUgTSMv8aLZPwnyMICxuhdjI/kBcif5oS -79UkMbKzsOd+1TS7IruX+oXCNO+SCznZmrJTKhZWgz/uaJdV7YPYdLOEvBd5W/7D -cK2fBnYFkDxZQhwqEavayOM4jBj3GC4HCLWTyQIDAQABAoIBAQCJzs3vW/w9m1+T -ZkUo/Qzciecsu9+B03pBQ9U3mpnY2ZVGDfvthKsji6XP8pQTk9AafBSrVx5Qwiyc -Qzd7LW922U9lkyeGzexO/G0RaktHUFrVJFMPRF62cY5ldimb3Gg65DMom8S0mzt5 -efLnLINVHK6+DyeatdSIaWl4jEtat9tsxp8UNtm0rnpa+jEy13wUsTcPe9f/pLXS -KqFXdyq263R2FkKC7FaT2HHYDJmiDPwta/hHPGzc3A8/CfPDAr0SrFEuWmYRj5mW -0QrcDh+BTIavs5I5cD+95lLtWnJvak03o5eQvvWw5K4PqWidZk6OQlSoQe82uQXw -AWLVH1thAoGBAPrLy39ACxFTA7dOQnwnJJBPN1MVV4ZnUAjE49iCnZfzyr5mZWRZ -nNGJLSekwOqBbBa1dfh8n5cnv2aBNXv1m1NFMsPsnwmcm7ugrr3UPiIT81ZnJgR0 -5SzBfHTQRcegzaWq2Je79BYsa4SB6mAwPkjmOlnn03aMQICsbeFRYy59AoGBAM34 -7qCxZkz6vGxx7L6jtxP6q96Jd0S22eZqB3cccai9EfPgpywAzbYcoXfhz07RtGEU -JBf1975tKHtwxzE1YTFKtvDjkRtikI/sw8TpDVfy9fDts7RF4nNmlhQJwAXTtWAk -3Ui5u25WFi2don4XvcIexmaQviz/sguvtx3vOYA9AoGBANnQIR6VKoeTR4jt6QQW -osTKZ8w6ntdV5saW6SNi3SfZTd3q5GgxA+dfcd4aUonYeV2Hn7t90MTgenS2BxNv -jcTWNm6+lKkuYHql5N1s9cF2/kGuN/Bq7ZbfPA3fzJrB55jYNmAhlq2jSoW8pyd+ -/rklaswmcRtmV6bpGk0z+CWpAoGBAI7aAD6I6uem2rnnxYduqlH7/+mGs6Z/nt60 -WNseaiHah7H59FeLcyDD+KTZgtsqjAzsWCAaIqn6sSHz1OLnH7J9HCYz3nb8xEBd -uGVAMVX3FuXzJjh4Y5cf5iSdooUoENpOlv6SelEK+bTHaGRFeQFCMN3/szYoXMbI -JptnSB0NAoGAT9vfD/GhokSCvNO99XQOSR/r9wkrv/AzXjFUkAyEFv00+DXoNoEw -eT62HaUjdnBoIiwCc7whrYzk94BOoHGkMF5qUCNWEN9G3kOT4VWhsEYqXepdFy5Y -x/Jt5UIXABtFMKS4ZE8VbjFTodpmdqUcdCP6Zb+ASDJCzbGlqHZofJA= ------END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/trusted-leaf.crt b/test/data/verificationcerts/trusted-leaf.crt index 10432db8..6a92de92 100644 --- a/test/data/verificationcerts/trusted-leaf.crt +++ b/test/data/verificationcerts/trusted-leaf.crt @@ -1,18 +1,18 @@ -----BEGIN CERTIFICATE----- -MIIC1TCCAb0CCQCzDwzVB+KILjANBgkqhkiG9w0BAQsFADBFMQswCQYDVQQGEwJB +MIIC4TCCAckCCQCj6D9oVylb8jANBgkqhkiG9w0BAQsFADBFMQswCQYDVQQGEwJB VTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0 -cyBQdHkgTHRkMB4XDTE1MTAxNjE1MDMwMFoXDTE4MDgwNTE1MDMwMFowFDESMBAG -A1UEAwwJMTI3LjAuMC4xMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA -uDZ6dL1IxfAY8uxpzUScwTg1jHpktLk1Uwj7m+dsDGHCvEhmbLhllYdCiqpcEjxm -2IGtzzKU5iAtN46WQsWApw5H6n+Ozw3k3Wf7KqwcSxn0pemGglj7lRg8PNTBWHLe -aw3qQcZUKgmAwVly0ILYNmKbTBiRh1IpCjI1lqHM9gLY9GqrQ5N6D4iZPX3Snxq3 -IKVcpAvGShxjrRXyXwrVdk5vHdMRiNiMOLE+drpoK9ShmIz8OCCA2D+PvClaYGjz -2GbvNzrHMcSiJzggeT8aRNv8HT5JEj8NOgt9NM2yzQSZsOCGc9r6scKV0A0n8PQH -KiwBMulH9ums73MfM/NstQIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQA1pWmPnijh -szj3hf7hBrk27jEO8VO87Y9rVlPlII2JAdSGU/5AbrgsiWIMKeLJM5eqftWjuELG -ZGAJHNk/J74x16I/YlyZZZ1pzklGmVp3VYbHeabRCP77a+qLzBhhirdqPaZuFK3U -3GTm/fsyAypHxDM5xsDJVqLolLgrasFgUxEoNuI3LRbMKhcGURAOKiJJpJIwBqGo -xiZVdC5ZAOK70jU+8jNpNFrgo7gN1tinuQYFoZZ5fGIQObo5rgbqkF7U/fCknkLj -N7ykvCkMqeax3gj7htkpfXYTvG0zRiX59D11hhRGoTs3XZS52+jFHAJau6netga2 -fT7jVsojtgw8 +cyBQdHkgTHRkMB4XDTE1MTEwMTE2NDgwMloXDTE4MDgyMTE2NDgwMlowIDEeMBwG +A1UEAwwVZXhhbXBsZS5taXRtcHJveHkub3JnMIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEAy/L5JYHS7QFhSIsjmd6bJTgs2rdqEn6tsmPBVZKZ7SqCAVjW +hPpEu7Q23akmU6Zm9Fp/vENc3jzxQLlEKhrv7eWmFYSOrCYtbJOz3RQorlwjjfdY +LlNQh1wYUXQX3PN3r3dyYtt5vTtXKc8+aP4M4vX7qlbW+4j4LrQfmPjS0XOdYpu3 +wh+i1ZMIhZye3hpCjwnpjTf7/ff45ZFxtkoi1uzEC/+swr1RSvamY8Foe12Re17Z +5ij8ZB0NIdoSk1tDkY3sJ8iNi35+qartl0UYeG9IUXRwDRrPsEKpF4RxY1+X2bdZ +r6PKb/E4CA5JlMvS5SVmrvxjCVqTQBmTjXfxqwIDAQABMA0GCSqGSIb3DQEBCwUA +A4IBAQBmpSZJrTDvzSlo6P7P7x1LoETzHyVjwgPeqGYw6ndGXeJMN9rhhsFvRsiB +I/aHh58MIlSjti7paikDAoFHB3dBvFHR+JUa/ailWEbcZReWRSE3lV6wFiN3G3lU +OyofR7MKnPW7bv8hSqOLqP1mbupXuQFB5M6vPLRwg5VgiCHI/XBiTvzMamzvNAR3 +UHHZtsJkRqzogYm6K9YJaga7jteSx2nNo+ujLwrxeXsLChTyFMJGnVkp5IyKeNfc +qwlzNncb3y+4KnUdNkPEtuydgAxAfuyXufiFBYRcUWbQ5/9ycgF7131ySaj9f/Y2 +kMsv2jg+soKvwwVYCABsk1KSHtfz -----END CERTIFICATE----- diff --git a/test/data/verificationcerts/trusted-leaf.key b/test/data/verificationcerts/trusted-leaf.key index a6aba170..783ebf1c 100644 --- a/test/data/verificationcerts/trusted-leaf.key +++ b/test/data/verificationcerts/trusted-leaf.key @@ -1,27 +1,27 @@ -----BEGIN RSA PRIVATE KEY----- -MIIEpQIBAAKCAQEAuDZ6dL1IxfAY8uxpzUScwTg1jHpktLk1Uwj7m+dsDGHCvEhm -bLhllYdCiqpcEjxm2IGtzzKU5iAtN46WQsWApw5H6n+Ozw3k3Wf7KqwcSxn0pemG -glj7lRg8PNTBWHLeaw3qQcZUKgmAwVly0ILYNmKbTBiRh1IpCjI1lqHM9gLY9Gqr -Q5N6D4iZPX3Snxq3IKVcpAvGShxjrRXyXwrVdk5vHdMRiNiMOLE+drpoK9ShmIz8 -OCCA2D+PvClaYGjz2GbvNzrHMcSiJzggeT8aRNv8HT5JEj8NOgt9NM2yzQSZsOCG -c9r6scKV0A0n8PQHKiwBMulH9ums73MfM/NstQIDAQABAoIBAE6hQmfuG9ARijS2 -4PpzXQ3Etma+H5pcq/xDi1Ki16X5XKwNo1qo4wOOdsLFsQM/sQ6dW9ljV9dayLI8 -NLtPnniwSdY4mHadEaHILpeqW3FbJOhk47tjzA96BsxYbCca8QF1MRbeVzKSV9kw -GygRkcS1FmDG4+eFFGt7vxALBHfFSdI2qkqrRRv+PHtBcV5/06CR/7CFF2GC5ZSs -X1DaceoEW9Qjl+0+EP/XoecOA6W/3zgjqv9hbPwJZuMDYYQ574BtWAi0LC5Q5G5s -L+Zbl1pMeUIcZ2nHThMSo6xM/2SU+5KNcGwYa2jfx5Q4TzP0BB7b8M2IB+bYIim/ -D7fGsUkCgYEA4r3XF5FI5eSXswrkqQWMjkjI0IJIuhAld8tRILT2c2yADTaQ4dB9 -v/SfdoZBmXxatOKrC4KJwSkdoy3YOIz1knVd27amYcTQnwpux1XQWl9Cxxc/C85h -hy5LWxsBaxy4JzFn98N9PUdX9jeXgm6yBzGcBdNU1sSXOYul3T5hzvcCgYEAz/u/ -QslIlffYiLGZ+vu9CBiVVERrix0Uj9K+I3wM7T6WtP471bMjcsQ4IKVzKeiHu/0S -bpqktdxIIbQEnziwIS15Vz165HXR9lfqurk1Vi5x4O94MB9A2pB5qTPo6IS3aClB -gyA5gUw5dUxI4iSu5nxOBaVg9jsgFzyfbr0CerMCgYEAoNnD6PgsGsqbw2wK4s0I -9Tc1HpYOOdCSg/U8TFOUMjXacYUwKsHZM3+6UD7V8qiBQKk8ZiHoz5r3Z3dyWEvH -OmsAdomQZvNUfD7Ob6K0+Cd0HAClvR5fmaKB2tPBodbx3PvzoZSRGBOwlv7BAMq+ -iNPst0VAfktgbHZg6B8FC+kCgYEAzjgKQxk7HF+b1qVqTL5QhveBEQWqMExMN/K4 -TozQcGfPnHQ8Nb6iVkgScuQ5lQMXmqDqJrq0uBFLgAdzUcAueycQmhy+fkoIPh6c -AjpjlSkGBwbJ/8TtVAlOaCOtOudkxyWo7HAGNJq0mgZieb/vn17/KX/57Qtg3Ulh -t7Y3ABsCgYEA0ehcAnI14nW5pPfmXiVg4MSaSRNHtVKsDrQ+g5SUGLuXFqyYiNxC -/yzRhtknLe1rwHjp4+bNpFj+OzAad8MXh5FClIsa7w3cc2S/9ixO6vA8BMshHNBL -GOMxdqBTzSKNf0kE/M7YYznCi4kodxy8wWwsAbQYswKCU0jCn/GBCOw= +MIIEpAIBAAKCAQEAy/L5JYHS7QFhSIsjmd6bJTgs2rdqEn6tsmPBVZKZ7SqCAVjW +hPpEu7Q23akmU6Zm9Fp/vENc3jzxQLlEKhrv7eWmFYSOrCYtbJOz3RQorlwjjfdY +LlNQh1wYUXQX3PN3r3dyYtt5vTtXKc8+aP4M4vX7qlbW+4j4LrQfmPjS0XOdYpu3 +wh+i1ZMIhZye3hpCjwnpjTf7/ff45ZFxtkoi1uzEC/+swr1RSvamY8Foe12Re17Z +5ij8ZB0NIdoSk1tDkY3sJ8iNi35+qartl0UYeG9IUXRwDRrPsEKpF4RxY1+X2bdZ +r6PKb/E4CA5JlMvS5SVmrvxjCVqTQBmTjXfxqwIDAQABAoIBAQC956DWq+wbhA1x +3x1nSUBth8E8Z0z9q7dRRFHhvIBXth0X5ADcEa2umj/8ZmSpv2heX2ZRhugSh+yc +t+YgzrRacFwV7ThsU6A4WdBBK2Q19tWke4xAlpOFdtut/Mu7kXkAidiY9ISHD5o5 +9B/I48ZcD3AnTHUiAogV9OL3LbogDD4HasLt4mWkbq8U2thdjxMIvxdg36olJEuo +iAZrAUCPZEXuU89BtvPLUYioe9n90nzkyneGNS0SHxotlEc9ZYK9VTsivtXJb4wB +ptDMCp+TH3tjo8BTGnbnoZEybgyyOEd0UTzxK4DlxnvRVWexFY6NXwPFhIxKlB0Y +Bg8NkAkBAoGBAOiRnmbC5QkqrKrTkLx3fghIHPqgEXPPYgHLSuY3UjTlMb3APXpq +vzQnlCn3QuSse/1fWnQj+9vLVbx1XNgKjzk7dQhn5IUY+mGN4lLmoSnTebxvSQ43 +VAgTYjST9JFmJ3wK4KkWDsEsVao8LAx0h5JEQXUTT5xZpFA2MLztYbgfAoGBAOB/ +MvhLMAwlx8+m/zXMEPLk/KOd2dVZ4q5se8bAT/GiGsi8JUcPnCk140ZZabJqryAp +JFzUHIjfVsS9ejAfocDk1JeIm7Uus4um6fQEKIPMBxI/M/UAwYCXAG9ULXqilbO3 +pTdeeuraVKrTu1Z4ea6x4du1JWKcyDfYfsHepcT1AoGBAM2fskV5G7e3G2MOG3IG +1E/OMpEE5WlXenfLnjVdxDkwS4JRbgnGR7d9JurTyzkTp6ylmfwFtLDoXq15ttTs +wSUBBMCh2tIy+201XV2eu++XIpMQca84C/v352RFTH8hqtdpZqkY74KsCDGzcd6x +SQxxfM5efIzoVPb2crEX0MZRAoGAQ2EqFSfL9flo7UQ8GRN0itJ7mUgJV2WxCZT5 +2X9i/y0eSN1feuKOhjfsTPMNLEWk5kwy48GuBs6xpj8Qa10zGUgVHp4bzdeEgAfK +9DhDSLt1694YZBKkAUpRERj8xXAC6nvWFLZAwjhhbRw7gAqMywgMt/q4i85usYRD +F0ESE/kCgYBbc083PcLmlHbkn/d1i4IcLI6wFk+tZYIEVYDid7xDOgZOBcOTTyYB +BrDzNqbKNexKRt7QHVlwR+VOGMdN5P0hf7oH3SMW23OxBKoQe8pUSGF9a4DjCS1v +vCXMekifb9kIhhUWaG71L8+MaOzNBVAmk1+3NzPZgV/YxHjAWWhGHQ== -----END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/trusted-root.crt b/test/data/verificationcerts/trusted-root.crt index 7d91e288..b22e4d20 100644 --- a/test/data/verificationcerts/trusted-root.crt +++ b/test/data/verificationcerts/trusted-root.crt @@ -1,21 +1,21 @@ -----BEGIN CERTIFICATE----- -MIIDXTCCAkWgAwIBAgIJAPJ/OeIFZUrJMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV +MIIDXTCCAkWgAwIBAgIJAPAfPQGCV/Z4MA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX -aWRnaXRzIFB0eSBMdGQwHhcNMTUxMDE2MTUwMjU4WhcNMTgwODA1MTUwMjU4WjBF +aWRnaXRzIFB0eSBMdGQwHhcNMTUxMTAxMTY0ODAxWhcNMTgwODIxMTY0ODAxWjBF MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB -CgKCAQEAs8EFHXjYTdmPf9J37wUuLcx5hi9HFmXfFEbJ0tSm/U8lajfsnr869LiO -2App2JHgntreemHe/OZaaa/fPykDnDDiQBVb74H55YGHYCGphIPeyT78KEvInPOs -m/CaYFxlXB/ao81SXeGKkKagcFq/D4FjFYjmjxDxzUJVxX67knjr5WwPK60NfJSq -JzRIvFFXUtkByRv2VZmEAj56KRQx1W0+Ant51j52ryuD7pvCZ6P5TU4CdGlu34bu -1DJ/7uRBCIGYffZs7vE2wMhCvbwQAPl0q+Kq9yZdPXY+sgoGgmkydB/INuXSv/Ce -IgpBW+EjjeYD32YbnOTQ0Fi5yvxEjwIDAQABo1AwTjAdBgNVHQ4EFgQU8X+ohuC4 -QOemuutP/xX6ZCddKqowHwYDVR0jBBgwFoAU8X+ohuC4QOemuutP/xX6ZCddKqow -DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAnD4fPo/ztU6g77BSf88o -TtsQ1x2Cu3I7DNFEDBpBubooQpQZpwuLspMSgQfTnlPT4V7iBE/+3x7gJm8BcWEi -QxjJhoiWVWDpDe0GdcgNvScPq+3kupzxEJrTGSY/SJjiftlTvI1oGRmto9VXhNlU -6TeFEwieDWfm2waqJCYlGI86go47piqjh3E8ODPAT1SBRLvrfU6b3nvSPl3r1JvF -iurGxMPUk3DHb/Y19MdkFiaUqu/P+c/rO6BDxhPfuJxhmw4OdMuPA7cY0H3bbXHE -yoXqEvQ43ItEiYXVRoc9CCT1l9+ExC8cUsOTUqFi5Fwyr7K3ZDpAOqCpzaLZnss7 -mw== +CgKCAQEArp8LD34JhKCwcQbwIYQMg4+eCgLVN8fwB7+/qOfJbArPs0djFBN+F7c6 +HGvMr24BKUk5u8pn4dPtNurm/vPC8ovNGmcXz62BQJpcMX2veVdRsF7yNwhNacNJ +Arq+70zNMwYBznx0XUxMF6j6nVFf3AW6SU04ylT4Mp3SY/BUUDAdfl1eRo0mPLNS +8rpsN+8YBw1Q7SCuBRVqpOgVIsL88svgQUSOlzvMZPBpG/cmB3BNKNrltwb5iFEI +1jAV7uSj5IcIuNO/246kfsDVPTFMJIzav/CUoidd5UNw+SoFDlzh8sA7L1Bm7D1/ +3KHYSKswGsSR3kynAl10w/SJKDtn8wIDAQABo1AwTjAdBgNVHQ4EFgQUgOcrtxBX +LxbpnOT65d+vpfyWUkgwHwYDVR0jBBgwFoAUgOcrtxBXLxbpnOT65d+vpfyWUkgw +DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAEE9bFmUCA+6cvESKPoi2 +TGSpV652d0xd2U66LpEXeiWRJFLz8YGgoJCx3QFGBscJDXxrLxrBBBV/tCpEqypo +pYIqsawH7M66jpOr83Us3M8JC2eFBZJocMpXxdytWqHik5VKZNx6VQFT8bS7+yVC +VoUKePhlgcg+pmo41qjqieBNKRMh/1tXS77DI1lgO5wZLVrLXcdqWuDpmaQOKJeq +G/nxytCW/YJA7bFn/8Gjy8DYypJSeeaKu7o3P3+ONJHdIMHb+MdcheDBS9AOFSeo +xI0D5EbO9F873O77l7nbD7B0X34HFN0nGczC4poexIpbDFG3hAPekwZ5KC6VwJLc +1Q== -----END CERTIFICATE----- diff --git a/test/data/verificationcerts/trusted-root.key b/test/data/verificationcerts/trusted-root.key index 298e8bd9..05483f77 100644 --- a/test/data/verificationcerts/trusted-root.key +++ b/test/data/verificationcerts/trusted-root.key @@ -1,27 +1,27 @@ -----BEGIN RSA PRIVATE KEY----- -MIIEpQIBAAKCAQEAs8EFHXjYTdmPf9J37wUuLcx5hi9HFmXfFEbJ0tSm/U8lajfs -nr869LiO2App2JHgntreemHe/OZaaa/fPykDnDDiQBVb74H55YGHYCGphIPeyT78 -KEvInPOsm/CaYFxlXB/ao81SXeGKkKagcFq/D4FjFYjmjxDxzUJVxX67knjr5WwP -K60NfJSqJzRIvFFXUtkByRv2VZmEAj56KRQx1W0+Ant51j52ryuD7pvCZ6P5TU4C -dGlu34bu1DJ/7uRBCIGYffZs7vE2wMhCvbwQAPl0q+Kq9yZdPXY+sgoGgmkydB/I -NuXSv/CeIgpBW+EjjeYD32YbnOTQ0Fi5yvxEjwIDAQABAoIBAGwOu6l07OsXv0kC -1/BK9/C2O9OrdhuDz+/bghYImf4q5v4McmUX5jQZAl3jHLABObulLRr63NbBD1b7 -T8QjPrVVOZ12eZboVrZeAGiMs+AiefoWr/T1YbrgTUJNCDCnOpN/3qqbkkk3fVnp -oQcJtlN2336hlqAoeoN+vhsETXQF1L598oOkb54O9fuZWKR/WrVgd1492oktHA0L -19RmoZ6J7a3ojZh4IN838jlc3istqywbuB77dHSXYA9ZUjg+ejkZlUG1mGiv1OQg -HUbRIqW+OOieMVvXUTGjAQDqXw/oz2d684rwXr3x7G9MKtBciVjUrNk1+JdM9M+9 -531xmNkCgYEA4+nHbNyEju33X2DB9W7xQYyrVvRu/+3+eCjCuOP+tppXs3WvE7kP -TH2kjRpptPUm0FlpmYS0Uj1ty3WApNLEP9rdPyAtRBSzwPZOZ7Wg+GiNozarJwro -FtUfJVYkO/vEOntMBOJQbih2eZudxSmKi1tC5eKKZXQJ6y56+nznrUsCgYEAyefp -Kv1qCZ3cgqxoOdt1rSpPjkmB6JOtVii/BcoKx3NiqFTI11qPKdV1dllqgNnDEber -fH9FA0POtNJrvulbw2YBq6DqySYVKZvxsDQ+Z9Ho0K3dicv4UwU5wbIQiOoNDQYC -Xb+hqBp6ZMTaK4BnBfPQld6IkKN8yU9Fw9uZT00CgYEAvd/64+fHi+gW6eALVvUJ -i2mtKTFU9GULVoHmz/AqOWjWXc1SgaTwaPJXz7JMlJSUtIl5H4veSpGg0htfhHGP -S/+DyV5+N7TjmIPbCC3aIHnCXlJiPpGoj7UYUJu2bj6u2WX1DDCbf1q4cVHDHAoi -wTzTu/+C+0i0Jrm/fMXooYcCgYEAj1igWY47d4JlaT0AbntaK8xLWUjk+3vFZ9Nb -879DMeHA3KP9R7AazmenkpPfIoX4kZ6mGKi/FZdRrV1rc8p4BN1qODDyIEdyZO07 -hY9B8zG7qlSWYdu3fTHLlLJYPOx2wZVPnsGMAy5xURPVlWb/PeGhaJXqvU3lLYOj -k29YhE0CgYEAnsoVVhuZZ2SIyPOGE3E2O/y0475lGjp83cUoEaRTRBS1DJS7LVdN -QD1PCq0owNFKUZzDcbOD1x+an4X6gxTKd0GVDjNkmVCVwkKPQZRPFBNzVRN8do5e -WFJLqY/3shJvSRmvut+SLnN5U5iYzHP1MOIatJSGfK/DXshOLlrQPqs= +MIIEowIBAAKCAQEArp8LD34JhKCwcQbwIYQMg4+eCgLVN8fwB7+/qOfJbArPs0dj +FBN+F7c6HGvMr24BKUk5u8pn4dPtNurm/vPC8ovNGmcXz62BQJpcMX2veVdRsF7y +NwhNacNJArq+70zNMwYBznx0XUxMF6j6nVFf3AW6SU04ylT4Mp3SY/BUUDAdfl1e +Ro0mPLNS8rpsN+8YBw1Q7SCuBRVqpOgVIsL88svgQUSOlzvMZPBpG/cmB3BNKNrl +twb5iFEI1jAV7uSj5IcIuNO/246kfsDVPTFMJIzav/CUoidd5UNw+SoFDlzh8sA7 +L1Bm7D1/3KHYSKswGsSR3kynAl10w/SJKDtn8wIDAQABAoIBAFgMzjDzpqz/sbhs +fS0JPp4gDtqRbx3/bSMbJvNuXPxjvzNxLZ5z7cLbmyu1l7Jlz6QXzkrI1vTiPdzR +OcUY+RYANF252iHYJTKEIzS5YX/X7dL3LT9eqlpIJEqCC8Dygw3VW5fY3Xwl+sB7 +blNhMuro4HQRwi8UBUrQlcPa7Ui5BBi323Q6en+VjYctkqpJHzNKPSqPTbsdLaK+ +B0XuXxFatM09rmeRKZCL71Lk1T8N/l0hqEzej7zxgVD7vG/x1kMFN4T3yCmXCbPa +izGHYr1EBHglm4qMNWveXCZiVJ+wmwCjdjqvggyHiZFXE2N0OCrWPhxQPdqFf5y7 +bUO9U2ECgYEA6GM1UzRnbVpjb20ezFy7dU7rlWM0nHBfG27M3bcXh4HnPpnvKp0/ +8a1WFi4kkRywrNXx8hFEd43vTbdObLpVXScXRKiY3MHmFk4k4hbWuTpmumCubQZO +AWlX6TE0HRKn1wQahgpQcxcWaDN2xJJmRQ1zVmlnNkT48/4kFgRxyykCgYEAwF08 +ngrF35oYoU/x+KKq2NXGeNUzoZMj568dE1oWW0ZFpqCi+DGT+hAbG3yUOBSaPqy9 +zn1obGo0YRlrayvtebz118kG7a/rzY02VcAPlT/GpEhvkZlXTwEK17zRJc1nJrfP +39QAZWZsaOru9NRIg/8HcdG3JPR2MhRD/De9GbsCgYAaiZnBUq6s8jGAu/lUZRKT +JtwIRzfu1XZG77Q9bXcmZlM99t41A5gVxTGbftF2MMyMMDJc7lPfQzocqd4u1GiD +Jr+le4tZSls4GNxlZS5IIL8ycW/5y0qFJr5/RrsoxsSb7UAKJothWTWZ2Karc/xx +zkNpjsfWjrHPSypbyU4lYQKBgFh1R5/BgnatjO/5LGNSok/uFkOQfxqo6BTtYOh6 +P9efO/5A1lBdtBeE+oIsSphzWO7DTtE6uB9Kw2V3Y/83hw+5RjABoG8Cu+OdMURD +eqb+WeFH8g45Pn31E8Bbcq34g5u5YR0jhz8Z13ZzuojZabNRPmIntxmGVSf4S78a +/plrAoGBANMHNng2lyr03nqnHrOM6NXD+60af0YR/YJ+2d/H40RnXxGJ4DXn7F00 +a4vJFPa97uq+xpd0HE+TE+NIrOdVDXPePD2qzBzMTsctGtj30vLzojMOT+Yf/nvO +WxTL5Q8GruJz2Dn0awSZO2z/3A8S1rmpuVZ/jT5NtRrvOSY6hmxF -----END RSA PRIVATE KEY----- diff --git a/test/data/verificationcerts/trusted-root.srl b/test/data/verificationcerts/trusted-root.srl index 22f0855d..4ad962ba 100644 --- a/test/data/verificationcerts/trusted-root.srl +++ b/test/data/verificationcerts/trusted-root.srl @@ -1 +1 @@ -B30F0CD507E2882F +A3E83F6857295BF2 -- cgit v1.2.3 From 5af9df326aef1cf72be7fd5390df239fb6b906c7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 1 Nov 2015 18:15:30 +0100 Subject: fix certificate verification This commit fixes netlib's optional (turned off by default) certificate verification, which previously did not validate the cert's host name. As it turns out, verifying the connection's host name on an intercepting proxy is not really straightforward - if we receive a connection in transparent mode without SNI, we have no clue which hosts the client intends to connect to. There are two basic approaches to solve this problem: 1. Exactly mirror the host names presented by the server in the spoofed certificate presented to the client. 2. Require the client to send the TLS Server Name Indication extension. While this does not work with older clients, we can validate the hostname on the proxy. Approach 1 is problematic in mitmproxy's use case, as we may want to deliberately divert connections without the client's knowledge. As a consequence, we opt for approach 2. While mitmproxy does now require a SNI value to be sent by the client if certificate verification is turned on, we retain our ability to present certificates to the client which are accepted with a maximum likelihood. --- netlib/certutils.py | 5 +++++ netlib/tcp.py | 37 ++++++++++++++++++++++++------ setup.py | 1 + test/test_tcp.py | 65 ++++++++++++++++++++++++++++++++--------------------- 4 files changed, 75 insertions(+), 33 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index b3ddcbe4..93366a99 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -437,6 +437,11 @@ class SSLCert(object): @property def altnames(self): + """ + Returns: + All DNS altnames. + """ + # tcp.TCPClient.convert_to_ssl assumes that this property only contains DNS altnames for hostname verification. altnames = [] for i in range(self.x509.get_extension_count()): ext = self.x509.get_extension(i) diff --git a/netlib/tcp.py b/netlib/tcp.py index b751d71f..33776fc4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -11,6 +11,7 @@ import binascii from six.moves import range import certifi +from backports import ssl_match_hostname import six import OpenSSL from OpenSSL import SSL @@ -597,9 +598,14 @@ class TCPClient(_Connection): ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool ca_pemfile: Path to a PEM formatted trusted CA certificate """ + verification_mode = sslctx_kwargs.get('verify_options', None) + if verification_mode == SSL.VERIFY_PEER and not sni: + raise TlsException("Cannot validate certificate hostname without SNI") + context = self.create_ssl_context( alpn_protos=alpn_protos, - **sslctx_kwargs) + **sslctx_kwargs + ) self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni @@ -612,15 +618,32 @@ class TCPClient(_Connection): raise InvalidCertificateException("SSL handshake error: %s" % repr(v)) else: raise TlsException("SSL handshake error: %s" % repr(v)) + else: + # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on + # certificate validation failure + if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error is not None: + raise InvalidCertificateException("SSL handshake error: certificate verify failed") - # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on - # certificate validation failure - verification_mode = sslctx_kwargs.get('verify_options', None) - if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER: - raise InvalidCertificateException("SSL handshake error: certificate verify failed") + self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) + + # Validate TLS Hostname + try: + crt = dict( + subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in self.cert.altnames] + ) + if self.cert.cn: + crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]] + if sni: + hostname = sni.decode("ascii", "strict") + else: + hostname = "no-hostname" + ssl_match_hostname.match_hostname(crt, hostname) + except (ValueError, ssl_match_hostname.CertificateError) as e: + self.ssl_verification_error = dict(depth=0, errno="Invalid Hostname") + if verification_mode == SSL.VERIFY_PEER: + raise InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e))) self.ssl_established = True - self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) diff --git a/setup.py b/setup.py index 30c80f5b..729910f8 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ deps = { "hpack>=1.0.1", "six>=1.9.0", "certifi>=2015.9.6.2", + "backports.ssl_match_hostname>=3.4.0.2", } if sys.version_info < (3, 0): deps.add("ipaddress>=1.0.14") diff --git a/test/test_tcp.py b/test/test_tcp.py index c87bebb3..68d54b78 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -189,8 +189,8 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/verificationcerts/untrusted.crt"), - key=tutils.test_data.path("data/verificationcerts/verification-server.key") + cert=tutils.test_data.path("data/verificationcerts/self-signed.crt"), + key=tutils.test_data.path("data/verificationcerts/self-signed.key") ) def test_mode_default_should_pass(self): @@ -226,58 +226,69 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - tutils.raises( - InvalidCertificateException, - c.convert_to_ssl, - verify_options=SSL.VERIFY_PEER, - ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem")) + with tutils.raises(InvalidCertificateException): + c.convert_to_ssl( + sni=b"example.mitmproxy.org", + verify_options=SSL.VERIFY_PEER, + ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") + ) assert c.ssl_verification_error is not None # Unknown issuing certificate authority for first certificate - assert c.ssl_verification_error['errno'] == 20 + assert c.ssl_verification_error['errno'] == 18 assert c.ssl_verification_error['depth'] == 0 -class TestSSLUpstreamCertVerificationWBadCertChain(tservers.ServerTestBase): +class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/verificationcerts/untrusted-chain.crt"), - key=tutils.test_data.path("data/verificationcerts/verification-server.key")) + cert=tutils.test_data.path("data/verificationcerts/trusted-leaf.crt"), + key=tutils.test_data.path("data/verificationcerts/trusted-leaf.key") + ) - def test_mode_strict_should_fail(self): + def test_should_fail_without_sni(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - tutils.raises( - "certificate verify failed", - c.convert_to_ssl, - verify_options=SSL.VERIFY_PEER, - ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem")) + with tutils.raises(TlsException): + c.convert_to_ssl( + verify_options=SSL.VERIFY_PEER, + ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") + ) - assert c.ssl_verification_error is not None + def test_should_fail(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + + with tutils.raises(InvalidCertificateException): + c.convert_to_ssl( + sni=b"mitmproxy.org", + verify_options=SSL.VERIFY_PEER, + ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") + ) - # Untrusted self-signed certificate at second position in certificate - # chain - assert c.ssl_verification_error['errno'] == 19 - assert c.ssl_verification_error['depth'] == 1 + assert c.ssl_verification_error is not None class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - cert=tutils.test_data.path("data/verificationcerts/trusted-chain.crt"), - key=tutils.test_data.path("data/verificationcerts/verification-server.key")) + cert=tutils.test_data.path("data/verificationcerts/trusted-leaf.crt"), + key=tutils.test_data.path("data/verificationcerts/trusted-leaf.key") + ) def test_mode_strict_w_pemfile_should_pass(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl( + sni=b"example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, - ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem")) + ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") + ) assert c.ssl_verification_error is None @@ -291,8 +302,10 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): c.connect() c.convert_to_ssl( + sni=b"example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, - ca_path=tutils.test_data.path("data/verificationcerts/")) + ca_path=tutils.test_data.path("data/verificationcerts/") + ) assert c.ssl_verification_error is None -- cgit v1.2.3 From 9d36f8e43fc7a3b3c4bf10a8c1b9819da8999dad Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 1 Nov 2015 18:20:00 +0100 Subject: minor fixes --- netlib/http/request.py | 2 ++ netlib/tcp.py | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/http/request.py b/netlib/http/request.py index 92d99532..5ebf21a5 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -92,6 +92,8 @@ class Request(Message): Target host. This may be parsed from the raw request (e.g. from a ``GET http://example.com/ HTTP/1.1`` request line) or inferred from the proxy mode (e.g. an IP in transparent mode). + + Setting the host attribute also updates the host header, if present. """ if six.PY2: # pragma: nocover diff --git a/netlib/tcp.py b/netlib/tcp.py index b751d71f..ef5ab4b6 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -559,8 +559,6 @@ class TCPClient(_Connection): @address.setter def address(self, address): - if self.connection: - raise RuntimeError("Cannot change server address after establishing connection") if address: self.__address = Address.wrap(address) else: -- cgit v1.2.3 From 9d12425d5ee942ee3d954a9324c31b74f466d520 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 4 Nov 2015 11:28:02 +0100 Subject: Set default cert expiry to <39 months This sould fix mitmproxy/mitmproxy#815 --- netlib/certutils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index b3ddcbe4..69530245 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -12,7 +12,8 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 +# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 +DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = b""" -----BEGIN DH PARAMETERS----- -- cgit v1.2.3 From 3e2eb3fef166822bfad0d2200dadffe541efbc38 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 6 Nov 2015 13:51:15 +1300 Subject: Bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index 044fde2c..d2c3c369 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 13, 2) +IVERSION = (0, 14, 0) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 41e91dae0060cfe487a0ff372437b2bad013eea3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 6 Nov 2015 14:08:38 +1300 Subject: Add CONTRIBUTORS --- CONTRIBUTORS | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 CONTRIBUTORS diff --git a/CONTRIBUTORS b/CONTRIBUTORS new file mode 100644 index 00000000..4b4240f8 --- /dev/null +++ b/CONTRIBUTORS @@ -0,0 +1,18 @@ + 250 Aldo Cortesi + 204 Maximilian Hils + 109 Thomas Kriechbaumer + 8 Chandler Abraham + 8 Kyle Morton + 2 Sean Coates + 2 Israel Nir + 2 Brad Peabody + 2 Pedro Worcel + 2 Matthias Urlichs + 1 kronick + 1 Bradley Baetz + 1 M. Utku Altinkaya + 1 Andrey Plotnikov + 1 Paul + 1 Pritam Baral + 1 Rouli + 1 Tim Becker -- cgit v1.2.3 From 9cab9ee5d6f39b658c1e9260950cc3575d3ad9db Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 7 Nov 2015 09:30:49 +1300 Subject: Bump version for next release cycle --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index d2c3c369..e836dbe3 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 14, 0) +IVERSION = (0, 14, 1) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 823718348598efb324298ca29ad4cb7d5097c084 Mon Sep 17 00:00:00 2001 From: Sam Cleveland Date: Wed, 11 Nov 2015 11:32:02 -0600 Subject: Porting netlib to python3.4 Updated utils.py using 2to3-3.4 Updated hexdump to use .format() with .encode() to support python 3.4 Python 3.5 supports .format() on bytes objects, but 3.4 is the current default on Ubuntu. samc$ py.test netlib/test/test_utils.py = test session starts = platform darwin -- Python 3.4.1, pytest-2.8.2, py-1.4.30, pluggy-0.3.1 rootdir: /Users/samc/src/python/netlib, inifile: collected 11 items netlib/test/test_utils.py ........... = 11 passed in 0.19 seconds = --- netlib/utils.py | 16 +-- netlib/utils.py.bak | 368 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 376 insertions(+), 8 deletions(-) create mode 100644 netlib/utils.py.bak diff --git a/netlib/utils.py b/netlib/utils.py index acc7ccd4..62f17012 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,4 +1,4 @@ -from __future__ import absolute_import, print_function, division + import os.path import re import string @@ -61,11 +61,11 @@ def clean_bin(s, keep_spacing=True): """ if isinstance(s, six.text_type): if keep_spacing: - keep = u" \n\r\t" + keep = " \n\r\t" else: - keep = u" " - return u"".join( - ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." + keep = " " + return "".join( + ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else "." for ch in s ) else: @@ -85,9 +85,9 @@ def hexdump(s): A generator of (offset, hex, str) tuples """ for i in range(0, len(s), 16): - offset = b"%.10x" % i + offset = "{:0=10x}".format(i).encode() part = s[i:i + 16] - x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) + x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part)) x = x.ljust(47) # 16*2 + 15 yield (offset, x, clean_bin(part, False)) @@ -122,7 +122,7 @@ class BiDi(object): def __init__(self, **kwargs): self.names = kwargs self.values = {} - for k, v in kwargs.items(): + for k, v in list(kwargs.items()): self.values[v] = k if len(self.names) != len(self.values): raise ValueError("Duplicate values not allowed.") diff --git a/netlib/utils.py.bak b/netlib/utils.py.bak new file mode 100644 index 00000000..acc7ccd4 --- /dev/null +++ b/netlib/utils.py.bak @@ -0,0 +1,368 @@ +from __future__ import absolute_import, print_function, division +import os.path +import re +import string +import unicodedata + +import six + +from six.moves import urllib + + +def always_bytes(unicode_or_bytes, *encode_args): + if isinstance(unicode_or_bytes, six.text_type): + return unicode_or_bytes.encode(*encode_args) + return unicode_or_bytes + + +def always_byte_args(*encode_args): + """Decorator that transparently encodes all arguments passed as unicode""" + def decorator(fun): + def _fun(*args, **kwargs): + args = [always_bytes(arg, *encode_args) for arg in args] + kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} + return fun(*args, **kwargs) + return _fun + return decorator + + +def native(s, *encoding_opts): + """ + Convert :py:class:`bytes` or :py:class:`unicode` to the native + :py:class:`str` type, using latin1 encoding if conversion is necessary. + + https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types + """ + if not isinstance(s, (six.binary_type, six.text_type)): + raise TypeError("%r is neither bytes nor unicode" % s) + if six.PY3: + if isinstance(s, six.binary_type): + return s.decode(*encoding_opts) + else: + if isinstance(s, six.text_type): + return s.encode(*encoding_opts) + return s + + +def isascii(bytes): + try: + bytes.decode("ascii") + except ValueError: + return False + return True + + +def clean_bin(s, keep_spacing=True): + """ + Cleans binary data to make it safe to display. + + Args: + keep_spacing: If False, tabs and newlines will also be replaced. + """ + if isinstance(s, six.text_type): + if keep_spacing: + keep = u" \n\r\t" + else: + keep = u" " + return u"".join( + ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." + for ch in s + ) + else: + if keep_spacing: + keep = (9, 10, 13) # \t, \n, \r, + else: + keep = () + return b"".join( + six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." + for ch in six.iterbytes(s) + ) + + +def hexdump(s): + """ + Returns: + A generator of (offset, hex, str) tuples + """ + for i in range(0, len(s), 16): + offset = b"%.10x" % i + part = s[i:i + 16] + x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) + x = x.ljust(47) # 16*2 + 15 + yield (offset, x, clean_bin(part, False)) + + +def setbit(byte, offset, value): + """ + Set a bit in a byte to 1 if value is truthy, 0 if not. + """ + if value: + return byte | (1 << offset) + else: + return byte & ~(1 << offset) + + +def getbit(byte, offset): + mask = 1 << offset + return bool(byte & mask) + + +class BiDi(object): + + """ + A wee utility class for keeping bi-directional mappings, like field + constants in protocols. Names are attributes on the object, dict-like + access maps values to names: + + CONST = BiDi(a=1, b=2) + assert CONST.a == 1 + assert CONST.get_name(1) == "a" + """ + + def __init__(self, **kwargs): + self.names = kwargs + self.values = {} + for k, v in kwargs.items(): + self.values[v] = k + if len(self.names) != len(self.values): + raise ValueError("Duplicate values not allowed.") + + def __getattr__(self, k): + if k in self.names: + return self.names[k] + raise AttributeError("No such attribute: %s", k) + + def get_name(self, n, default=None): + return self.values.get(n, default) + + +def pretty_size(size): + suffixes = [ + ("B", 2 ** 10), + ("kB", 2 ** 20), + ("MB", 2 ** 30), + ] + for suf, lim in suffixes: + if size >= lim: + continue + else: + x = round(size / float(lim / 2 ** 10), 2) + if x == int(x): + x = int(x) + return str(x) + suf + + +class Data(object): + + def __init__(self, name): + m = __import__(name) + dirname, _ = os.path.split(m.__file__) + self.dirname = os.path.abspath(dirname) + + def path(self, path): + """ + Returns a path to the package data housed at 'path' under this + module.Path can be a path to a file, or to a directory. + + This function will raise ValueError if the path does not exist. + """ + fullpath = os.path.join(self.dirname, '../test/', path) + if not os.path.exists(fullpath): + raise ValueError("dataPath: %s does not exist." % fullpath) + return fullpath + + +_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? 255: + return False + if host[-1] == b".": + host = host[:-1] + return all(_label_valid.match(x) for x in host.split(b".")) + + +def is_valid_port(port): + return 0 <= port <= 65535 + + +# PY2 workaround +def decode_parse_result(result, enc): + if hasattr(result, "decode"): + return result.decode(enc) + else: + return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) + + +# PY2 workaround +def encode_parse_result(result, enc): + if hasattr(result, "encode"): + return result.encode(enc) + else: + return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) + + +def parse_url(url): + """ + URL-parsing function that checks that + - port is an integer 0-65535 + - host is a valid IDNA-encoded hostname with no null-bytes + - path is valid ASCII + + Args: + A URL (as bytes or as unicode) + + Returns: + A (scheme, host, port, path) tuple + + Raises: + ValueError, if the URL is not properly formatted. + """ + parsed = urllib.parse.urlparse(url) + + if not parsed.hostname: + raise ValueError("No hostname given") + + if isinstance(url, six.binary_type): + host = parsed.hostname + + # this should not raise a ValueError, + # but we try to be very forgiving here and accept just everything. + # decode_parse_result(parsed, "ascii") + else: + host = parsed.hostname.encode("idna") + parsed = encode_parse_result(parsed, "ascii") + + port = parsed.port + if not port: + port = 443 if parsed.scheme == b"https" else 80 + + full_path = urllib.parse.urlunparse( + (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) + ) + if not full_path.startswith(b"/"): + full_path = b"/" + full_path + + if not is_valid_host(host): + raise ValueError("Invalid Host") + if not is_valid_port(port): + raise ValueError("Invalid Port") + + return parsed.scheme, host, port, full_path + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + if key not in headers: + return [] + tokens = headers[key].split(",") + return [token.strip() for token in tokens] + + +def hostport(scheme, host, port): + """ + Returns the host component, with a port specifcation if needed. + """ + if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: + return host + else: + if isinstance(host, six.binary_type): + return b"%s:%d" % (host, port) + else: + return "%s:%d" % (host, port) + + +def unparse_url(scheme, host, port, path=""): + """ + Returns a URL string, constructed from the specified components. + + Args: + All args must be str. + """ + return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) + + +def urlencode(s): + """ + Takes a list of (key, value) tuples and returns a urlencoded string. + """ + s = [tuple(i) for i in s] + return urllib.parse.urlencode(s, False) + + +def urldecode(s): + """ + Takes a urlencoded string and returns a list of (key, value) tuples. + """ + return urllib.parse.parse_qsl(s, keep_blank_values=True) + + +def parse_content_type(c): + """ + A simple parser for content-type values. Returns a (type, subtype, + parameters) tuple, where type and subtype are strings, and parameters + is a dict. If the string could not be parsed, return None. + + E.g. the following string: + + text/html; charset=UTF-8 + + Returns: + + ("text", "html", {"charset": "UTF-8"}) + """ + parts = c.split(";", 1) + ts = parts[0].split("/", 1) + if len(ts) != 2: + return None + d = {} + if len(parts) == 2: + for i in parts[1].split(";"): + clause = i.split("=", 1) + if len(clause) == 2: + d[clause[0].strip()] = clause[1].strip() + return ts[0].lower(), ts[1].lower(), d + + +def multipartdecode(headers, content): + """ + Takes a multipart boundary encoded string and returns list of (key, value) tuples. + """ + v = headers.get("content-type") + if v: + v = parse_content_type(v) + if not v: + return [] + try: + boundary = v[2]["boundary"].encode("ascii") + except (KeyError, UnicodeError): + return [] + + rx = re.compile(br'\bname="([^"]+)"') + r = [] + + for i in content.split(b"--" + boundary): + parts = i.splitlines() + if len(parts) > 1 and parts[0][0:2] != b"--": + match = rx.search(parts[1]) + if match: + key = match.group(1) + value = b"".join(parts[3 + parts[2:].index(b""):]) + r.append((key, value)) + return r + return [] -- cgit v1.2.3 From 2d48f12332ff380db3ab66c8f436f78a62b2cd91 Mon Sep 17 00:00:00 2001 From: Sam Cleveland Date: Wed, 11 Nov 2015 19:41:42 -0600 Subject: Revert "Porting netlib to python3.4" This reverts commit 823718348598efb324298ca29ad4cb7d5097c084. --- netlib/utils.py | 16 +-- netlib/utils.py.bak | 368 ---------------------------------------------------- 2 files changed, 8 insertions(+), 376 deletions(-) delete mode 100644 netlib/utils.py.bak diff --git a/netlib/utils.py b/netlib/utils.py index 62f17012..acc7ccd4 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,4 +1,4 @@ - +from __future__ import absolute_import, print_function, division import os.path import re import string @@ -61,11 +61,11 @@ def clean_bin(s, keep_spacing=True): """ if isinstance(s, six.text_type): if keep_spacing: - keep = " \n\r\t" + keep = u" \n\r\t" else: - keep = " " - return "".join( - ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else "." + keep = u" " + return u"".join( + ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." for ch in s ) else: @@ -85,9 +85,9 @@ def hexdump(s): A generator of (offset, hex, str) tuples """ for i in range(0, len(s), 16): - offset = "{:0=10x}".format(i).encode() + offset = b"%.10x" % i part = s[i:i + 16] - x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part)) + x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) x = x.ljust(47) # 16*2 + 15 yield (offset, x, clean_bin(part, False)) @@ -122,7 +122,7 @@ class BiDi(object): def __init__(self, **kwargs): self.names = kwargs self.values = {} - for k, v in list(kwargs.items()): + for k, v in kwargs.items(): self.values[v] = k if len(self.names) != len(self.values): raise ValueError("Duplicate values not allowed.") diff --git a/netlib/utils.py.bak b/netlib/utils.py.bak deleted file mode 100644 index acc7ccd4..00000000 --- a/netlib/utils.py.bak +++ /dev/null @@ -1,368 +0,0 @@ -from __future__ import absolute_import, print_function, division -import os.path -import re -import string -import unicodedata - -import six - -from six.moves import urllib - - -def always_bytes(unicode_or_bytes, *encode_args): - if isinstance(unicode_or_bytes, six.text_type): - return unicode_or_bytes.encode(*encode_args) - return unicode_or_bytes - - -def always_byte_args(*encode_args): - """Decorator that transparently encodes all arguments passed as unicode""" - def decorator(fun): - def _fun(*args, **kwargs): - args = [always_bytes(arg, *encode_args) for arg in args] - kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} - return fun(*args, **kwargs) - return _fun - return decorator - - -def native(s, *encoding_opts): - """ - Convert :py:class:`bytes` or :py:class:`unicode` to the native - :py:class:`str` type, using latin1 encoding if conversion is necessary. - - https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types - """ - if not isinstance(s, (six.binary_type, six.text_type)): - raise TypeError("%r is neither bytes nor unicode" % s) - if six.PY3: - if isinstance(s, six.binary_type): - return s.decode(*encoding_opts) - else: - if isinstance(s, six.text_type): - return s.encode(*encoding_opts) - return s - - -def isascii(bytes): - try: - bytes.decode("ascii") - except ValueError: - return False - return True - - -def clean_bin(s, keep_spacing=True): - """ - Cleans binary data to make it safe to display. - - Args: - keep_spacing: If False, tabs and newlines will also be replaced. - """ - if isinstance(s, six.text_type): - if keep_spacing: - keep = u" \n\r\t" - else: - keep = u" " - return u"".join( - ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." - for ch in s - ) - else: - if keep_spacing: - keep = (9, 10, 13) # \t, \n, \r, - else: - keep = () - return b"".join( - six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." - for ch in six.iterbytes(s) - ) - - -def hexdump(s): - """ - Returns: - A generator of (offset, hex, str) tuples - """ - for i in range(0, len(s), 16): - offset = b"%.10x" % i - part = s[i:i + 16] - x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) - x = x.ljust(47) # 16*2 + 15 - yield (offset, x, clean_bin(part, False)) - - -def setbit(byte, offset, value): - """ - Set a bit in a byte to 1 if value is truthy, 0 if not. - """ - if value: - return byte | (1 << offset) - else: - return byte & ~(1 << offset) - - -def getbit(byte, offset): - mask = 1 << offset - return bool(byte & mask) - - -class BiDi(object): - - """ - A wee utility class for keeping bi-directional mappings, like field - constants in protocols. Names are attributes on the object, dict-like - access maps values to names: - - CONST = BiDi(a=1, b=2) - assert CONST.a == 1 - assert CONST.get_name(1) == "a" - """ - - def __init__(self, **kwargs): - self.names = kwargs - self.values = {} - for k, v in kwargs.items(): - self.values[v] = k - if len(self.names) != len(self.values): - raise ValueError("Duplicate values not allowed.") - - def __getattr__(self, k): - if k in self.names: - return self.names[k] - raise AttributeError("No such attribute: %s", k) - - def get_name(self, n, default=None): - return self.values.get(n, default) - - -def pretty_size(size): - suffixes = [ - ("B", 2 ** 10), - ("kB", 2 ** 20), - ("MB", 2 ** 30), - ] - for suf, lim in suffixes: - if size >= lim: - continue - else: - x = round(size / float(lim / 2 ** 10), 2) - if x == int(x): - x = int(x) - return str(x) + suf - - -class Data(object): - - def __init__(self, name): - m = __import__(name) - dirname, _ = os.path.split(m.__file__) - self.dirname = os.path.abspath(dirname) - - def path(self, path): - """ - Returns a path to the package data housed at 'path' under this - module.Path can be a path to a file, or to a directory. - - This function will raise ValueError if the path does not exist. - """ - fullpath = os.path.join(self.dirname, '../test/', path) - if not os.path.exists(fullpath): - raise ValueError("dataPath: %s does not exist." % fullpath) - return fullpath - - -_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? 255: - return False - if host[-1] == b".": - host = host[:-1] - return all(_label_valid.match(x) for x in host.split(b".")) - - -def is_valid_port(port): - return 0 <= port <= 65535 - - -# PY2 workaround -def decode_parse_result(result, enc): - if hasattr(result, "decode"): - return result.decode(enc) - else: - return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) - - -# PY2 workaround -def encode_parse_result(result, enc): - if hasattr(result, "encode"): - return result.encode(enc) - else: - return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) - - -def parse_url(url): - """ - URL-parsing function that checks that - - port is an integer 0-65535 - - host is a valid IDNA-encoded hostname with no null-bytes - - path is valid ASCII - - Args: - A URL (as bytes or as unicode) - - Returns: - A (scheme, host, port, path) tuple - - Raises: - ValueError, if the URL is not properly formatted. - """ - parsed = urllib.parse.urlparse(url) - - if not parsed.hostname: - raise ValueError("No hostname given") - - if isinstance(url, six.binary_type): - host = parsed.hostname - - # this should not raise a ValueError, - # but we try to be very forgiving here and accept just everything. - # decode_parse_result(parsed, "ascii") - else: - host = parsed.hostname.encode("idna") - parsed = encode_parse_result(parsed, "ascii") - - port = parsed.port - if not port: - port = 443 if parsed.scheme == b"https" else 80 - - full_path = urllib.parse.urlunparse( - (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) - ) - if not full_path.startswith(b"/"): - full_path = b"/" + full_path - - if not is_valid_host(host): - raise ValueError("Invalid Host") - if not is_valid_port(port): - raise ValueError("Invalid Port") - - return parsed.scheme, host, port, full_path - - -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - if key not in headers: - return [] - tokens = headers[key].split(",") - return [token.strip() for token in tokens] - - -def hostport(scheme, host, port): - """ - Returns the host component, with a port specifcation if needed. - """ - if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: - return host - else: - if isinstance(host, six.binary_type): - return b"%s:%d" % (host, port) - else: - return "%s:%d" % (host, port) - - -def unparse_url(scheme, host, port, path=""): - """ - Returns a URL string, constructed from the specified components. - - Args: - All args must be str. - """ - return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) - - -def urlencode(s): - """ - Takes a list of (key, value) tuples and returns a urlencoded string. - """ - s = [tuple(i) for i in s] - return urllib.parse.urlencode(s, False) - - -def urldecode(s): - """ - Takes a urlencoded string and returns a list of (key, value) tuples. - """ - return urllib.parse.parse_qsl(s, keep_blank_values=True) - - -def parse_content_type(c): - """ - A simple parser for content-type values. Returns a (type, subtype, - parameters) tuple, where type and subtype are strings, and parameters - is a dict. If the string could not be parsed, return None. - - E.g. the following string: - - text/html; charset=UTF-8 - - Returns: - - ("text", "html", {"charset": "UTF-8"}) - """ - parts = c.split(";", 1) - ts = parts[0].split("/", 1) - if len(ts) != 2: - return None - d = {} - if len(parts) == 2: - for i in parts[1].split(";"): - clause = i.split("=", 1) - if len(clause) == 2: - d[clause[0].strip()] = clause[1].strip() - return ts[0].lower(), ts[1].lower(), d - - -def multipartdecode(headers, content): - """ - Takes a multipart boundary encoded string and returns list of (key, value) tuples. - """ - v = headers.get("content-type") - if v: - v = parse_content_type(v) - if not v: - return [] - try: - boundary = v[2]["boundary"].encode("ascii") - except (KeyError, UnicodeError): - return [] - - rx = re.compile(br'\bname="([^"]+)"') - r = [] - - for i in content.split(b"--" + boundary): - parts = i.splitlines() - if len(parts) > 1 and parts[0][0:2] != b"--": - match = rx.search(parts[1]) - if match: - key = match.group(1) - value = b"".join(parts[3 + parts[2:].index(b""):]) - r.append((key, value)) - return r - return [] -- cgit v1.2.3 From 6689a342ae68c75bd52d81ee1959b1946739eca4 Mon Sep 17 00:00:00 2001 From: Sam Cleveland Date: Wed, 11 Nov 2015 19:53:51 -0600 Subject: Porting to Python 3.4 Fixed byte string formatting for hexdump. = test session starts = platform darwin -- Python 3.4.1, pytest-2.8.2, py-1.4.30, pluggy-0.3.1 rootdir: /Users/samc/src/python/netlib, inifile: collected 11 items netlib/test/test_utils.py ........... = 11 passed in 0.23 seconds = --- netlib/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/utils.py b/netlib/utils.py index acc7ccd4..66225897 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -85,9 +85,9 @@ def hexdump(s): A generator of (offset, hex, str) tuples """ for i in range(0, len(s), 16): - offset = b"%.10x" % i + offset = "{:0=10x}".format(i).encode() part = s[i:i + 16] - x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) + x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part)) x = x.ljust(47) # 16*2 + 15 yield (offset, x, clean_bin(part, False)) -- cgit v1.2.3 From 2bd7bcb3711a20b6a166710f2c7d989d8ae5fcc8 Mon Sep 17 00:00:00 2001 From: Sam Cleveland Date: Wed, 11 Nov 2015 20:27:10 -0600 Subject: Porting to Python 3.4 Updated wsgi to support Python 3.4 byte strings. Updated test_wsgi to remove py.test warning for TestApp having an __init__ constructor. samc$ sudo py.test netlib/test/test_wsgi.py -r w = test session starts = platform darwin -- Python 3.4.1, pytest-2.8.2, py-1.4.30, pluggy-0.3.1 rootdir: /Users/samc/src/python/netlib, inifile: collected 6 items netlib/test/test_wsgi.py ...... = 6 passed in 0.20 seconds = --- netlib/wsgi.py | 11 ++++++----- test/test_wsgi.py | 2 -- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index df248a19..d6dfae5d 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -96,16 +96,17 @@ class WSGIAdaptor(object): Make a best-effort attempt to write an error page. If headers are already sent, we just bung the error into the page. """ - c = b""" + c = """

Internal Server Error

-
%s"
+
{err}"
- """.strip() % s.encode() + """.format(err=s).strip().encode() + if not headers_sent: soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") soc.write(b"Content-Type: text/html\r\n") - soc.write(b"Content-Length: %s\r\n" % len(c)) + soc.write("Content-Length: {length}\r\n".format(length=len(c)).encode()) soc.write(b"\r\n") soc.write(c) @@ -119,7 +120,7 @@ class WSGIAdaptor(object): def write(data): if not state["headers_sent"]: - soc.write(b"HTTP/1.1 %s\r\n" % state["status"].encode()) + soc.write("HTTP/1.1 {status}\r\n".format(status=state["status"]).encode()) headers = state["headers"] if 'server' not in headers: headers["Server"] = self.sversion diff --git a/test/test_wsgi.py b/test/test_wsgi.py index fe6f09b5..ec6e8c63 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -12,8 +12,6 @@ def tflow(): class TestApp: - def __init__(self): - self.called = False def __call__(self, environ, start_response): self.called = True -- cgit v1.2.3 From 5916260849504ccf475f5a190bd7249c7709bad8 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 13 Nov 2015 20:00:54 +0100 Subject: be more conservative about dependency versions --- setup.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 729910f8..1bfab2ba 100644 --- a/setup.py +++ b/setup.py @@ -15,17 +15,17 @@ with open(os.path.join(here, 'README.mkd'), encoding='utf-8') as f: long_description = f.read() deps = { - "pyasn1>=0.1.7", - "pyOpenSSL>=0.15.1", - "cryptography>=1.0", - "passlib>=1.6.2", - "hpack>=1.0.1", - "six>=1.9.0", - "certifi>=2015.9.6.2", - "backports.ssl_match_hostname>=3.4.0.2", + "pyasn1~=0.1.9", + "pyOpenSSL~=0.15.1", + "cryptography~=1.1.0", + "passlib~=1.6.5", + "hpack~=2.0.1", + "six~=1.10.0", + "certifi>=2015.9.6.2", # no semver here - this should always be on the last release! + "backports.ssl_match_hostname~=3.4.0.2", } if sys.version_info < (3, 0): - deps.add("ipaddress>=1.0.14") + deps.add("ipaddress~=1.0.15") setup( name="netlib", -- cgit v1.2.3 From ce02874e2a7c5dacd12e58652e9665857bd46c62 Mon Sep 17 00:00:00 2001 From: Sam Cleveland Date: Sat, 14 Nov 2015 13:42:43 -0600 Subject: Fixing test_wsgi to remove py.test warnings Renamed TestApp class to ExampleApp to prevent py.test from trying to collect it as a test. --- test/test_wsgi.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_wsgi.py b/test/test_wsgi.py index ec6e8c63..8c782b27 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -10,8 +10,10 @@ def tflow(): return wsgi.Flow(("127.0.0.1", 8888), req) -class TestApp: - +class ExampleApp: + + def __init__(self): + self.called = False def __call__(self, environ, start_response): self.called = True @@ -33,7 +35,7 @@ class TestWSGI: assert r["QUERY_STRING"] == "bar=voing" def test_serve(self): - ta = TestApp() + ta = ExampleApp() w = wsgi.WSGIAdaptor(ta, "foo", 80, "version") f = tflow() f.request.host = "foo" -- cgit v1.2.3 From c1385c9a176b8d8113f05cb5e920392016bda0cd Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Tue, 17 Nov 2015 04:51:20 +1100 Subject: Fix to ignore empty header value. According to Augmented BNF in the following RFCs http://tools.ietf.org/html/rfc5234#section-3.6 http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.1 field-value = *( field-content | LWS ) http://tools.ietf.org/html/rfc7230#section-3.2 field-value = *( field-content / obs-fold ) ... the HTTP message header `field-value` is allowed to be empty. --- netlib/http/http1/read.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 0f6de26c..6e3a1b93 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -321,7 +321,7 @@ def _read_headers(rfile): try: name, value = line.split(b":", 1) value = value.strip() - if not name or not value: + if not name: raise ValueError() ret.append([name, value]) except ValueError: -- cgit v1.2.3 From cf1889e1575dbeb1257c37e97193fce1d2e65244 Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Tue, 17 Nov 2015 06:46:48 +1100 Subject: WIP. Add breaking test. --- test/http/http1/test_read.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index 45f61b4f..4524e1d0 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -297,6 +297,13 @@ class TestReadHeaders(object): with raises(HttpSyntaxException): self._read(data) + def test_read_empty_value(self): + data = b"bar:" + headers = self._read(data) + # XXX. WIP. break test + # assert headers.fields == [[b"bar", b""]] + with raises(HttpSyntaxException): + self._read(data) def test_read_chunked(): req = treq(content=None) -- cgit v1.2.3 From 52c02bc930b380f741b9bb295aa019e22687d5d3 Mon Sep 17 00:00:00 2001 From: Benjamin Lee Date: Tue, 17 Nov 2015 06:51:22 +1100 Subject: Add test for empty header field value. --- test/http/http1/test_read.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index 4524e1d0..8a315508 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -300,10 +300,7 @@ class TestReadHeaders(object): def test_read_empty_value(self): data = b"bar:" headers = self._read(data) - # XXX. WIP. break test - # assert headers.fields == [[b"bar", b""]] - with raises(HttpSyntaxException): - self._read(data) + assert headers.fields == [[b"bar", b""]] def test_read_chunked(): req = treq(content=None) -- cgit v1.2.3 From 7cb57e206fe148d544f5167a35bd02eadeca495f Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 29 Nov 2015 19:04:19 +0100 Subject: README: mkd -> rst pypi only renders reStructuredText. --- MANIFEST.in | 4 ++-- README.mkd | 24 ------------------------ README.rst | 35 +++++++++++++++++++++++++++++++++++ setup.py | 2 +- 4 files changed, 38 insertions(+), 27 deletions(-) delete mode 100644 README.mkd create mode 100644 README.rst diff --git a/MANIFEST.in b/MANIFEST.in index bd59f003..a68c043e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ -include LICENSE README.mkd -recursive-include test * +include LICENSE CONTRIBUTORS README.rst +graft test recursive-exclude * *.pyc *.pyo *.swo *.swp \ No newline at end of file diff --git a/README.mkd b/README.mkd deleted file mode 100644 index f5e66d99..00000000 --- a/README.mkd +++ /dev/null @@ -1,24 +0,0 @@ -[![Build Status](https://img.shields.io/travis/mitmproxy/netlib/master.svg)](https://travis-ci.org/mitmproxy/netlib) -[![Code Health](https://landscape.io/github/mitmproxy/netlib/master/landscape.svg?style=flat)](https://landscape.io/github/mitmproxy/netlib/master) -[![Coverage Status](https://img.shields.io/coveralls/mitmproxy/netlib/master.svg)](https://coveralls.io/r/mitmproxy/netlib) -[![Downloads](https://img.shields.io/pypi/dm/netlib.svg?color=orange)](https://pypi.python.org/pypi/netlib) -[![Latest Version](https://img.shields.io/pypi/v/netlib.svg)](https://pypi.python.org/pypi/netlib) -[![Supported Python versions](https://img.shields.io/pypi/pyversions/netlib.svg)](https://pypi.python.org/pypi/netlib) - -Netlib is a collection of network utility classes, used by the pathod and -mitmproxy projects. It differs from other projects in some fundamental -respects, because both pathod and mitmproxy often need to violate standards. -This means that protocols are implemented as small, well-contained and flexible -functions, and are designed to allow misbehaviour when needed. - - -Requirements ------------- - -* [Python](http://www.python.org) 2.7.x or a compatible version of pypy. -* Third-party packages listed in [setup.py](https://github.com/mitmproxy/netlib/blob/master/setup.py) - -Hacking -------- - -If you'd like to work on netlib, check out the instructions in mitmproxy's [README](https://github.com/mitmproxy/mitmproxy#hacking). diff --git a/README.rst b/README.rst new file mode 100644 index 00000000..694e3ad9 --- /dev/null +++ b/README.rst @@ -0,0 +1,35 @@ +|travis| |coveralls| |downloads| |latest-release| |python-versions| + +Netlib is a collection of network utility classes, used by the pathod and +mitmproxy projects. It differs from other projects in some fundamental +respects, because both pathod and mitmproxy often need to violate standards. +This means that protocols are implemented as small, well-contained and flexible +functions, and are designed to allow misbehaviour when needed. + + +Hacking +------- + +If you'd like to work on netlib, check out the instructions in mitmproxy's README_. + +.. |travis| image:: https://img.shields.io/travis/mitmproxy/netlib/master.svg + :target: https://travis-ci.org/mitmproxy/netlib + :alt: Build Status + +.. |coveralls| image:: https://img.shields.io/coveralls/mitmproxy/netlib/master.svg + :target: https://coveralls.io/r/mitmproxy/netlib + :alt: Coverage Status + +.. |downloads| image:: https://img.shields.io/pypi/dm/netlib.svg?color=orange + :target: https://pypi.python.org/pypi/netlib + :alt: Downloads + +.. |latest-release| image:: https://img.shields.io/pypi/v/netlib.svg + :target: https://pypi.python.org/pypi/netlib + :alt: Latest Version + +.. |python-versions| image:: https://img.shields.io/pypi/pyversions/netlib.svg + :target: https://pypi.python.org/pypi/netlib + :alt: Supported Python versions + +.. _README: https://github.com/mitmproxy/mitmproxy#hacking \ No newline at end of file diff --git a/setup.py b/setup.py index 1bfab2ba..c4bbb8cf 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ from netlib import version here = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(here, 'README.mkd'), encoding='utf-8') as f: +with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: long_description = f.read() deps = { -- cgit v1.2.3 From 9f224f7dbd52b02421d0b3674b2630df6afa0007 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 29 Nov 2015 19:06:54 +0100 Subject: add 3.5 compat classifiers --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index c4bbb8cf..85af963f 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,8 @@ setup( "Programming Language :: Python", "Programming Language :: Python :: 2", "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.5", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Internet", -- cgit v1.2.3 From 4718f36379c651e50a0e63954abfea8433ef514d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 3 Dec 2015 17:56:57 +0100 Subject: use version specifiers compatible with old setuptools releases --- setup.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 85af963f..96fa74e5 100644 --- a/setup.py +++ b/setup.py @@ -15,17 +15,17 @@ with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: long_description = f.read() deps = { - "pyasn1~=0.1.9", - "pyOpenSSL~=0.15.1", - "cryptography~=1.1.0", - "passlib~=1.6.5", - "hpack~=2.0.1", - "six~=1.10.0", - "certifi>=2015.9.6.2", # no semver here - this should always be on the last release! - "backports.ssl_match_hostname~=3.4.0.2", + "pyasn1>=0.1.9, <0.2", + "pyOpenSSL>=0.15.1, <0.16", + "cryptography>=1.1.1, <1.2", + "passlib>=1.6.5, <1.7", + "hpack>=2.0.1, <2.1", + "six>=1.10.0, <1.11", + "certifi>=2015.9.6.2", # no semver here - this should always be on the last release! + "backports.ssl_match_hostname>=3.4.0.2, <3.4.1", } if sys.version_info < (3, 0): - deps.add("ipaddress~=1.0.15") + deps.add("ipaddress>=1.0.15, <1.1") setup( name="netlib", -- cgit v1.2.3 From 71834421bbf63e89eb923b888ea97db437c59ea5 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 3 Dec 2015 18:13:24 +0100 Subject: bump version --- CONTRIBUTORS | 16 +++++++++------- netlib/version.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/CONTRIBUTORS b/CONTRIBUTORS index 4b4240f8..b1fb2a0f 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -1,18 +1,20 @@ - 250 Aldo Cortesi - 204 Maximilian Hils + 253 Aldo Cortesi + 210 Maximilian Hils 109 Thomas Kriechbaumer 8 Chandler Abraham 8 Kyle Morton - 2 Sean Coates - 2 Israel Nir + 5 Sam Cleveland + 3 Benjamin Lee 2 Brad Peabody - 2 Pedro Worcel + 2 Israel Nir 2 Matthias Urlichs - 1 kronick + 2 Pedro Worcel + 2 Sean Coates + 1 Andrey Plotnikov 1 Bradley Baetz 1 M. Utku Altinkaya - 1 Andrey Plotnikov 1 Paul 1 Pritam Baral 1 Rouli 1 Tim Becker + 1 kronick diff --git a/netlib/version.py b/netlib/version.py index e836dbe3..aa4ba641 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 14, 1) +IVERSION = (0, 15) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From c10b614f700055976a5e30e5cfb0142fc6b83843 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 25 Dec 2015 15:56:26 +0100 Subject: update ssl_match_hostname dependency, refs #868 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 96fa74e5..d411a498 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ deps = { "hpack>=2.0.1, <2.1", "six>=1.10.0, <1.11", "certifi>=2015.9.6.2", # no semver here - this should always be on the last release! - "backports.ssl_match_hostname>=3.4.0.2, <3.4.1", + "backports.ssl_match_hostname>=3.5.0.1, <3.6", } if sys.version_info < (3, 0): deps.add("ipaddress>=1.0.15, <1.1") -- cgit v1.2.3 From d1e6b5366c97dd31c9b9606db2bb7a8520cfbd2c Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 25 Dec 2015 16:00:50 +0100 Subject: bump version --- CONTRIBUTORS | 2 +- netlib/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTORS b/CONTRIBUTORS index b1fb2a0f..a43d31c9 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -1,5 +1,5 @@ 253 Aldo Cortesi - 210 Maximilian Hils + 212 Maximilian Hils 109 Thomas Kriechbaumer 8 Chandler Abraham 8 Kyle Morton diff --git a/netlib/version.py b/netlib/version.py index aa4ba641..7a68ca39 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 15) +IVERSION = (0, 15, 1) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 4bb9f3d35b02bb076dd8df133288492c24295c8a Mon Sep 17 00:00:00 2001 From: Sandor Nemes Date: Fri, 8 Jan 2016 18:04:47 +0100 Subject: Added getter/setter for TCPClient source_address --- netlib/tcp.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/netlib/tcp.py b/netlib/tcp.py index 8e46d4f6..e5e9ec1a 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -565,6 +565,17 @@ class TCPClient(_Connection): else: self.__address = None + @property + def source_address(self): + return self.__source_address + + @source_address.setter + def source_address(self, source_address): + if source_address: + self.__source_address = Address.wrap(source_address) + else: + self.__source_address = None + def close(self): # Make sure to close the real socket, not the SSL proxy. # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, -- cgit v1.2.3 From 739806bfe2d002e5faed6fb343a0af614198d2dd Mon Sep 17 00:00:00 2001 From: Felix Yan Date: Mon, 11 Jan 2016 00:37:43 +0800 Subject: Allow cryptography 1.2.* --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d411a498..0cc81a1b 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: deps = { "pyasn1>=0.1.9, <0.2", "pyOpenSSL>=0.15.1, <0.16", - "cryptography>=1.1.1, <1.2", + "cryptography>=1.1.1, <1.3", "passlib>=1.6.5, <1.7", "hpack>=2.0.1, <2.1", "six>=1.10.0, <1.11", -- cgit v1.2.3 From b8e8c4d68222c9292daf23e6ace55351fcef1af6 Mon Sep 17 00:00:00 2001 From: Sandor Nemes Date: Mon, 11 Jan 2016 08:10:36 +0100 Subject: Simplified setting the source_address in the TCPClient constructor --- netlib/tcp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index e5e9ec1a..8902b9dc 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -548,8 +548,7 @@ class TCPClient(_Connection): def __init__(self, address, source_address=None): super(TCPClient, self).__init__(None) self.address = address - self.source_address = Address.wrap( - source_address) if source_address else None + self.source_address = source_address self.cert = None self.ssl_verification_error = None self.sni = None -- cgit v1.2.3 From 9e2d050bb367fafafd08df48225864880527660f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 13 Jan 2016 12:05:38 +0100 Subject: upgrade cryptography dependency for new wheels --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0cc81a1b..2da10baa 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: deps = { "pyasn1>=0.1.9, <0.2", "pyOpenSSL>=0.15.1, <0.16", - "cryptography>=1.1.1, <1.3", + "cryptography>=1.2.1, <1.3", "passlib>=1.6.5, <1.7", "hpack>=2.0.1, <2.1", "six>=1.10.0, <1.11", -- cgit v1.2.3 From 1b487539b1f3ea183eaed26ae756d0cc7d3ec3ea Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 24 Jan 2016 23:24:59 +0100 Subject: move tservers to netlib module --- netlib/tservers.py | 109 +++++++++++++++++++++++++++++++++++++ test/http/http2/test_protocol.py | 3 +- test/test_tcp.py | 3 +- test/tservers.py | 107 ------------------------------------ test/websockets/test_websockets.py | 3 +- 5 files changed, 112 insertions(+), 113 deletions(-) create mode 100644 netlib/tservers.py delete mode 100644 test/tservers.py diff --git a/netlib/tservers.py b/netlib/tservers.py new file mode 100644 index 00000000..44ef8063 --- /dev/null +++ b/netlib/tservers.py @@ -0,0 +1,109 @@ +from __future__ import (absolute_import, print_function, division) + +import threading +from six.moves import queue +from io import StringIO +import OpenSSL + +from netlib import tcp +from netlib import tutils + + +class ServerThread(threading.Thread): + + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class ServerTestBase(object): + ssl = None + handler = None + addr = ("localhost", 0) + + @classmethod + def setup_class(cls): + cls.q = queue.Queue() + s = cls.makeserver() + cls.port = s.address.port + cls.server = ServerThread(s) + cls.server.start() + + @classmethod + def makeserver(cls): + return TServer(cls.ssl, cls.q, cls.handler, cls.addr) + + @classmethod + def teardown_class(cls): + cls.server.shutdown() + + @property + def last_handler(self): + return self.server.server.last_handler + + +class TServer(tcp.TCPServer): + + def __init__(self, ssl, q, handler_klass, addr): + """ + ssl: A dictionary of SSL parameters: + + cert, key, request_client_cert, cipher_list, + dhparams, v3_only + """ + tcp.TCPServer.__init__(self, addr) + + if ssl is True: + self.ssl = dict() + elif isinstance(ssl, dict): + self.ssl = ssl + else: + self.ssl = None + + self.q = q + self.handler_klass = handler_klass + self.last_handler = None + + def handle_client_connection(self, request, client_address): + h = self.handler_klass(request, client_address, self) + self.last_handler = h + if self.ssl is not None: + cert = self.ssl.get( + "cert", + tutils.test_data.path("data/server.crt")) + raw_key = self.ssl.get( + "key", + tutils.test_data.path("data/server.key")) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + open(raw_key, "rb").read()) + if self.ssl.get("v3_only", False): + method = OpenSSL.SSL.SSLv3_METHOD + options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 + else: + method = OpenSSL.SSL.SSLv23_METHOD + options = None + h.convert_to_ssl( + cert, key, + method=method, + options=options, + handle_sni=getattr(h, "handle_sni", None), + request_client_cert=self.ssl.get("request_client_cert", None), + cipher_list=self.ssl.get("cipher_list", None), + dhparams=self.ssl.get("dhparams", None), + chain_file=self.ssl.get("chain_file", None), + alpn_select=self.ssl.get("alpn_select", None) + ) + h.handle() + h.finish() + + def handle_error(self, connection, client_address, fp=None): + s = StringIO() + tcp.TCPServer.handle_error(self, connection, client_address, s) + self.q.put(s.getvalue()) diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 6bda96f5..0beec950 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,12 +1,11 @@ import OpenSSL import mock -from netlib import tcp, http, tutils +from netlib import tcp, http, tutils, tservers from netlib.exceptions import TcpDisconnect from netlib.http import Headers from netlib.http.http2.connections import HTTP2Protocol, TCPHandler from netlib.http.http2.frame import * -from ... import tservers class TestTCPHandlerWrapper: def test_wrapped(self): diff --git a/test/test_tcp.py b/test/test_tcp.py index 68d54b78..738fb2eb 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -10,8 +10,7 @@ import mock from OpenSSL import SSL import OpenSSL -from netlib import tcp, certutils, tutils -from . import tservers +from netlib import tcp, certutils, tutils, tservers from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ TcpTimeout, TcpDisconnect, TcpException diff --git a/test/tservers.py b/test/tservers.py deleted file mode 100644 index c47d6a5f..00000000 --- a/test/tservers.py +++ /dev/null @@ -1,107 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import threading -from six.moves import queue -from io import StringIO -import OpenSSL -from netlib import tcp -from netlib import tutils - - -class ServerThread(threading.Thread): - - def __init__(self, server): - self.server = server - threading.Thread.__init__(self) - - def run(self): - self.server.serve_forever() - - def shutdown(self): - self.server.shutdown() - - -class ServerTestBase(object): - ssl = None - handler = None - addr = ("localhost", 0) - - @classmethod - def setup_class(cls): - cls.q = queue.Queue() - s = cls.makeserver() - cls.port = s.address.port - cls.server = ServerThread(s) - cls.server.start() - - @classmethod - def makeserver(cls): - return TServer(cls.ssl, cls.q, cls.handler, cls.addr) - - @classmethod - def teardown_class(cls): - cls.server.shutdown() - - @property - def last_handler(self): - return self.server.server.last_handler - - -class TServer(tcp.TCPServer): - - def __init__(self, ssl, q, handler_klass, addr): - """ - ssl: A dictionary of SSL parameters: - - cert, key, request_client_cert, cipher_list, - dhparams, v3_only - """ - tcp.TCPServer.__init__(self, addr) - - if ssl is True: - self.ssl = dict() - elif isinstance(ssl, dict): - self.ssl = ssl - else: - self.ssl = None - - self.q = q - self.handler_klass = handler_klass - self.last_handler = None - - def handle_client_connection(self, request, client_address): - h = self.handler_klass(request, client_address, self) - self.last_handler = h - if self.ssl is not None: - cert = self.ssl.get( - "cert", - tutils.test_data.path("data/server.crt")) - raw_key = self.ssl.get( - "key", - tutils.test_data.path("data/server.key")) - key = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - open(raw_key, "rb").read()) - if self.ssl.get("v3_only", False): - method = OpenSSL.SSL.SSLv3_METHOD - options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 - else: - method = OpenSSL.SSL.SSLv23_METHOD - options = None - h.convert_to_ssl( - cert, key, - method=method, - options=options, - handle_sni=getattr(h, "handle_sni", None), - request_client_cert=self.ssl.get("request_client_cert", None), - cipher_list=self.ssl.get("cipher_list", None), - dhparams=self.ssl.get("dhparams", None), - chain_file=self.ssl.get("chain_file", None), - alpn_select=self.ssl.get("alpn_select", None) - ) - h.handle() - h.finish() - - def handle_error(self, connection, client_address, fp=None): - s = StringIO() - tcp.TCPServer.handle_error(self, connection, client_address, s) - self.q.put(s.getvalue()) diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 9a1e5d3d..d53f0d83 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -2,12 +2,11 @@ import os from netlib.http.http1 import read_response, read_request -from netlib import tcp, tutils, websockets, http +from netlib import tcp, websockets, http, tutils, tservers from netlib.http import status_codes from netlib.tutils import treq from netlib.exceptions import * -from .. import tservers class WebSocketsEchoHandler(tcp.BaseHandler): -- cgit v1.2.3 From 2145ded375b0b288ed350bd9fbfe259e59fc8671 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 30 Jan 2016 12:48:09 +0100 Subject: fix pypy on travis --- .travis.yml | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/.travis.yml b/.travis.yml index 0f2b1431..a60a4e69 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,10 +17,7 @@ matrix: - libssl-dev - python: 3.5 script: - - py.test -n 4 -k "not http2" . - - python: pypy3 - script: - - py.test -n 4 -k "not http2" . + - py.test -s --cov netlib -k "not http2" - python: pypy - python: pypy env: OPENSSL=1.0.2 @@ -36,16 +33,36 @@ matrix: # We allow pypy to fail until Travis fixes their infrastructure to a pypy # with a recent enought CFFI library to run cryptography 1.0+. - python: pypy - - python: pypy3 install: + - | + if [[ $TRAVIS_OS_NAME == "osx" ]] + then + brew update || brew update # try again if it fails + brew outdated openssl || brew upgrade openssl + brew install python + fi + - | + if [ "$TRAVIS_PYTHON_VERSION" = "pypy" ]; then + export PYENV_ROOT="$HOME/.pyenv" + if [ -f "$PYENV_ROOT/bin/pyenv" ]; then + pushd "$PYENV_ROOT" && git pull && popd + else + rm -rf "$PYENV_ROOT" && git clone --depth 1 https://github.com/yyuu/pyenv.git "$PYENV_ROOT" + fi + export PYPY_VERSION="4.0.1" + "$PYENV_ROOT/bin/pyenv" install --skip-existing "pypy-$PYPY_VERSION" + virtualenv --python="$PYENV_ROOT/versions/pypy-$PYPY_VERSION/bin/python" "$HOME/virtualenvs/pypy-$PYPY_VERSION" + source "$HOME/virtualenvs/pypy-$PYPY_VERSION/bin/activate" + fi + - "pip install -U pip setuptools" - "pip install --src . -r requirements.txt" before_script: - "openssl version -a" script: - - "py.test -n 4 --cov netlib" + - "py.test -s --cov netlib" after_success: - coveralls -- cgit v1.2.3 From 283c74a0eab01b817ba8c7d9f0341f9084ceae66 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 30 Jan 2016 13:38:28 +0100 Subject: allow pypy again on travis --- .travis.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index a60a4e69..84f32d71 100644 --- a/.travis.yml +++ b/.travis.yml @@ -29,10 +29,6 @@ matrix: - debian-sid packages: - libssl-dev - allow_failures: - # We allow pypy to fail until Travis fixes their infrastructure to a pypy - # with a recent enought CFFI library to run cryptography 1.0+. - - python: pypy install: - | -- cgit v1.2.3 From d253ebc142d80708a1bdc065d3db05d1394e3819 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 30 Jan 2016 22:03:24 +0100 Subject: fix test request and response headers --- netlib/http/message.py | 2 +- netlib/tutils.py | 4 ++-- test/http/http1/test_assemble.py | 5 +++-- test/http/http1/test_read.py | 4 ++-- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/netlib/http/message.py b/netlib/http/message.py index e4e799ca..28f55fa2 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -193,4 +193,4 @@ class decoded(object): def __exit__(self, type, value, tb): if self.ce: - self.message.encode(self.ce) \ No newline at end of file + self.message.encode(self.ce) diff --git a/netlib/tutils.py b/netlib/tutils.py index e16f1a76..14b4ef06 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -105,7 +105,7 @@ def treq(**kwargs): port=22, path=b"/path", http_version=b"HTTP/1.1", - headers=Headers(header="qvalue"), + headers=Headers(header="qvalue", content_length="7"), content=b"content" ) default.update(kwargs) @@ -121,7 +121,7 @@ def tresp(**kwargs): http_version=b"HTTP/1.1", status_code=200, reason=b"OK", - headers=Headers(header_response="svalue"), + headers=Headers(header_response="svalue", content_length="7"), content=b"message", timestamp_start=time.time(), timestamp_end=time.time(), diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index ed94292d..31a62438 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -24,10 +24,11 @@ def test_assemble_request(): def test_assemble_request_head(): - c = assemble_request_head(treq()) + c = assemble_request_head(treq(content="foo")) assert b"GET" in c assert b"qvalue" in c - assert b"content" not in c + assert b"content-length" in c + assert b"foo" not in c def test_assemble_response(): diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index 8a315508..90234070 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -159,10 +159,10 @@ def test_expected_http_body_size(): # no length assert expected_http_body_size( - treq() + treq(headers=Headers()) ) == 0 assert expected_http_body_size( - treq(), tresp() + treq(headers=Headers()), tresp(headers=Headers()) ) == -1 -- cgit v1.2.3 From 280b491ab2b743f75483e2916e5344b22d4136e1 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 31 Jan 2016 12:15:44 +0100 Subject: migrate to hyperframe --- netlib/http/http2/connections.py | 81 +++-- netlib/http/http2/frame.py | 651 ----------------------------------- netlib/utils.py | 21 +- setup.py | 7 +- test/http/http2/test_connections.py | 546 +++++++++++++++++++++++++++++ test/http/http2/test_frames.py | 669 ------------------------------------ test/http/http2/test_protocol.py | 538 ----------------------------- 7 files changed, 612 insertions(+), 1901 deletions(-) delete mode 100644 netlib/http/http2/frame.py create mode 100644 test/http/http2/test_connections.py delete mode 100644 test/http/http2/test_frames.py delete mode 100644 test/http/http2/test_protocol.py diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index c493abe6..c963f7c4 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -5,7 +5,8 @@ import time from hpack.hpack import Encoder, Decoder from ... import utils from .. import Headers, Response, Request -from . import frame + +from hyperframe import frame class TCPHandler(object): @@ -36,6 +37,15 @@ class HTTP2Protocol(object): CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + HTTP2_DEFAULT_SETTINGS = { + frame.SettingsFrame.HEADER_TABLE_SIZE: 4096, + frame.SettingsFrame.ENABLE_PUSH: 1, + frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None, + frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14, + frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None, + } + def __init__( self, tcp_handler=None, @@ -54,7 +64,7 @@ class HTTP2Protocol(object): self.decoder = decoder or Decoder() self.unhandled_frame_cb = unhandled_frame_cb - self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() + self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None self.connection_preface_performed = False @@ -240,9 +250,9 @@ class HTTP2Protocol(object): magic = self.tcp_handler.rfile.safe_read(magic_length) assert magic == self.CLIENT_CONNECTION_PREFACE - frm = frame.SettingsFrame(state=self, settings={ - frame.SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 0, - frame.SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 1, + frm = frame.SettingsFrame(settings={ + frame.SettingsFrame.ENABLE_PUSH: 0, + frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1, }) self.send_frame(frm, hide=True) self._receive_settings(hide=True) @@ -253,12 +263,12 @@ class HTTP2Protocol(object): self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - self.send_frame(frame.SettingsFrame(state=self), hide=True) + self.send_frame(frame.SettingsFrame(), hide=True) self._receive_settings(hide=True) # server announces own settings self._receive_settings(hide=True) # server acks my settings def send_frame(self, frm, hide=False): - raw_bytes = frm.to_bytes() + raw_bytes = frm.serialize() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() if not hide and self.dump_frames: # pragma no cover @@ -266,19 +276,19 @@ class HTTP2Protocol(object): def read_frame(self, hide=False): while True: - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + frm = utils.http2_read_frame(self.tcp_handler.rfile) if not hide and self.dump_frames: # pragma no cover print(frm.human_readable("<<")) if isinstance(frm, frame.PingFrame): - raw_bytes = frame.PingFrame(flags=frame.Frame.FLAG_ACK, payload=frm.payload).to_bytes() + raw_bytes = frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() continue - if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + if isinstance(frm, frame.SettingsFrame) and 'ACK' not in frm.flags: self._apply_settings(frm.settings, hide) - if isinstance(frm, frame.DataFrame) and frm.length > 0: - self._update_flow_control_window(frm.stream_id, frm.length) + if isinstance(frm, frame.DataFrame) and frm.flow_controlled_length > 0: + self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length) return frm def check_alpn(self): @@ -321,15 +331,13 @@ class HTTP2Protocol(object): old_value = '-' self.http2_settings[setting] = value - frm = frame.SettingsFrame( - state=self, - flags=frame.Frame.FLAG_ACK) + frm = frame.SettingsFrame(flags=['ACK']) self.send_frame(frm, hide) def _update_flow_control_window(self, stream_id, increment): - frm = frame.WindowUpdateFrame(stream_id=0, window_size_increment=increment) + frm = frame.WindowUpdateFrame(stream_id=0, window_increment=increment) self.send_frame(frm) - frm = frame.WindowUpdateFrame(stream_id=stream_id, window_size_increment=increment) + frm = frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment) self.send_frame(frm) def _create_headers(self, headers, stream_id, end_stream=True): @@ -342,43 +350,40 @@ class HTTP2Protocol(object): header_block_fragment = self.encoder.encode(headers.fields) - chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] chunks = range(0, len(header_block_fragment), chunk_size) frms = [frm_cls( - state=self, - flags=frame.Frame.FLAG_NO_FLAGS, + flags=[], stream_id=stream_id, - header_block_fragment=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] + data=header_block_fragment[i:i+chunk_size]) for frm_cls, i in frame_cls(chunks)] - last_flags = frame.Frame.FLAG_END_HEADERS + frms[-1].flags.add('END_HEADERS') if end_stream: - last_flags |= frame.Frame.FLAG_END_STREAM - frms[-1].flags = last_flags + frms[0].flags.add('END_STREAM') if self.dump_frames: # pragma no cover for frm in frms: print(frm.human_readable(">>")) - return [frm.to_bytes() for frm in frms] + return [frm.serialize() for frm in frms] def _create_body(self, body, stream_id): if body is None or len(body) == 0: return b'' - chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] chunks = range(0, len(body), chunk_size) frms = [frame.DataFrame( - state=self, - flags=frame.Frame.FLAG_NO_FLAGS, + flags=[], stream_id=stream_id, - payload=body[i:i+chunk_size]) for i in chunks] - frms[-1].flags = frame.Frame.FLAG_END_STREAM + data=body[i:i+chunk_size]) for i in chunks] + frms[-1].flags.add('END_STREAM') if self.dump_frames: # pragma no cover for frm in frms: print(frm.human_readable(">>")) - return [frm.to_bytes() for frm in frms] + return [frm.serialize() for frm in frms] def _receive_transmission(self, stream_id=None, include_body=True): if not include_body: @@ -386,7 +391,7 @@ class HTTP2Protocol(object): body_expected = True - header_block_fragment = b'' + header_blocks = b'' body = b'' while True: @@ -396,10 +401,10 @@ class HTTP2Protocol(object): (stream_id is None or frm.stream_id == stream_id) ): stream_id = frm.stream_id - header_block_fragment += frm.header_block_fragment - if frm.flags & frame.Frame.FLAG_END_STREAM: + header_blocks += frm.data + if 'END_STREAM' in frm.flags: body_expected = False - if frm.flags & frame.Frame.FLAG_END_HEADERS: + if 'END_HEADERS' in frm.flags: break else: self._handle_unexpected_frame(frm) @@ -407,14 +412,14 @@ class HTTP2Protocol(object): while body_expected: frm = self.read_frame() if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: - body += frm.payload - if frm.flags & frame.Frame.FLAG_END_STREAM: + body += frm.data + if 'END_STREAM' in frm.flags: break else: self._handle_unexpected_frame(frm) headers = Headers( - [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] + [[str(k), str(v)] for k, v in self.decoder.decode(header_blocks)] ) return stream_id, headers, body diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py deleted file mode 100644 index 188629d4..00000000 --- a/netlib/http/http2/frame.py +++ /dev/null @@ -1,651 +0,0 @@ -from __future__ import absolute_import, print_function, division -import struct -from hpack.hpack import Encoder, Decoder - -from ...utils import BiDi -from ...exceptions import HttpSyntaxException - - -ERROR_CODES = BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd -) - -CLIENT_CONNECTION_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = 0 - for flag in self.VALID_FLAGS: - valid_flags |= flag - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - class State(object): - pass - - state = State() - state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(cls, length, state): - if state: - settings = state.http2_settings - else: - settings = HTTP2_DEFAULT_SETTINGS.copy() - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise HttpSyntaxException( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(cls, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - if raw_header[:4] == b'HTTP': # pragma no cover - raise HttpSyntaxException("Expected HTTP2 Frame, got HTTP/1 connection") - - cls._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', (self.length & 0xFFFF00) >> 8, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self, direction="-"): - self.length = len(self.payload_bytes()) - - return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), - self.payload_human_readable(), - "===============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & Frame.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in range(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Size Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - s = [] - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - return "\n".join(s) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} - - -HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, -} diff --git a/netlib/utils.py b/netlib/utils.py index 66225897..c537754a 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -2,12 +2,12 @@ from __future__ import absolute_import, print_function, division import os.path import re import string +import codecs import unicodedata - import six from six.moves import urllib - +import hyperframe def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): @@ -366,3 +366,20 @@ def multipartdecode(headers, content): r.append((key, value)) return r return [] + + +def http2_read_raw_frame(rfile): + field = rfile.peek(3) + length = int(codecs.encode(field, 'hex_codec'), 16) + + if length == 4740180: + raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20)) + + raw = rfile.safe_read(9 + length) + return raw + +def http2_read_frame(rfile): + raw = http2_read_raw_frame(rfile) + frame, length = hyperframe.frame.Frame.parse_frame_header(raw[:9]) + frame.parse_body(memoryview(raw[9:])) + return frame diff --git a/setup.py b/setup.py index 2da10baa..ba723bfa 100644 --- a/setup.py +++ b/setup.py @@ -17,11 +17,12 @@ with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: deps = { "pyasn1>=0.1.9, <0.2", "pyOpenSSL>=0.15.1, <0.16", - "cryptography>=1.2.1, <1.3", + "cryptography>=1.2.2, <1.3", "passlib>=1.6.5, <1.7", - "hpack>=2.0.1, <2.1", + "hpack>=2.0.1, <3.0", + "hyperframe>=3.1.1, <4.0", "six>=1.10.0, <1.11", - "certifi>=2015.9.6.2", # no semver here - this should always be on the last release! + "certifi>=2015.11.20.1", # no semver here - this should always be on the last release! "backports.ssl_match_hostname>=3.5.0.1, <3.6", } if sys.version_info < (3, 0): diff --git a/test/http/http2/test_connections.py b/test/http/http2/test_connections.py new file mode 100644 index 00000000..edbd4f8b --- /dev/null +++ b/test/http/http2/test_connections.py @@ -0,0 +1,546 @@ +import OpenSSL +import mock +import codecs + +from hyperframe.frame import * + +from netlib import tcp, http, utils, tutils, tservers +from netlib.exceptions import TcpDisconnect +from netlib.http import Headers +from netlib.http.http2.connections import HTTP2Protocol, TCPHandler + +class TestTCPHandlerWrapper: + def test_wrapped(self): + h = TCPHandler(rfile='foo', wfile='bar') + p = HTTP2Protocol(h) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + + def test_direct(self): + p = HTTP2Protocol(rfile='foo', wfile='bar') + assert isinstance(p.tcp_handler, TCPHandler) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + + +class EchoHandler(tcp.BaseHandler): + sni = None + + def handle(self): + while True: + v = self.rfile.safe_read(1) + self.wfile.write(v) + self.wfile.flush() + + +class TestProtocol: + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") + def test_perform_connection_preface(self, mock_client_method, mock_server_method): + protocol = HTTP2Protocol(is_server=False) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert mock_client_method.called + assert not mock_server_method.called + + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") + def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): + protocol = HTTP2Protocol(is_server=True) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert not mock_client_method.called + assert mock_server_method.called + + +class TestCheckALPNMatch(tservers.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=b'h2', + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[b'h2']) + protocol = HTTP2Protocol(c) + assert protocol.check_alpn() + + +class TestCheckALPNMismatch(tservers.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=None, + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[b'h2']) + protocol = HTTP2Protocol(c) + tutils.raises(NotImplementedError, protocol.check_alpn) + + +class TestPerformServerConnectionPreface(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + try: + # send magic + self.wfile.write(codecs.decode('505249202a20485454502f322e300d0a0d0a534d0d0a0d0a', 'hex_codec')) + self.wfile.flush() + + # send empty settings frame + self.wfile.write(codecs.decode('000000040000000000', 'hex_codec')) + self.wfile.flush() + + # check empty settings frame + raw = utils.http2_read_raw_frame(self.rfile) + assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec') + + # check settings acknowledgement + raw = utils.http2_read_raw_frame(self.rfile) + assert raw == codecs.decode('000000040100000000', 'hex_codec') + + # send settings acknowledgement + self.wfile.write(codecs.decode('000000040100000000', 'hex_codec')) + self.wfile.flush() + except Exception as e: + print(e) + + def test_perform_server_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = HTTP2Protocol(c) + + assert not protocol.connection_preface_performed + protocol.perform_server_connection_preface() + assert protocol.connection_preface_performed + + frm = protocol.read_frame() + assert isinstance(frm, SettingsFrame) + assert 'ACK' in frm.flags + + tutils.raises(TcpDisconnect, protocol.perform_server_connection_preface, force=True) + + +class TestPerformClientConnectionPreface(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check magic + assert self.rfile.read(24) ==\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + def test_perform_client_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = HTTP2Protocol(c) + + assert not protocol.connection_preface_performed + protocol.perform_client_connection_preface() + assert protocol.connection_preface_performed + + +class TestClientStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = HTTP2Protocol(c) + + def test_client_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol._next_stream_id() == 1 + assert self.protocol.current_stream_id == 1 + assert self.protocol._next_stream_id() == 3 + assert self.protocol.current_stream_id == 3 + assert self.protocol._next_stream_id() == 5 + assert self.protocol.current_stream_id == 5 + + +class TestServerStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = HTTP2Protocol(c, is_server=True) + + def test_server_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol._next_stream_id() == 2 + assert self.protocol.current_stream_id == 2 + assert self.protocol._next_stream_id() == 4 + assert self.protocol.current_stream_id == 4 + assert self.protocol._next_stream_id() == 6 + assert self.protocol.current_stream_id == 6 + + +class TestApplySettings(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + # check settings acknowledgement + assert self.rfile.read(9) == '000000040100000000'.decode('hex') + self.wfile.write("OK") + self.wfile.flush() + self.rfile.safe_read(9) # just to keep the connection alive a bit longer + + ssl = True + + def test_apply_settings(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c) + + protocol._apply_settings({ + SettingsFrame.ENABLE_PUSH: 'foo', + SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar', + SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef', + }) + + assert c.rfile.safe_read(2) == "OK" + + assert protocol.http2_settings[ + SettingsFrame.ENABLE_PUSH] == 'foo' + assert protocol.http2_settings[ + SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar' + assert protocol.http2_settings[ + SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef' + + +class TestCreateHeaders(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_headers(self): + headers = http.Headers([ + (b':method', b'GET'), + (b':path', b'index.html'), + (b':scheme', b'https'), + (b'foo', b'bar')]) + + bytes = HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=True) + assert b''.join(bytes) ==\ + '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + bytes = HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=False) + assert b''.join(bytes) ==\ + '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + def test_create_headers_multiple_frames(self): + headers = http.Headers([ + (b':method', b'GET'), + (b':path', b'/'), + (b':scheme', b'https'), + (b'foo', b'bar'), + (b'server', b'version')]) + + protocol = HTTP2Protocol(self.c) + protocol.http2_settings[SettingsFrame.MAX_FRAME_SIZE] = 8 + bytes = protocol._create_headers(headers, 1, end_stream=True) + assert len(bytes) == 3 + assert bytes[0] == '000008010100000001828487408294e783'.decode('hex') + assert bytes[1] == '0000080900000000018c767f7685ee5b10'.decode('hex') + assert bytes[2] == '00000209040000000163d5'.decode('hex') + + +class TestCreateBody(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_body_empty(self): + protocol = HTTP2Protocol(self.c) + bytes = protocol._create_body(b'', 1) + assert b''.join(bytes) == ''.decode('hex') + + def test_create_body_single_frame(self): + protocol = HTTP2Protocol(self.c) + bytes = protocol._create_body(b'foobar', 1) + assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') + + def test_create_body_multiple_frames(self): + protocol = HTTP2Protocol(self.c) + protocol.http2_settings[SettingsFrame.MAX_FRAME_SIZE] = 5 + bytes = protocol._create_body(b'foobarmehm42', 1) + assert len(bytes) == 3 + assert bytes[0] == '000005000000000001666f6f6261'.decode('hex') + assert bytes[1] == '000005000000000001726d65686d'.decode('hex') + assert bytes[2] == '0000020001000000013432'.decode('hex') + + +class TestReadRequest(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + b'000003010400000001828487'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + self.rfile.safe_read(9) # just to keep the connection alive a bit longer + + ssl = True + + def test_read_request(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True + + req = protocol.read_request(NotImplemented) + + assert req.stream_id + assert req.headers.fields == [[':method', 'GET'], [':path', '/'], [':scheme', 'https']] + assert req.content == b'foobar' + + +class TestReadRequestRelative(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + b'00000c0105000000014287d5af7e4d5a777f4481f9'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_asterisk_form_in(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True + + req = protocol.read_request(NotImplemented) + + assert req.form_in == "relative" + assert req.method == "OPTIONS" + assert req.path == "*" + + +class TestReadRequestAbsolute(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + b'00001901050000000182448d9d29aee30c0e492c2a1170426366871c92585422e085'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_absolute_form_in(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True + + req = protocol.read_request(NotImplemented) + + assert req.form_in == "absolute" + assert req.scheme == "http" + assert req.host == "address" + assert req.port == 22 + + +class TestReadRequestConnect(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + b'00001b0105000000014287bdab4e9c17b7ff44871c92585422e08541871c92585422e085'.decode('hex')) + self.wfile.write( + b'00001d0105000000014287bdab4e9c17b7ff44882f91d35d055c87a741882f91d35d055c87a7'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_connect(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True + + req = protocol.read_request(NotImplemented) + assert req.form_in == "authority" + assert req.method == "CONNECT" + assert req.host == "address" + assert req.port == 22 + + req = protocol.read_request(NotImplemented) + assert req.form_in == "authority" + assert req.method == "CONNECT" + assert req.host == "example.com" + assert req.port == 443 + + +class TestReadResponse(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + b'00000801040000002a88628594e78c767f'.decode('hex')) + self.wfile.write( + b'00000600010000002a666f6f626172'.decode('hex')) + self.wfile.flush() + self.rfile.safe_read(9) # just to keep the connection alive a bit longer + + ssl = True + + def test_read_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c) + protocol.connection_preface_performed = True + + resp = protocol.read_response(NotImplemented, stream_id=42) + + assert resp.http_version == (2, 0) + assert resp.status_code == 200 + assert resp.msg == "" + assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] + assert resp.content == b'foobar' + assert resp.timestamp_end + + +class TestReadEmptyResponse(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + def handle(self): + self.wfile.write( + b'00000801050000002a88628594e78c767f'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_empty_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c) + protocol.connection_preface_performed = True + + resp = protocol.read_response(NotImplemented, stream_id=42) + + assert resp.stream_id == 42 + assert resp.http_version == (2, 0) + assert resp.status_code == 200 + assert resp.msg == b'' + assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] + assert resp.content == b'' + + +class TestAssembleRequest(object): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_request_simple(self): + bytes = HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + None, + None, + )) + assert len(bytes) == 1 + assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') + + def test_request_with_stream_id(self): + req = http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + None, + None, + ) + req.stream_id = 0x42 + bytes = HTTP2Protocol(self.c).assemble_request(req) + assert len(bytes) == 1 + assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex') + + def test_request_with_body(self): + bytes = HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + http.Headers([('foo', 'bar')]), + 'foobar', + )) + assert len(bytes) == 2 + assert bytes[0] ==\ + '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') + + +class TestAssembleResponse(object): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_simple(self): + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + (2, 0), + 200, + )) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000000288'.decode('hex') + + def test_with_stream_id(self): + resp = http.Response( + (2, 0), + 200, + ) + resp.stream_id = 0x42 + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000004288'.decode('hex') + + def test_with_body(self): + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + (2, 0), + 200, + '', + Headers(foo="bar"), + 'foobar' + )) + assert len(bytes) == 2 + assert bytes[0] ==\ + '00000901040000000288408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000002666f6f626172'.decode('hex') diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py deleted file mode 100644 index 9f41c74d..00000000 --- a/test/http/http2/test_frames.py +++ /dev/null @@ -1,669 +0,0 @@ -from io import BytesIO - -from netlib import tcp, tutils -from netlib.http.http2.frame import * - - -def hex_to_file(data): - data = data.decode('hex') - return tcp.Reader(BytesIO(data)) - - -def test_invalid_flags(): - tutils.raises( - ValueError, - DataFrame, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - payload='foobar') - - -def test_frame_equality(): - a = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - b = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - assert a == b - - -def test_too_large_frames(): - f = DataFrame( - length=9000, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar' * 3000) - tutils.raises(HttpSyntaxException, f.to_bytes) - - -def test_data_frame_to_bytes(): - f = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - assert f.to_bytes().encode('hex') == '000006000101234567666f6f626172' - - f = DataFrame( - length=11, - flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), - stream_id=0x1234567, - payload='foobar', - pad_length=3) - assert f.to_bytes().encode('hex') == '00000a00090123456703666f6f626172000000' - - f = DataFrame( - length=6, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload='foobar') - tutils.raises(ValueError, f.to_bytes) - - -def test_data_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000006000101234567666f6f626172')) - assert isinstance(f, DataFrame) - assert f.length == 6 - assert f.TYPE == DataFrame.TYPE - assert f.flags == Frame.FLAG_END_STREAM - assert f.stream_id == 0x1234567 - assert f.payload == 'foobar' - - f = Frame.from_file(hex_to_file('00000a00090123456703666f6f626172000000')) - assert isinstance(f, DataFrame) - assert f.length == 10 - assert f.TYPE == DataFrame.TYPE - assert f.flags == Frame.FLAG_END_STREAM | Frame.FLAG_PADDED - assert f.stream_id == 0x1234567 - assert f.payload == 'foobar' - - -def test_data_frame_human_readable(): - f = DataFrame( - length=11, - flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), - stream_id=0x1234567, - payload='foobar', - pad_length=3) - assert f.human_readable() - - -def test_headers_frame_to_bytes(): - f = HeadersFrame( - length=6, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex')) - assert f.to_bytes().encode('hex') == '000007010001234567668594e75e31d9' - - f = HeadersFrame( - length=10, - flags=(HeadersFrame.FLAG_PADDED), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3) - assert f.to_bytes().encode('hex') == '00000b01080123456703668594e75e31d9000000' - - f = HeadersFrame( - length=10, - flags=(HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert f.to_bytes().encode('hex') == '00000c012001234567876543212a668594e75e31d9' - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert f.to_bytes().encode('hex') == '00001001280123456703876543212a668594e75e31d9000000' - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert f.to_bytes().encode('hex') == '00001001280123456703076543212a668594e75e31d9000000' - - f = HeadersFrame( - length=6, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment='668594e75e31d9'.decode('hex')) - tutils.raises(ValueError, f.to_bytes) - - -def test_headers_frame_from_bytes(): - f = Frame.from_file(hex_to_file( - '000007010001234567668594e75e31d9')) - assert isinstance(f, HeadersFrame) - assert f.length == 7 - assert f.TYPE == HeadersFrame.TYPE - assert f.flags == Frame.FLAG_NO_FLAGS - assert f.stream_id == 0x1234567 - assert f.header_block_fragment == '668594e75e31d9'.decode('hex') - - f = Frame.from_file(hex_to_file( - '00000b01080123456703668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert f.length == 11 - assert f.TYPE == HeadersFrame.TYPE - assert f.flags == HeadersFrame.FLAG_PADDED - assert f.stream_id == 0x1234567 - assert f.header_block_fragment == '668594e75e31d9'.decode('hex') - - f = Frame.from_file(hex_to_file( - '00000c012001234567876543212a668594e75e31d9')) - assert isinstance(f, HeadersFrame) - assert f.length == 12 - assert f.TYPE == HeadersFrame.TYPE - assert f.flags == HeadersFrame.FLAG_PRIORITY - assert f.stream_id == 0x1234567 - assert f.header_block_fragment == '668594e75e31d9'.decode('hex') - assert f.exclusive == True - assert f.stream_dependency == 0x7654321 - assert f.weight == 42 - - f = Frame.from_file(hex_to_file( - '00001001280123456703876543212a668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert f.length == 16 - assert f.TYPE == HeadersFrame.TYPE - assert f.flags == HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY - assert f.stream_id == 0x1234567 - assert f.header_block_fragment == '668594e75e31d9'.decode('hex') - assert f.exclusive == True - assert f.stream_dependency == 0x7654321 - assert f.weight == 42 - - f = Frame.from_file(hex_to_file( - '00001001280123456703076543212a668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert f.length == 16 - assert f.TYPE == HeadersFrame.TYPE - assert f.flags == HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY - assert f.stream_id == 0x1234567 - assert f.header_block_fragment == '668594e75e31d9'.decode('hex') - assert f.exclusive == False - assert f.stream_dependency == 0x7654321 - assert f.weight == 42 - - -def test_headers_frame_human_readable(): - f = HeadersFrame( - length=7, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment=b'', - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert f.human_readable() - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert f.human_readable() - - -def test_priority_frame_to_bytes(): - f = PriorityFrame( - length=5, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - exclusive=True, - stream_dependency=0x0, - weight=42) - assert f.to_bytes().encode('hex') == '000005020001234567800000002a' - - f = PriorityFrame( - length=5, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - exclusive=False, - stream_dependency=0x7654321, - weight=21) - assert f.to_bytes().encode('hex') == '0000050200012345670765432115' - - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - stream_dependency=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - -def test_priority_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000005020001234567876543212a')) - assert isinstance(f, PriorityFrame) - assert f.length == 5 - assert f.TYPE == PriorityFrame.TYPE - assert f.flags == Frame.FLAG_NO_FLAGS - assert f.stream_id == 0x1234567 - assert f.exclusive == True - assert f.stream_dependency == 0x7654321 - assert f.weight == 42 - - f = Frame.from_file(hex_to_file('0000050200012345670765432115')) - assert isinstance(f, PriorityFrame) - assert f.length == 5 - assert f.TYPE == PriorityFrame.TYPE - assert f.flags == Frame.FLAG_NO_FLAGS - assert f.stream_id == 0x1234567 - assert f.exclusive == False - assert f.stream_dependency == 0x7654321 - assert f.weight == 21 - - -def test_priority_frame_human_readable(): - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - exclusive=False, - stream_dependency=0x7654321, - weight=21) - assert f.human_readable() - - -def test_rst_stream_frame_to_bytes(): - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - error_code=0x7654321) - assert f.to_bytes().encode('hex') == '00000403000123456707654321' - - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_rst_stream_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000403000123456707654321')) - assert isinstance(f, RstStreamFrame) - assert f.length == 4 - assert f.TYPE == RstStreamFrame.TYPE - assert f.flags == Frame.FLAG_NO_FLAGS - assert f.stream_id == 0x1234567 - assert f.error_code == 0x07654321 - - -def test_rst_stream_frame_human_readable(): - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - error_code=0x7654321) - assert f.human_readable() - - -def test_settings_frame_to_bytes(): - f = SettingsFrame( - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0) - assert f.to_bytes().encode('hex') == '000000040000000000' - - f = SettingsFrame( - length=0, - flags=SettingsFrame.FLAG_ACK, - stream_id=0x0) - assert f.to_bytes().encode('hex') == '000000040100000000' - - f = SettingsFrame( - length=6, - flags=SettingsFrame.FLAG_ACK, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) - assert f.to_bytes().encode('hex') == '000006040100000000000200000001' - - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert f.to_bytes().encode('hex') == '00000c040000000000000200000001000312345678' - - f = SettingsFrame( - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - -def test_settings_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000000040000000000')) - assert isinstance(f, SettingsFrame) - assert f.length == 0 - assert f.TYPE == SettingsFrame.TYPE - assert f.flags == Frame.FLAG_NO_FLAGS - assert f.stream_id == 0x0 - - f = Frame.from_file(hex_to_file('000000040100000000')) - assert isinstance(f, SettingsFrame) - assert f.length == 0 - assert f.TYPE == SettingsFrame.TYPE - assert f.flags == SettingsFrame.FLAG_ACK - assert f.stream_id == 0x0 - - f = Frame.from_file(hex_to_file('000006040100000000000200000001')) - assert isinstance(f, SettingsFrame) - assert f.length == 6 - assert f.TYPE == SettingsFrame.TYPE - assert f.flags == SettingsFrame.FLAG_ACK, 0x0 - assert f.stream_id == 0x0 - assert len(f.settings) == 1 - assert f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 1 - - f = Frame.from_file(hex_to_file( - '00000c040000000000000200000001000312345678')) - assert isinstance(f, SettingsFrame) - assert f.length == 12 - assert f.TYPE == SettingsFrame.TYPE - assert f.flags == Frame.FLAG_NO_FLAGS - assert f.stream_id == 0x0 - assert len(f.settings) == 2 - assert f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 1 - assert f.settings[SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 0x12345678 - - -def test_settings_frame_human_readable(): - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={}) - assert f.human_readable() - - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert f.human_readable() - - -def test_push_promise_frame_to_bytes(): - f = PushPromiseFrame( - length=10, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar') - assert f.to_bytes().encode('hex') == '00000a05000123456707654321666f6f626172' - - f = PushPromiseFrame( - length=14, - flags=HeadersFrame.FLAG_PADDED, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar', - pad_length=3) - assert f.to_bytes().encode('hex') == '00000e0508012345670307654321666f6f626172000000' - - f = PushPromiseFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - f = PushPromiseFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - promised_stream=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_push_promise_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000a05000123456707654321666f6f626172')) - assert isinstance(f, PushPromiseFrame) - assert f.length == 10 - assert f.TYPE == PushPromiseFrame.TYPE - assert f.flags == Frame.FLAG_NO_FLAGS - assert f.stream_id == 0x1234567 - assert f.header_block_fragment == 'foobar' - - f = Frame.from_file(hex_to_file( - '00000e0508012345670307654321666f6f626172000000')) - assert isinstance(f, PushPromiseFrame) - assert f.length == 14 - assert f.TYPE == PushPromiseFrame.TYPE - assert f.flags == PushPromiseFrame.FLAG_PADDED - assert f.stream_id == 0x1234567 - assert f.header_block_fragment == 'foobar' - - -def test_push_promise_frame_human_readable(): - f = PushPromiseFrame( - length=14, - flags=HeadersFrame.FLAG_PADDED, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar', - pad_length=3) - assert f.human_readable() - - -def test_ping_frame_to_bytes(): - f = PingFrame( - length=8, - flags=PingFrame.FLAG_ACK, - stream_id=0x0, - payload=b'foobar') - assert f.to_bytes().encode('hex') == '000008060100000000666f6f6261720000' - - f = PingFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'foobardeadbeef') - assert f.to_bytes().encode('hex') == '000008060000000000666f6f6261726465' - - f = PingFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - -def test_ping_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000008060100000000666f6f6261720000')) - assert isinstance(f, PingFrame) - assert f.length == 8 - assert f.TYPE == PingFrame.TYPE - assert f.flags == PingFrame.FLAG_ACK - assert f.stream_id == 0x0 - assert f.payload == b'foobar\0\0' - - f = Frame.from_file(hex_to_file('000008060000000000666f6f6261726465')) - assert isinstance(f, PingFrame) - assert f.length == 8 - assert f.TYPE == PingFrame.TYPE - assert f.flags == Frame.FLAG_NO_FLAGS - assert f.stream_id == 0x0 - assert f.payload == b'foobarde' - - -def test_ping_frame_human_readable(): - f = PingFrame( - length=8, - flags=PingFrame.FLAG_ACK, - stream_id=0x0, - payload=b'foobar') - assert f.human_readable() - - -def test_goaway_frame_to_bytes(): - f = GoAwayFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'') - assert f.to_bytes().encode('hex') == '0000080700000000000123456787654321' - - f = GoAwayFrame( - length=14, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'foobar') - assert f.to_bytes().encode('hex') == '00000e0700000000000123456787654321666f6f626172' - - f = GoAwayFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - last_stream=0x1234567, - error_code=0x87654321) - tutils.raises(ValueError, f.to_bytes) - - -def test_goaway_frame_from_bytes(): - f = Frame.from_file(hex_to_file( - '0000080700000000000123456787654321')) - assert isinstance(f, GoAwayFrame) - assert f.length == 8 - assert f.TYPE == GoAwayFrame.TYPE - assert f.flags == Frame.FLAG_NO_FLAGS - assert f.stream_id == 0x0 - assert f.last_stream == 0x1234567 - assert f.error_code == 0x87654321 - assert f.data == b'' - - f = Frame.from_file(hex_to_file( - '00000e0700000000000123456787654321666f6f626172')) - assert isinstance(f, GoAwayFrame) - assert f.length == 14 - assert f.TYPE == GoAwayFrame.TYPE - assert f.flags == Frame.FLAG_NO_FLAGS - assert f.stream_id == 0x0 - assert f.last_stream == 0x1234567 - assert f.error_code == 0x87654321 - assert f.data == b'foobar' - - -def test_go_away_frame_human_readable(): - f = GoAwayFrame( - length=14, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'foobar') - assert f.human_readable() - - -def test_window_update_frame_to_bytes(): - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x1234567) - assert f.to_bytes().encode('hex') == '00000408000000000001234567' - - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - window_size_increment=0x7654321) - assert f.to_bytes().encode('hex') == '00000408000123456707654321' - - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0xdeadbeef) - tutils.raises(ValueError, f.to_bytes) - - f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) - tutils.raises(ValueError, f.to_bytes) - - -def test_window_update_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000408000000000001234567')) - assert isinstance(f, WindowUpdateFrame) - assert f.length == 4 - assert f.TYPE == WindowUpdateFrame.TYPE - assert f.flags == Frame.FLAG_NO_FLAGS - assert f.stream_id == 0x0 - assert f.window_size_increment == 0x1234567 - - -def test_window_update_frame_human_readable(): - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - window_size_increment=0x7654321) - assert f.human_readable() - - -def test_continuation_frame_to_bytes(): - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - header_block_fragment='foobar') - assert f.to_bytes().encode('hex') == '000006090401234567666f6f626172' - - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x0, - header_block_fragment='foobar') - tutils.raises(ValueError, f.to_bytes) - - -def test_continuation_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000006090401234567666f6f626172')) - assert isinstance(f, ContinuationFrame) - assert f.length == 6 - assert f.TYPE == ContinuationFrame.TYPE - assert f.flags == ContinuationFrame.FLAG_END_HEADERS - assert f.stream_id == 0x1234567 - assert f.header_block_fragment == 'foobar' - - -def test_continuation_frame_human_readable(): - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - header_block_fragment='foobar') - assert f.human_readable() diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py deleted file mode 100644 index 0beec950..00000000 --- a/test/http/http2/test_protocol.py +++ /dev/null @@ -1,538 +0,0 @@ -import OpenSSL -import mock - -from netlib import tcp, http, tutils, tservers -from netlib.exceptions import TcpDisconnect -from netlib.http import Headers -from netlib.http.http2.connections import HTTP2Protocol, TCPHandler -from netlib.http.http2.frame import * - -class TestTCPHandlerWrapper: - def test_wrapped(self): - h = TCPHandler(rfile='foo', wfile='bar') - p = HTTP2Protocol(h) - assert p.tcp_handler.rfile == 'foo' - assert p.tcp_handler.wfile == 'bar' - - def test_direct(self): - p = HTTP2Protocol(rfile='foo', wfile='bar') - assert isinstance(p.tcp_handler, TCPHandler) - assert p.tcp_handler.rfile == 'foo' - assert p.tcp_handler.wfile == 'bar' - - -class EchoHandler(tcp.BaseHandler): - sni = None - - def handle(self): - while True: - v = self.rfile.safe_read(1) - self.wfile.write(v) - self.wfile.flush() - - -class TestProtocol: - @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") - @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") - def test_perform_connection_preface(self, mock_client_method, mock_server_method): - protocol = HTTP2Protocol(is_server=False) - protocol.connection_preface_performed = True - - protocol.perform_connection_preface() - assert not mock_client_method.called - assert not mock_server_method.called - - protocol.perform_connection_preface(force=True) - assert mock_client_method.called - assert not mock_server_method.called - - @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") - @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") - def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): - protocol = HTTP2Protocol(is_server=True) - protocol.connection_preface_performed = True - - protocol.perform_connection_preface() - assert not mock_client_method.called - assert not mock_server_method.called - - protocol.perform_connection_preface(force=True) - assert not mock_client_method.called - assert mock_server_method.called - - -class TestCheckALPNMatch(tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=b'h2', - ) - - if OpenSSL._util.lib.Cryptography_HAS_ALPN: - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl(alpn_protos=[b'h2']) - protocol = HTTP2Protocol(c) - assert protocol.check_alpn() - - -class TestCheckALPNMismatch(tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=None, - ) - - if OpenSSL._util.lib.Cryptography_HAS_ALPN: - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl(alpn_protos=[b'h2']) - protocol = HTTP2Protocol(c) - tutils.raises(NotImplementedError, protocol.check_alpn) - - -class TestPerformServerConnectionPreface(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # send magic - self.wfile.write( - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) - self.wfile.flush() - - # send empty settings frame - self.wfile.write('000000040000000000'.decode('hex')) - self.wfile.flush() - - # check empty settings frame - assert self.rfile.read(9) ==\ - '000000040000000000'.decode('hex') - - # check settings acknowledgement - assert self.rfile.read(9) == \ - '000000040100000000'.decode('hex') - - # send settings acknowledgement - self.wfile.write('000000040100000000'.decode('hex')) - self.wfile.flush() - - def test_perform_server_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - protocol = HTTP2Protocol(c) - - assert not protocol.connection_preface_performed - protocol.perform_server_connection_preface() - assert protocol.connection_preface_performed - - tutils.raises(TcpDisconnect, protocol.perform_server_connection_preface, force=True) - - -class TestPerformClientConnectionPreface(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # check magic - assert self.rfile.read(24) ==\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') - - # check empty settings frame - assert self.rfile.read(9) ==\ - '000000040000000000'.decode('hex') - - # send empty settings frame - self.wfile.write('000000040000000000'.decode('hex')) - self.wfile.flush() - - # check settings acknowledgement - assert self.rfile.read(9) == \ - '000000040100000000'.decode('hex') - - # send settings acknowledgement - self.wfile.write('000000040100000000'.decode('hex')) - self.wfile.flush() - - def test_perform_client_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - protocol = HTTP2Protocol(c) - - assert not protocol.connection_preface_performed - protocol.perform_client_connection_preface() - assert protocol.connection_preface_performed - - -class TestClientStreamIds(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = HTTP2Protocol(c) - - def test_client_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol._next_stream_id() == 1 - assert self.protocol.current_stream_id == 1 - assert self.protocol._next_stream_id() == 3 - assert self.protocol.current_stream_id == 3 - assert self.protocol._next_stream_id() == 5 - assert self.protocol.current_stream_id == 5 - - -class TestServerStreamIds(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = HTTP2Protocol(c, is_server=True) - - def test_server_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol._next_stream_id() == 2 - assert self.protocol.current_stream_id == 2 - assert self.protocol._next_stream_id() == 4 - assert self.protocol.current_stream_id == 4 - assert self.protocol._next_stream_id() == 6 - assert self.protocol.current_stream_id == 6 - - -class TestApplySettings(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - # check settings acknowledgement - assert self.rfile.read(9) == '000000040100000000'.decode('hex') - self.wfile.write("OK") - self.wfile.flush() - self.rfile.safe_read(9) # just to keep the connection alive a bit longer - - ssl = True - - def test_apply_settings(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = HTTP2Protocol(c) - - protocol._apply_settings({ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 'bar', - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 'deadbeef', - }) - - assert c.rfile.safe_read(2) == "OK" - - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 'foo' - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 'bar' - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE] == 'deadbeef' - - -class TestCreateHeaders(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_headers(self): - headers = http.Headers([ - (b':method', b'GET'), - (b':path', b'index.html'), - (b':scheme', b'https'), - (b'foo', b'bar')]) - - bytes = HTTP2Protocol(self.c)._create_headers( - headers, 1, end_stream=True) - assert b''.join(bytes) ==\ - '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') - - bytes = HTTP2Protocol(self.c)._create_headers( - headers, 1, end_stream=False) - assert b''.join(bytes) ==\ - '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') - - def test_create_headers_multiple_frames(self): - headers = http.Headers([ - (b':method', b'GET'), - (b':path', b'/'), - (b':scheme', b'https'), - (b'foo', b'bar'), - (b'server', b'version')]) - - protocol = HTTP2Protocol(self.c) - protocol.http2_settings[SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] = 8 - bytes = protocol._create_headers(headers, 1, end_stream=True) - assert len(bytes) == 3 - assert bytes[0] == '000008010000000001828487408294e783'.decode('hex') - assert bytes[1] == '0000080900000000018c767f7685ee5b10'.decode('hex') - assert bytes[2] == '00000209050000000163d5'.decode('hex') - - -class TestCreateBody(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_body_empty(self): - protocol = HTTP2Protocol(self.c) - bytes = protocol._create_body(b'', 1) - assert b''.join(bytes) == ''.decode('hex') - - def test_create_body_single_frame(self): - protocol = HTTP2Protocol(self.c) - bytes = protocol._create_body('foobar', 1) - assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') - - def test_create_body_multiple_frames(self): - protocol = HTTP2Protocol(self.c) - protocol.http2_settings[SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] = 5 - bytes = protocol._create_body('foobarmehm42', 1) - assert len(bytes) == 3 - assert bytes[0] == '000005000000000001666f6f6261'.decode('hex') - assert bytes[1] == '000005000000000001726d65686d'.decode('hex') - assert bytes[2] == '0000020001000000013432'.decode('hex') - - -class TestReadRequest(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - b'000003010400000001828487'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - self.rfile.safe_read(9) # just to keep the connection alive a bit longer - - ssl = True - - def test_read_request(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = HTTP2Protocol(c, is_server=True) - protocol.connection_preface_performed = True - - req = protocol.read_request(NotImplemented) - - assert req.stream_id - assert req.headers.fields == [[':method', 'GET'], [':path', '/'], [':scheme', 'https']] - assert req.content == b'foobar' - - -class TestReadRequestRelative(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - b'00000c0105000000014287d5af7e4d5a777f4481f9'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_asterisk_form_in(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = HTTP2Protocol(c, is_server=True) - protocol.connection_preface_performed = True - - req = protocol.read_request(NotImplemented) - - assert req.form_in == "relative" - assert req.method == "OPTIONS" - assert req.path == "*" - - -class TestReadRequestAbsolute(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - b'00001901050000000182448d9d29aee30c0e492c2a1170426366871c92585422e085'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_absolute_form_in(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = HTTP2Protocol(c, is_server=True) - protocol.connection_preface_performed = True - - req = protocol.read_request(NotImplemented) - - assert req.form_in == "absolute" - assert req.scheme == "http" - assert req.host == "address" - assert req.port == 22 - - -class TestReadRequestConnect(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - b'00001b0105000000014287bdab4e9c17b7ff44871c92585422e08541871c92585422e085'.decode('hex')) - self.wfile.write( - b'00001d0105000000014287bdab4e9c17b7ff44882f91d35d055c87a741882f91d35d055c87a7'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_connect(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = HTTP2Protocol(c, is_server=True) - protocol.connection_preface_performed = True - - req = protocol.read_request(NotImplemented) - assert req.form_in == "authority" - assert req.method == "CONNECT" - assert req.host == "address" - assert req.port == 22 - - req = protocol.read_request(NotImplemented) - assert req.form_in == "authority" - assert req.method == "CONNECT" - assert req.host == "example.com" - assert req.port == 443 - - -class TestReadResponse(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - b'00000801040000002a88628594e78c767f'.decode('hex')) - self.wfile.write( - b'00000600010000002a666f6f626172'.decode('hex')) - self.wfile.flush() - self.rfile.safe_read(9) # just to keep the connection alive a bit longer - - ssl = True - - def test_read_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = HTTP2Protocol(c) - protocol.connection_preface_performed = True - - resp = protocol.read_response(NotImplemented, stream_id=42) - - assert resp.http_version == (2, 0) - assert resp.status_code == 200 - assert resp.msg == "" - assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] - assert resp.content == b'foobar' - assert resp.timestamp_end - - -class TestReadEmptyResponse(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - b'00000801050000002a88628594e78c767f'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_empty_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = HTTP2Protocol(c) - protocol.connection_preface_performed = True - - resp = protocol.read_response(NotImplemented, stream_id=42) - - assert resp.stream_id == 42 - assert resp.http_version == (2, 0) - assert resp.status_code == 200 - assert resp.msg == "" - assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] - assert resp.content == b'' - - -class TestAssembleRequest(object): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_request_simple(self): - bytes = HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - None, - None, - )) - assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') - - def test_request_with_stream_id(self): - req = http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - None, - None, - ) - req.stream_id = 0x42 - bytes = HTTP2Protocol(self.c).assemble_request(req) - assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex') - - def test_request_with_body(self): - bytes = HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - http.Headers([('foo', 'bar')]), - 'foobar', - )) - assert len(bytes) == 2 - assert bytes[0] ==\ - '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') - - -class TestAssembleResponse(object): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_simple(self): - bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( - (2, 0), - 200, - )) - assert len(bytes) == 1 - assert bytes[0] ==\ - '00000101050000000288'.decode('hex') - - def test_with_stream_id(self): - resp = http.Response( - (2, 0), - 200, - ) - resp.stream_id = 0x42 - bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp) - assert len(bytes) == 1 - assert bytes[0] ==\ - '00000101050000004288'.decode('hex') - - def test_with_body(self): - bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( - (2, 0), - 200, - '', - Headers(foo="bar"), - 'foobar' - )) - assert len(bytes) == 2 - assert bytes[0] ==\ - '00000901040000000288408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000002666f6f626172'.decode('hex') -- cgit v1.2.3 From e98c729bb9b0d3debde6f61c948108bdc9dbafbe Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 31 Jan 2016 14:16:03 +0100 Subject: test on python3 --- .appveyor.yml | 2 +- .travis.yml | 14 ++- netlib/http/http2/connections.py | 40 +++++--- netlib/utils.py | 14 +-- setup.py | 1 + test/http/http2/test_connections.py | 200 +++++++++++++++++------------------- 6 files changed, 140 insertions(+), 131 deletions(-) diff --git a/.appveyor.yml b/.appveyor.yml index dbb6d2fa..cd1354c2 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -8,4 +8,4 @@ install: - "%PYTHON%\\python -c \"from OpenSSL import SSL; print(SSL.SSLeay_version(SSL.SSLEAY_VERSION))\"" build: off # Not a C# project test_script: - - "%PYTHON%\\Scripts\\py.test -n 4" \ No newline at end of file + - "%PYTHON%\\Scripts\\py.test -n 4 --timeout 10" diff --git a/.travis.yml b/.travis.yml index 84f32d71..529c7ed3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,8 +16,16 @@ matrix: packages: - libssl-dev - python: 3.5 - script: - - py.test -s --cov netlib -k "not http2" + - python: 3.5 + env: OPENSSL=1.0.2 + addons: + apt: + sources: + # Debian sid currently holds OpenSSL 1.0.2 + # change this with future releases! + - debian-sid + packages: + - libssl-dev - python: pypy - python: pypy env: OPENSSL=1.0.2 @@ -58,7 +66,7 @@ before_script: - "openssl version -a" script: - - "py.test -s --cov netlib" + - "py.test -s --cov netlib --timeout 10" after_success: - coveralls diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index c963f7c4..91133121 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -8,6 +8,11 @@ from .. import Headers, Response, Request from hyperframe import frame +# TODO: remove once hyperframe released a new version > 3.1.1 +# wrapper for deprecated name in old hyperframe release +frame.SettingsFrame.MAX_FRAME_SIZE = frame.SettingsFrame.SETTINGS_MAX_FRAME_SIZE +frame.SettingsFrame.MAX_HEADER_LIST_SIZE = frame.SettingsFrame.SETTINGS_MAX_HEADER_LIST_SIZE + class TCPHandler(object): @@ -35,7 +40,7 @@ class HTTP2Protocol(object): HTTP_1_1_REQUIRED=0xd ) - CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' HTTP2_DEFAULT_SETTINGS = { frame.SettingsFrame.HEADER_TABLE_SIZE: 4096, @@ -94,7 +99,7 @@ class HTTP2Protocol(object): timestamp_end = time.time() - authority = headers.get(':authority', '') + authority = headers.get(':authority', b'') method = headers.get(':method', 'GET') scheme = headers.get(':scheme', 'https') path = headers.get(':path', '/') @@ -113,6 +118,8 @@ class HTTP2Protocol(object): form_in = "absolute" # FIXME: verify if path or :host contains what we need scheme, host, port, _ = utils.parse_url(path) + scheme = scheme.decode('ascii') + host = host.decode('ascii') if host is None: host = 'localhost' @@ -122,18 +129,17 @@ class HTTP2Protocol(object): request = Request( form_in, - method, - scheme, - host, + method.encode('ascii'), + scheme.encode('ascii'), + host.encode('ascii'), port, - path, - (2, 0), + path.encode('ascii'), + b'2.0', headers, body, timestamp_start, timestamp_end, ) - # FIXME: We should not do this. request.stream_id = stream_id return request @@ -141,7 +147,7 @@ class HTTP2Protocol(object): def read_response( self, __rfile, - request_method='', + request_method=b'', body_size_limit=None, include_body=True, stream_id=None, @@ -170,9 +176,9 @@ class HTTP2Protocol(object): timestamp_end = None response = Response( - (2, 0), + b'2.0', int(headers.get(':status', 502)), - "", + b'', headers, body, timestamp_start=timestamp_start, @@ -200,13 +206,13 @@ class HTTP2Protocol(object): headers = request.headers.copy() if ':authority' not in headers: - headers.fields.insert(0, (':authority', bytes(authority))) + headers.fields.insert(0, (b':authority', authority.encode('ascii'))) if ':scheme' not in headers: - headers.fields.insert(0, (':scheme', bytes(request.scheme))) + headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) if ':path' not in headers: - headers.fields.insert(0, (':path', bytes(request.path))) + headers.fields.insert(0, (b':path', request.path.encode('ascii'))) if ':method' not in headers: - headers.fields.insert(0, (':method', bytes(request.method))) + headers.fields.insert(0, (b':method', request.method.encode('ascii'))) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -223,7 +229,7 @@ class HTTP2Protocol(object): headers = response.headers.copy() if ':status' not in headers: - headers.fields.insert(0, (':status', bytes(str(response.status_code)))) + headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -419,7 +425,7 @@ class HTTP2Protocol(object): self._handle_unexpected_frame(frm) headers = Headers( - [[str(k), str(v)] for k, v in self.decoder.decode(header_blocks)] + [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] ) return stream_id, headers, body diff --git a/netlib/utils.py b/netlib/utils.py index c537754a..1c1b617a 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -369,17 +369,17 @@ def multipartdecode(headers, content): def http2_read_raw_frame(rfile): - field = rfile.peek(3) - length = int(codecs.encode(field, 'hex_codec'), 16) + header = rfile.safe_read(9) + length = int(codecs.encode(header[:3], 'hex_codec'), 16) if length == 4740180: raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20)) - raw = rfile.safe_read(9 + length) - return raw + body = rfile.safe_read(length) + return [header, body] def http2_read_frame(rfile): - raw = http2_read_raw_frame(rfile) - frame, length = hyperframe.frame.Frame.parse_frame_header(raw[:9]) - frame.parse_body(memoryview(raw[9:])) + header, body = http2_read_raw_frame(rfile) + frame, length = hyperframe.frame.Frame.parse_frame_header(header) + frame.parse_body(memoryview(body)) return frame diff --git a/setup.py b/setup.py index ba723bfa..e842fd74 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,7 @@ setup( "pytest>=2.8.0", "pytest-xdist>=1.13.1", "pytest-cov>=2.1.0", + "pytest-timeout>=1.0.0", "coveralls>=0.4.1", "autopep8>=1.0.3", "autoflake>=0.6.6", diff --git a/test/http/http2/test_connections.py b/test/http/http2/test_connections.py index edbd4f8b..22a43266 100644 --- a/test/http/http2/test_connections.py +++ b/test/http/http2/test_connections.py @@ -4,11 +4,12 @@ import codecs from hyperframe.frame import * -from netlib import tcp, http, utils, tutils, tservers +from netlib import tcp, http, utils, tservers +from netlib.tutils import raises from netlib.exceptions import TcpDisconnect -from netlib.http import Headers from netlib.http.http2.connections import HTTP2Protocol, TCPHandler + class TestTCPHandlerWrapper: def test_wrapped(self): h = TCPHandler(rfile='foo', wfile='bar') @@ -92,35 +93,33 @@ class TestCheckALPNMismatch(tservers.ServerTestBase): c.connect() c.convert_to_ssl(alpn_protos=[b'h2']) protocol = HTTP2Protocol(c) - tutils.raises(NotImplementedError, protocol.check_alpn) + with raises(NotImplementedError): + protocol.check_alpn() class TestPerformServerConnectionPreface(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): - try: - # send magic - self.wfile.write(codecs.decode('505249202a20485454502f322e300d0a0d0a534d0d0a0d0a', 'hex_codec')) - self.wfile.flush() + # send magic + self.wfile.write(codecs.decode('505249202a20485454502f322e300d0a0d0a534d0d0a0d0a', 'hex_codec')) + self.wfile.flush() - # send empty settings frame - self.wfile.write(codecs.decode('000000040000000000', 'hex_codec')) - self.wfile.flush() + # send empty settings frame + self.wfile.write(codecs.decode('000000040000000000', 'hex_codec')) + self.wfile.flush() - # check empty settings frame - raw = utils.http2_read_raw_frame(self.rfile) - assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec') + # check empty settings frame + raw = utils.http2_read_raw_frame(self.rfile) + assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec') - # check settings acknowledgement - raw = utils.http2_read_raw_frame(self.rfile) - assert raw == codecs.decode('000000040100000000', 'hex_codec') + # check settings acknowledgement + raw = utils.http2_read_raw_frame(self.rfile) + assert raw == codecs.decode('000000040100000000', 'hex_codec') - # send settings acknowledgement - self.wfile.write(codecs.decode('000000040100000000', 'hex_codec')) - self.wfile.flush() - except Exception as e: - print(e) + # send settings acknowledgement + self.wfile.write(codecs.decode('000000040100000000', 'hex_codec')) + self.wfile.flush() def test_perform_server_connection_preface(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -131,11 +130,8 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase): protocol.perform_server_connection_preface() assert protocol.connection_preface_performed - frm = protocol.read_frame() - assert isinstance(frm, SettingsFrame) - assert 'ACK' in frm.flags - - tutils.raises(TcpDisconnect, protocol.perform_server_connection_preface, force=True) + with raises(TcpDisconnect): + protocol.perform_server_connection_preface(force=True) class TestPerformClientConnectionPreface(tservers.ServerTestBase): @@ -143,23 +139,22 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase): def handle(self): # check magic - assert self.rfile.read(24) ==\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + assert self.rfile.read(24) == HTTP2Protocol.CLIENT_CONNECTION_PREFACE # check empty settings frame assert self.rfile.read(9) ==\ - '000000040000000000'.decode('hex') + codecs.decode('000000040000000000', 'hex_codec') # send empty settings frame - self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.write(codecs.decode('000000040000000000', 'hex_codec')) self.wfile.flush() # check settings acknowledgement assert self.rfile.read(9) == \ - '000000040100000000'.decode('hex') + codecs.decode('000000040100000000', 'hex_codec') # send settings acknowledgement - self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.write(codecs.decode('000000040100000000', 'hex_codec')) self.wfile.flush() def test_perform_client_connection_preface(self): @@ -172,7 +167,7 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase): assert protocol.connection_preface_performed -class TestClientStreamIds(): +class TestClientStreamIds(object): c = tcp.TCPClient(("127.0.0.1", 0)) protocol = HTTP2Protocol(c) @@ -186,7 +181,7 @@ class TestClientStreamIds(): assert self.protocol.current_stream_id == 5 -class TestServerStreamIds(): +class TestServerStreamIds(object): c = tcp.TCPClient(("127.0.0.1", 0)) protocol = HTTP2Protocol(c, is_server=True) @@ -204,7 +199,7 @@ class TestApplySettings(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): # check settings acknowledgement - assert self.rfile.read(9) == '000000040100000000'.decode('hex') + assert self.rfile.read(9) == codecs.decode('000000040100000000', 'hex_codec') self.wfile.write("OK") self.wfile.flush() self.rfile.safe_read(9) # just to keep the connection alive a bit longer @@ -223,7 +218,7 @@ class TestApplySettings(tservers.ServerTestBase): SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef', }) - assert c.rfile.safe_read(2) == "OK" + assert c.rfile.safe_read(2) == b"OK" assert protocol.http2_settings[ SettingsFrame.ENABLE_PUSH] == 'foo' @@ -233,7 +228,7 @@ class TestApplySettings(tservers.ServerTestBase): SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef' -class TestCreateHeaders(): +class TestCreateHeaders(object): c = tcp.TCPClient(("127.0.0.1", 0)) def test_create_headers(self): @@ -246,14 +241,12 @@ class TestCreateHeaders(): bytes = HTTP2Protocol(self.c)._create_headers( headers, 1, end_stream=True) assert b''.join(bytes) ==\ - '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') + codecs.decode('000014010500000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec') bytes = HTTP2Protocol(self.c)._create_headers( headers, 1, end_stream=False) assert b''.join(bytes) ==\ - '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') + codecs.decode('000014010400000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec') def test_create_headers_multiple_frames(self): headers = http.Headers([ @@ -267,41 +260,42 @@ class TestCreateHeaders(): protocol.http2_settings[SettingsFrame.MAX_FRAME_SIZE] = 8 bytes = protocol._create_headers(headers, 1, end_stream=True) assert len(bytes) == 3 - assert bytes[0] == '000008010100000001828487408294e783'.decode('hex') - assert bytes[1] == '0000080900000000018c767f7685ee5b10'.decode('hex') - assert bytes[2] == '00000209040000000163d5'.decode('hex') + assert bytes[0] == codecs.decode('000008010100000001828487408294e783', 'hex_codec') + assert bytes[1] == codecs.decode('0000080900000000018c767f7685ee5b10', 'hex_codec') + assert bytes[2] == codecs.decode('00000209040000000163d5', 'hex_codec') -class TestCreateBody(): +class TestCreateBody(object): c = tcp.TCPClient(("127.0.0.1", 0)) def test_create_body_empty(self): protocol = HTTP2Protocol(self.c) bytes = protocol._create_body(b'', 1) - assert b''.join(bytes) == ''.decode('hex') + assert b''.join(bytes) == b'' def test_create_body_single_frame(self): protocol = HTTP2Protocol(self.c) bytes = protocol._create_body(b'foobar', 1) - assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') + assert b''.join(bytes) == codecs.decode('000006000100000001666f6f626172', 'hex_codec') def test_create_body_multiple_frames(self): protocol = HTTP2Protocol(self.c) protocol.http2_settings[SettingsFrame.MAX_FRAME_SIZE] = 5 bytes = protocol._create_body(b'foobarmehm42', 1) assert len(bytes) == 3 - assert bytes[0] == '000005000000000001666f6f6261'.decode('hex') - assert bytes[1] == '000005000000000001726d65686d'.decode('hex') - assert bytes[2] == '0000020001000000013432'.decode('hex') + assert bytes[0] == codecs.decode('000005000000000001666f6f6261', 'hex_codec') + assert bytes[1] == codecs.decode('000005000000000001726d65686d', 'hex_codec') + assert bytes[2] == codecs.decode('0000020001000000013432', 'hex_codec') class TestReadRequest(tservers.ServerTestBase): class handler(tcp.BaseHandler): + def handle(self): self.wfile.write( - b'000003010400000001828487'.decode('hex')) + codecs.decode('000003010400000001828487', 'hex_codec')) self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) + codecs.decode('000006000100000001666f6f626172', 'hex_codec')) self.wfile.flush() self.rfile.safe_read(9) # just to keep the connection alive a bit longer @@ -317,7 +311,7 @@ class TestReadRequest(tservers.ServerTestBase): req = protocol.read_request(NotImplemented) assert req.stream_id - assert req.headers.fields == [[':method', 'GET'], [':path', '/'], [':scheme', 'https']] + assert req.headers.fields == [[b':method', b'GET'], [b':path', b'/'], [b':scheme', b'https']] assert req.content == b'foobar' @@ -325,7 +319,7 @@ class TestReadRequestRelative(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): self.wfile.write( - b'00000c0105000000014287d5af7e4d5a777f4481f9'.decode('hex')) + codecs.decode('00000c0105000000014287d5af7e4d5a777f4481f9', 'hex_codec')) self.wfile.flush() ssl = True @@ -348,7 +342,7 @@ class TestReadRequestAbsolute(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): self.wfile.write( - b'00001901050000000182448d9d29aee30c0e492c2a1170426366871c92585422e085'.decode('hex')) + codecs.decode('00001901050000000182448d9d29aee30c0e492c2a1170426366871c92585422e085', 'hex_codec')) self.wfile.flush() ssl = True @@ -372,9 +366,9 @@ class TestReadRequestConnect(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): self.wfile.write( - b'00001b0105000000014287bdab4e9c17b7ff44871c92585422e08541871c92585422e085'.decode('hex')) + codecs.decode('00001b0105000000014287bdab4e9c17b7ff44871c92585422e08541871c92585422e085', 'hex_codec')) self.wfile.write( - b'00001d0105000000014287bdab4e9c17b7ff44882f91d35d055c87a741882f91d35d055c87a7'.decode('hex')) + codecs.decode('00001d0105000000014287bdab4e9c17b7ff44882f91d35d055c87a741882f91d35d055c87a7', 'hex_codec')) self.wfile.flush() ssl = True @@ -403,9 +397,9 @@ class TestReadResponse(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): self.wfile.write( - b'00000801040000002a88628594e78c767f'.decode('hex')) + codecs.decode('00000801040000002a88628594e78c767f', 'hex_codec')) self.wfile.write( - b'00000600010000002a666f6f626172'.decode('hex')) + codecs.decode('00000600010000002a666f6f626172', 'hex_codec')) self.wfile.flush() self.rfile.safe_read(9) # just to keep the connection alive a bit longer @@ -420,10 +414,10 @@ class TestReadResponse(tservers.ServerTestBase): resp = protocol.read_response(NotImplemented, stream_id=42) - assert resp.http_version == (2, 0) + assert resp.http_version == '2.0' assert resp.status_code == 200 - assert resp.msg == "" - assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] + assert resp.msg == '' + assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] assert resp.content == b'foobar' assert resp.timestamp_end @@ -432,7 +426,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): self.wfile.write( - b'00000801050000002a88628594e78c767f'.decode('hex')) + codecs.decode('00000801050000002a88628594e78c767f', 'hex_codec')) self.wfile.flush() ssl = True @@ -447,10 +441,10 @@ class TestReadEmptyResponse(tservers.ServerTestBase): resp = protocol.read_response(NotImplemented, stream_id=42) assert resp.stream_id == 42 - assert resp.http_version == (2, 0) + assert resp.http_version == '2.0' assert resp.status_code == 200 - assert resp.msg == b'' - assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] + assert resp.msg == '' + assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] assert resp.content == b'' @@ -459,53 +453,53 @@ class TestAssembleRequest(object): def test_request_simple(self): bytes = HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), + b'', + b'GET', + b'https', + b'', + b'', + b'/', + b'2.0', None, None, )) assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') + assert bytes[0] == codecs.decode('00000d0105000000018284874188089d5c0b8170dc07', 'hex_codec') def test_request_with_stream_id(self): req = http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), + b'', + b'GET', + b'https', + b'', + b'', + b'/', + b'2.0', None, None, ) req.stream_id = 0x42 bytes = HTTP2Protocol(self.c).assemble_request(req) assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex') + assert bytes[0] == codecs.decode('00000d0105000000428284874188089d5c0b8170dc07', 'hex_codec') def test_request_with_body(self): bytes = HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - http.Headers([('foo', 'bar')]), - 'foobar', + b'', + b'GET', + b'https', + b'', + b'', + b'/', + b'2.0', + http.Headers([(b'foo', b'bar')]), + b'foobar', )) assert len(bytes) == 2 assert bytes[0] ==\ - '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') + codecs.decode('0000150104000000018284874188089d5c0b8170dc07408294e7838c767f', 'hex_codec') assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') + codecs.decode('000006000100000001666f6f626172', 'hex_codec') class TestAssembleResponse(object): @@ -513,34 +507,34 @@ class TestAssembleResponse(object): def test_simple(self): bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( - (2, 0), + b'2.0', 200, )) assert len(bytes) == 1 assert bytes[0] ==\ - '00000101050000000288'.decode('hex') + codecs.decode('00000101050000000288', 'hex_codec') def test_with_stream_id(self): resp = http.Response( - (2, 0), + b'2.0', 200, ) resp.stream_id = 0x42 bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp) assert len(bytes) == 1 assert bytes[0] ==\ - '00000101050000004288'.decode('hex') + codecs.decode('00000101050000004288', 'hex_codec') def test_with_body(self): bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( - (2, 0), + b'2.0', 200, - '', - Headers(foo="bar"), - 'foobar' + b'', + http.Headers(foo=b"bar"), + b'foobar' )) assert len(bytes) == 2 assert bytes[0] ==\ - '00000901040000000288408294e7838c767f'.decode('hex') + codecs.decode('00000901040000000288408294e7838c767f', 'hex_codec') assert bytes[1] ==\ - '000006000100000002666f6f626172'.decode('hex') + codecs.decode('000006000100000002666f6f626172', 'hex_codec') -- cgit v1.2.3 From 7c83a709ea06f3b538f446860f3c7ed463a29b1f Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 1 Feb 2016 19:24:30 +0100 Subject: add test for Reader.peek() --- test/test_tcp.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/test_tcp.py b/test/test_tcp.py index 738fb2eb..a68bf1e6 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -713,6 +713,20 @@ class TestFileLike: tutils.raises(TcpReadIncomplete, s.safe_read, 10) +class TestPeek(tservers.ServerTestBase): + handler = EchoHandler + + def test_peek(self): + testval = b"peek!\n" + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.wfile.write(testval) + c.wfile.flush() + + assert c.rfile.peek(4) == "peek"[:4] + assert c.rfile.peek(6) == testval + + class TestAddress: def test_simple(self): -- cgit v1.2.3 From bda49dd178fee1361f3585bd7efad67883298e5a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 1 Feb 2016 19:38:14 +0100 Subject: fix #113, make Reader.peek() work on Python 3 --- netlib/tcp.py | 30 +++++++++++++++++++++++++----- test/test_tcp.py | 2 +- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 8902b9dc..57a9b737 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -25,6 +25,10 @@ from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, Tl version_check.check_pyopenssl_version() +if six.PY2: + socket_fileobject = socket._fileobject +else: + socket_fileobject = socket.SocketIO EINTR = 4 @@ -270,7 +274,7 @@ class Reader(_FileLike): TlsException if there was an error with pyOpenSSL. NotImplementedError if the underlying file object is not a (pyOpenSSL) socket """ - if isinstance(self.o, socket._fileobject): + if isinstance(self.o, socket_fileobject): try: return self.o._sock.recv(length, socket.MSG_PEEK) except socket.error as e: @@ -423,8 +427,17 @@ class _Connection(object): def __init__(self, connection): if connection: self.connection = connection - self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + # Ideally, we would use the Buffered IO in Python 3 by default. + # Unfortunately, the implementation of .peek() is broken for n>1 bytes, + # as it may just return what's left in the buffer and not all the bytes we want. + # As a workaround, we just use unbuffered sockets directly. + # https://mail.python.org/pipermail/python-dev/2009-June/089986.html + if six.PY2: + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(self.connection, "rb")) + self.wfile = Writer(socket.SocketIO(self.connection, "wb")) else: self.connection = None self.rfile = None @@ -663,8 +676,15 @@ class TCPClient(_Connection): connection.connect(self.address()) if not self.source_address: self.source_address = Address(connection.getsockname()) - self.rfile = Reader(connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(connection.makefile('wb', self.wbufsize)) + + # See _Connection.__init__ why we do this dance. + if six.PY2: + self.rfile = Reader(connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(connection, "rb")) + self.wfile = Writer(socket.SocketIO(connection, "wb")) + except (socket.error, IOError) as err: raise TcpException( 'Error connecting to "%s": %s' % diff --git a/test/test_tcp.py b/test/test_tcp.py index a68bf1e6..20a295dd 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -723,7 +723,7 @@ class TestPeek(tservers.ServerTestBase): c.wfile.write(testval) c.wfile.flush() - assert c.rfile.peek(4) == "peek"[:4] + assert c.rfile.peek(4) == b"peek"[:4] assert c.rfile.peek(6) == testval -- cgit v1.2.3 From a3af0ce71d5b4368f1d9de8d17ff5e20086edcc4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 1 Feb 2016 20:10:18 +0100 Subject: tests++ --- netlib/tcp.py | 2 +- test/test_tcp.py | 27 +++++++++++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 57a9b737..1523370b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -272,7 +272,7 @@ class Reader(_FileLike): Raises: TcpException if there was an error with the socket TlsException if there was an error with pyOpenSSL. - NotImplementedError if the underlying file object is not a (pyOpenSSL) socket + NotImplementedError if the underlying file object is not a [pyOpenSSL] socket """ if isinstance(self.o, socket_fileobject): try: diff --git a/test/test_tcp.py b/test/test_tcp.py index 20a295dd..2b091ef0 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -12,7 +12,7 @@ import OpenSSL from netlib import tcp, certutils, tutils, tservers from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ - TcpTimeout, TcpDisconnect, TcpException + TcpTimeout, TcpDisconnect, TcpException, NetlibException class EchoHandler(tcp.BaseHandler): @@ -716,15 +716,34 @@ class TestFileLike: class TestPeek(tservers.ServerTestBase): handler = EchoHandler + def _connect(self, c): + c.connect() + def test_peek(self): testval = b"peek!\n" c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() + self._connect(c) c.wfile.write(testval) c.wfile.flush() - assert c.rfile.peek(4) == b"peek"[:4] - assert c.rfile.peek(6) == testval + assert c.rfile.peek(4) == b"peek" + assert c.rfile.peek(6) == b"peek!\n" + assert c.rfile.readline() == testval + + c.close() + with tutils.raises(NetlibException): + if c.rfile.peek(1) == b"": + # Workaround for Python 2 on Unix: + # Peeking a closed connection does not raise an exception here. + raise NetlibException() + + +class TestPeekSSL(TestPeek): + ssl = True + + def _connect(self, c): + c.connect() + c.convert_to_ssl() class TestAddress: -- cgit v1.2.3 From 931b5459e92ec237914d7cca9034c5a348033bdb Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 1 Feb 2016 20:19:34 +0100 Subject: remove code duplication --- netlib/tcp.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 1523370b..682db29a 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -424,20 +424,26 @@ class _Connection(object): rbufsize = -1 wbufsize = -1 + def _makefile(self): + """ + Set up .rfile and .wfile attributes from .connection + """ + # Ideally, we would use the Buffered IO in Python 3 by default. + # Unfortunately, the implementation of .peek() is broken for n>1 bytes, + # as it may just return what's left in the buffer and not all the bytes we want. + # As a workaround, we just use unbuffered sockets directly. + # https://mail.python.org/pipermail/python-dev/2009-June/089986.html + if six.PY2: + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(self.connection, "rb")) + self.wfile = Writer(socket.SocketIO(self.connection, "wb")) + def __init__(self, connection): if connection: self.connection = connection - # Ideally, we would use the Buffered IO in Python 3 by default. - # Unfortunately, the implementation of .peek() is broken for n>1 bytes, - # as it may just return what's left in the buffer and not all the bytes we want. - # As a workaround, we just use unbuffered sockets directly. - # https://mail.python.org/pipermail/python-dev/2009-June/089986.html - if six.PY2: - self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) - else: - self.rfile = Reader(socket.SocketIO(self.connection, "rb")) - self.wfile = Writer(socket.SocketIO(self.connection, "wb")) + self._makefile() else: self.connection = None self.rfile = None @@ -676,20 +682,12 @@ class TCPClient(_Connection): connection.connect(self.address()) if not self.source_address: self.source_address = Address(connection.getsockname()) - - # See _Connection.__init__ why we do this dance. - if six.PY2: - self.rfile = Reader(connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(connection.makefile('wb', self.wbufsize)) - else: - self.rfile = Reader(socket.SocketIO(connection, "rb")) - self.wfile = Writer(socket.SocketIO(connection, "wb")) - except (socket.error, IOError) as err: raise TcpException( 'Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection + self._makefile() def settimeout(self, n): self.connection.settimeout(n) -- cgit v1.2.3 From e222858f01095c61178590123eea7b49b5d7853b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 2 Feb 2016 17:39:49 +0100 Subject: bump dependency and remove deprecated fields --- netlib/http/http2/connections.py | 5 ----- setup.py | 4 ++-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index 91133121..5e877286 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -8,11 +8,6 @@ from .. import Headers, Response, Request from hyperframe import frame -# TODO: remove once hyperframe released a new version > 3.1.1 -# wrapper for deprecated name in old hyperframe release -frame.SettingsFrame.MAX_FRAME_SIZE = frame.SettingsFrame.SETTINGS_MAX_FRAME_SIZE -frame.SettingsFrame.MAX_HEADER_LIST_SIZE = frame.SettingsFrame.SETTINGS_MAX_HEADER_LIST_SIZE - class TCPHandler(object): diff --git a/setup.py b/setup.py index e842fd74..5bb17b19 100644 --- a/setup.py +++ b/setup.py @@ -19,8 +19,8 @@ deps = { "pyOpenSSL>=0.15.1, <0.16", "cryptography>=1.2.2, <1.3", "passlib>=1.6.5, <1.7", - "hpack>=2.0.1, <3.0", - "hyperframe>=3.1.1, <4.0", + "hpack>=2.1.0, <3.0", + "hyperframe>=3.2.0, <4.0", "six>=1.10.0, <1.11", "certifi>=2015.11.20.1", # no semver here - this should always be on the last release! "backports.ssl_match_hostname>=3.5.0.1, <3.6", -- cgit v1.2.3 From a188ae5ac55c4f9564d7590c827be9a7eb9afba4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 2 Feb 2016 18:15:55 +0100 Subject: allow creation of certs without CN --- netlib/certutils.py | 5 ++++- test/test_certutils.py | 20 ++++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index e6d71c39..a0111381 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -101,7 +101,8 @@ def dummy_cert(privkey, cacert, commonname, sans): cert.gmtime_adj_notBefore(-3600 * 48) cert.gmtime_adj_notAfter(DEFAULT_EXP) cert.set_issuer(cacert.get_subject()) - cert.get_subject().CN = commonname + if commonname is not None: + cert.get_subject().CN = commonname cert.set_serial_number(int(time.time() * 10000)) if ss: cert.set_version(2) @@ -294,6 +295,8 @@ class CertStore(object): @staticmethod def asterisk_forms(dn): + if dn is None: + return [] parts = dn.split(b".") parts.reverse() curr_dn = b"" diff --git a/test/test_certutils.py b/test/test_certutils.py index 991d59d6..027dcc93 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -41,8 +41,12 @@ class TestCertStore: ca2 = certutils.CertStore.from_store(d, "test") assert ca2.get_cert(b"foo", []) - assert ca.default_ca.get_serial_number( - ) == ca2.default_ca.get_serial_number() + assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() + + def test_create_no_common_name(self): + with tutils.tmpdir() as d: + ca = certutils.CertStore.from_store(d, "test") + assert ca.get_cert(None, [])[0].cn is None def test_create_tmp(self): with tutils.tmpdir() as d: @@ -54,10 +58,6 @@ class TestCertStore: r = ca.get_cert(b"*.foo.com", []) assert r[1] == ca.default_privatekey - def test_add_cert(self): - with tutils.tmpdir() as d: - certutils.CertStore.from_store(d, "test") - def test_sans(self): with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") @@ -105,6 +105,14 @@ class TestDummyCert: ) assert r.cn == b"foo.com" + r = certutils.dummy_cert( + ca.default_privatekey, + ca.default_ca, + None, + [] + ) + assert r.cn is None + class TestSSLCert: -- cgit v1.2.3 From 4bad98cfceb576ea285ed72580448256b47fb5c2 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 5 Feb 2016 23:39:48 +0100 Subject: use setup.y environment markers --- setup.py | 46 ++++++++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/setup.py b/setup.py index 5bb17b19..a3691ab9 100644 --- a/setup.py +++ b/setup.py @@ -14,20 +14,6 @@ here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: long_description = f.read() -deps = { - "pyasn1>=0.1.9, <0.2", - "pyOpenSSL>=0.15.1, <0.16", - "cryptography>=1.2.2, <1.3", - "passlib>=1.6.5, <1.7", - "hpack>=2.1.0, <3.0", - "hyperframe>=3.2.0, <4.0", - "six>=1.10.0, <1.11", - "certifi>=2015.11.20.1", # no semver here - this should always be on the last release! - "backports.ssl_match_hostname>=3.5.0.1, <3.6", -} -if sys.version_info < (3, 0): - deps.add("ipaddress>=1.0.15, <1.1") - setup( name="netlib", version=version.VERSION, @@ -57,18 +43,30 @@ setup( packages=find_packages(), include_package_data=True, zip_safe=False, - install_requires=list(deps), + install_requires=[ + "pyasn1>=0.1.9, <0.2", + "pyOpenSSL>=0.15.1, <0.16", + "cryptography>=1.2.2, <1.3", + "passlib>=1.6.5, <1.7", + "hpack>=2.1.0, <3.0", + "hyperframe>=3.2.0, <4.0", + "six>=1.10.0, <1.11", + "certifi>=2015.11.20.1", # no semver here - this should always be on the last release! + "backports.ssl_match_hostname>=3.5.0.1, <3.6", + ], extras_require={ + # Do not use a range operator here: https://bitbucket.org/pypa/setuptools/issues/380 + # Ubuntu Trusty and other still ship with setuptools < 17.1 + ':python_version == "2.7"': [ + "ipaddress>=1.0.15, <1.1", + ], 'dev': [ - "mock>=1.0.1", - "pytest>=2.8.0", - "pytest-xdist>=1.13.1", - "pytest-cov>=2.1.0", - "pytest-timeout>=1.0.0", - "coveralls>=0.4.1", - "autopep8>=1.0.3", - "autoflake>=0.6.6", - "wheel>=0.24.0", + "mock>=1.3.0, <1.4", + "pytest>=2.8.7, <2.9", + "pytest-xdist>=1.14, <1.15", + "pytest-cov>=2.2.1, <2.3", + "pytest-timeout>=1.0.0, <1.1", + "coveralls>=1.1, <1.2" ] }, ) -- cgit v1.2.3 From cbee3bdfa6b36aabf9b36412f8fa9b6b44371be7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 6 Feb 2016 00:25:50 +0100 Subject: minor fixes --- MANIFEST.in | 1 + setup.cfg | 2 ++ setup.py | 2 +- test/tools/getcertnames | 27 +++++++++++++++++++++++++++ tools/getcertnames | 27 --------------------------- 5 files changed, 31 insertions(+), 28 deletions(-) create mode 100644 setup.cfg create mode 100644 test/tools/getcertnames delete mode 100755 tools/getcertnames diff --git a/MANIFEST.in b/MANIFEST.in index a68c043e..db0e2ed6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ include LICENSE CONTRIBUTORS README.rst graft test +prune test/tools recursive-exclude * *.pyc *.pyo *.swo *.swp \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..3480374b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[bdist_wheel] +universal=1 \ No newline at end of file diff --git a/setup.py b/setup.py index a3691ab9..bcaecad4 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ setup( "Topic :: Software Development :: Testing", "Topic :: Software Development :: Testing :: Traffic Generation", ], - packages=find_packages(), + packages=find_packages(exclude=["test", "test.*"]), include_package_data=True, zip_safe=False, install_requires=[ diff --git a/test/tools/getcertnames b/test/tools/getcertnames new file mode 100644 index 00000000..e33619f7 --- /dev/null +++ b/test/tools/getcertnames @@ -0,0 +1,27 @@ +#!/usr/bin/env python +import sys +sys.path.insert(0, "../../") +from netlib import tcp + + +def get_remote_cert(host, port, sni): + c = tcp.TCPClient((host, port)) + c.connect() + c.convert_to_ssl(sni=sni) + return c.cert + +if len(sys.argv) > 2: + port = int(sys.argv[2]) +else: + port = 443 +if len(sys.argv) > 3: + sni = sys.argv[3] +else: + sni = None + +cert = get_remote_cert(sys.argv[1], port, sni) +print "CN:", cert.cn +if cert.altnames: + print "SANs:", + for i in cert.altnames: + print "\t", i diff --git a/tools/getcertnames b/tools/getcertnames deleted file mode 100755 index e33619f7..00000000 --- a/tools/getcertnames +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python -import sys -sys.path.insert(0, "../../") -from netlib import tcp - - -def get_remote_cert(host, port, sni): - c = tcp.TCPClient((host, port)) - c.connect() - c.convert_to_ssl(sni=sni) - return c.cert - -if len(sys.argv) > 2: - port = int(sys.argv[2]) -else: - port = 443 -if len(sys.argv) > 3: - sni = sys.argv[3] -else: - sni = None - -cert = get_remote_cert(sys.argv[1], port, sni) -print "CN:", cert.cn -if cert.altnames: - print "SANs:", - for i in cert.altnames: - print "\t", i -- cgit v1.2.3 From 8f8796f9d9d49e1e968cb8c48b09f26b2a11dcb2 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 00:40:55 +0100 Subject: expose OpenSSL's HAS_ALPN --- netlib/tcp.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 682db29a..85b4b0e2 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -31,6 +31,7 @@ else: socket_fileobject = socket.SocketIO EINTR = 4 +HAS_ALPN = OpenSSL._util.lib.Cryptography_HAS_ALPN # To enable all SSL methods use: SSLv23 # then add options to disable certain methods @@ -542,7 +543,7 @@ class _Connection(object): if log_ssl_key: context.set_info_callback(log_ssl_key) - if OpenSSL._util.lib.Cryptography_HAS_ALPN: + if HAS_ALPN: if alpn_protos is not None: # advertise application layer protocols context.set_alpn_protos(alpn_protos) @@ -696,7 +697,7 @@ class TCPClient(_Connection): return self.connection.gettimeout() def get_alpn_proto_negotiated(self): - if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: + if HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: return b"" @@ -802,7 +803,7 @@ class BaseHandler(_Connection): self.connection.settimeout(n) def get_alpn_proto_negotiated(self): - if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established: + if HAS_ALPN and self.ssl_established: return self.connection.get_alpn_proto_negotiated() else: return b"" -- cgit v1.2.3 From 4873547de3c65ba7c14cace4bca7b17368b2900d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 02:10:48 +0100 Subject: minor fixes --- netlib/http/headers.py | 2 +- netlib/http/request.py | 2 +- netlib/odict.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/netlib/http/headers.py b/netlib/http/headers.py index f64e6200..6eb9db92 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -194,7 +194,7 @@ class Headers(MutableMapping): return Headers(copy.copy(self.fields)) # Implement the StateObject protocol from mitmproxy - def get_state(self, short=False): + def get_state(self): return tuple(tuple(field) for field in self.fields) def load_state(self, state): diff --git a/netlib/http/request.py b/netlib/http/request.py index 5ebf21a5..6dabb189 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -292,7 +292,7 @@ class Request(Message): return None @multipart_form.setter - def multipart_form(self): + def multipart_form(self, value): raise NotImplementedError() # Legacy diff --git a/netlib/odict.py b/netlib/odict.py index 1124b23a..90317e5e 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -169,7 +169,7 @@ class ODict(object): return count # Implement the StateObject protocol from mitmproxy - def get_state(self, short=False): + def get_state(self): return [tuple(i) for i in self.lst] def load_state(self, state): -- cgit v1.2.3 From fe0ed63c4a3486402f65638b476149ebba752055 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 04:16:58 +0100 Subject: add Serializable ABC --- netlib/certutils.py | 22 +++++++++++++++++----- netlib/http/headers.py | 7 +++---- netlib/http/message.py | 33 ++++++++++++++++++++++++++++++--- netlib/http/request.py | 5 ++--- netlib/http/response.py | 5 ++--- netlib/odict.py | 10 ++++++---- netlib/tcp.py | 17 ++++++++++++++++- netlib/utils.py | 26 +++++++++++++++++++++++++- test/http/test_headers.py | 2 +- test/http/test_request.py | 2 +- test/http/test_response.py | 2 +- test/test_odict.py | 2 +- 12 files changed, 105 insertions(+), 28 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index a0111381..ecdc0624 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -13,6 +13,8 @@ from pyasn1.error import PyAsn1Error import OpenSSL # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 +from netlib.utils import Serializable + DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = b""" @@ -361,7 +363,7 @@ class _GeneralNames(univ.SequenceOf): constraint.ValueSizeConstraint(1, 1024) -class SSLCert(object): +class SSLCert(Serializable): def __init__(self, cert): """ @@ -375,15 +377,25 @@ class SSLCert(object): def __ne__(self, other): return not self.__eq__(other) + def get_state(self): + return self.to_pem() + + def set_state(self, state): + self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) + + @classmethod + def from_state(cls, state): + cls.from_pem(state) + @classmethod - def from_pem(klass, txt): + def from_pem(cls, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) - return klass(x509) + return cls(x509) @classmethod - def from_der(klass, der): + def from_der(cls, der): pem = ssl.DER_cert_to_PEM_cert(der) - return klass.from_pem(pem) + return cls.from_pem(pem) def to_pem(self): return OpenSSL.crypto.dump_certificate( diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 6eb9db92..78404796 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -14,7 +14,7 @@ except ImportError: # pragma: nocover import six -from netlib.utils import always_byte_args, always_bytes +from netlib.utils import always_byte_args, always_bytes, Serializable if six.PY2: # pragma: nocover _native = lambda x: x @@ -27,7 +27,7 @@ else: _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping): +class Headers(MutableMapping, Serializable): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. @@ -193,11 +193,10 @@ class Headers(MutableMapping): def copy(self): return Headers(copy.copy(self.fields)) - # Implement the StateObject protocol from mitmproxy def get_state(self): return tuple(tuple(field) for field in self.fields) - def load_state(self, state): + def set_state(self, state): self.fields = [list(field) for field in state] @classmethod diff --git a/netlib/http/message.py b/netlib/http/message.py index 28f55fa2..3d65f93e 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,9 +4,10 @@ import warnings import six +from netlib.utils import Serializable +from .headers import Headers from .. import encoding, utils - CONTENT_MISSING = 0 if six.PY2: # pragma: nocover @@ -18,7 +19,7 @@ else: _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") -class MessageData(object): +class MessageData(Serializable): def __eq__(self, other): if isinstance(other, MessageData): return self.__dict__ == other.__dict__ @@ -27,8 +28,24 @@ class MessageData(object): def __ne__(self, other): return not self.__eq__(other) + def set_state(self, state): + for k, v in state.items(): + if k == "headers": + v = Headers.from_state(v) + setattr(self, k, v) + + def get_state(self): + state = vars(self).copy() + state["headers"] = state["headers"].get_state() + return state + + @classmethod + def from_state(cls, state): + state["headers"] = Headers.from_state(state["headers"]) + return cls(**state) + -class Message(object): +class Message(Serializable): def __init__(self, data): self.data = data @@ -40,6 +57,16 @@ class Message(object): def __ne__(self, other): return not self.__eq__(other) + def get_state(self): + return self.data.get_state() + + def set_state(self, state): + self.data.set_state(state) + + @classmethod + def from_state(cls, state): + return cls(**state) + @property def headers(self): """ diff --git a/netlib/http/request.py b/netlib/http/request.py index 6dabb189..0e0f88ce 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -16,9 +16,8 @@ from .message import Message, _native, _always_bytes, MessageData class RequestData(MessageData): def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, timestamp_start=None, timestamp_end=None): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) + if not isinstance(headers, Headers): + headers = Headers(headers) self.first_line_format = first_line_format self.method = method diff --git a/netlib/http/response.py b/netlib/http/response.py index 66e5ded6..8f4d6215 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -12,9 +12,8 @@ from ..odict import ODict class ResponseData(MessageData): def __init__(self, http_version, status_code, reason=None, headers=None, content=None, timestamp_start=None, timestamp_end=None): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) + if not isinstance(headers, Headers): + headers = Headers(headers) self.http_version = http_version self.status_code = status_code diff --git a/netlib/odict.py b/netlib/odict.py index 90317e5e..1e6e381a 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -3,6 +3,8 @@ import re import copy import six +from .utils import Serializable + def safe_subn(pattern, repl, target, *args, **kwargs): """ @@ -13,7 +15,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs): return re.subn(str(pattern), str(repl), target, *args, **kwargs) -class ODict(object): +class ODict(Serializable): """ A dictionary-like object for managing ordered (key, value) data. Think @@ -172,12 +174,12 @@ class ODict(object): def get_state(self): return [tuple(i) for i in self.lst] - def load_state(self, state): + def set_state(self, state): self.lst = [list(i) for i in state] @classmethod - def from_state(klass, state): - return klass([list(i) for i in state]) + def from_state(cls, state): + return cls([list(i) for i in state]) class ODictCaseless(ODict): diff --git a/netlib/tcp.py b/netlib/tcp.py index 85b4b0e2..2e91a70c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,6 +16,7 @@ import six import OpenSSL from OpenSSL import SSL +from netlib.utils import Serializable from . import certutils, version_check # This is a rather hackish way to make sure that @@ -298,7 +299,7 @@ class Reader(_FileLike): raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -class Address(object): +class Address(Serializable): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and @@ -309,6 +310,20 @@ class Address(object): self.address = tuple(address) self.use_ipv6 = use_ipv6 + def get_state(self): + return { + "address": self.address, + "use_ipv6": self.use_ipv6 + } + + def set_state(self, state): + self.address = state["address"] + self.use_ipv6 = state["use_ipv6"] + + @classmethod + def from_state(cls, state): + return Address(**state) + @classmethod def wrap(cls, t): if isinstance(t, cls): diff --git a/netlib/utils.py b/netlib/utils.py index 1c1b617a..a0c2035c 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,14 +1,38 @@ from __future__ import absolute_import, print_function, division import os.path import re -import string import codecs import unicodedata +from abc import ABCMeta, abstractmethod + import six from six.moves import urllib import hyperframe + +@six.add_metaclass(ABCMeta) +class Serializable(object): + """ + ABC for Python's pickle protocol __getstate__ and __setstate__. + """ + + @classmethod + @abstractmethod + def from_state(cls, state): + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + + @abstractmethod + def get_state(self): + raise NotImplementedError() + + @abstractmethod + def set_state(self, state): + raise NotImplementedError() + + def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): return unicode_or_bytes.encode(*encode_args) diff --git a/test/http/test_headers.py b/test/http/test_headers.py index 8bddc0b2..d50fee3e 100644 --- a/test/http/test_headers.py +++ b/test/http/test_headers.py @@ -148,5 +148,5 @@ class TestHeaders(object): headers2 = Headers() assert headers != headers2 - headers2.load_state(headers.get_state()) + headers2.set_state(headers.get_state()) assert headers == headers2 diff --git a/test/http/test_request.py b/test/http/test_request.py index 8cf69ffe..1deee387 100644 --- a/test/http/test_request.py +++ b/test/http/test_request.py @@ -12,7 +12,7 @@ from .test_message import _test_decoded_attr, _test_passthrough_attr class TestRequestData(object): def test_init(self): - with raises(AssertionError): + with raises(ValueError): treq(headers="foobar") assert isinstance(treq(headers=None).headers, Headers) diff --git a/test/http/test_response.py b/test/http/test_response.py index a1f4abd7..c7d90b16 100644 --- a/test/http/test_response.py +++ b/test/http/test_response.py @@ -8,7 +8,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr class TestResponseData(object): def test_init(self): - with raises(AssertionError): + with raises(ValueError): tresp(headers="foobar") assert isinstance(tresp(headers=None).headers, Headers) diff --git a/test/test_odict.py b/test/test_odict.py index 88197026..f0985ef6 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -24,7 +24,7 @@ class TestODict(object): nd = odict.ODict.from_state(state) assert nd == od b = odict.ODict() - b.load_state(state) + b.set_state(state) assert b == od def test_in_any(self): -- cgit v1.2.3 From 173ff0b235cdb45a8923f313807d9804830c2a2b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 04:28:49 +0100 Subject: fix py3 compat --- netlib/certutils.py | 3 ++- netlib/http/message.py | 5 ++--- netlib/tcp.py | 5 ++--- test/http/test_request.py | 2 +- test/http/test_response.py | 4 +++- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index ecdc0624..616a778e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -12,8 +12,9 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL +from .utils import Serializable + # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 -from netlib.utils import Serializable DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 # Generated with "openssl dhparam". It's too slow to generate this on startup. diff --git a/netlib/http/message.py b/netlib/http/message.py index 3d65f93e..e3d8ce37 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,7 +4,6 @@ import warnings import six -from netlib.utils import Serializable from .headers import Headers from .. import encoding, utils @@ -19,7 +18,7 @@ else: _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") -class MessageData(Serializable): +class MessageData(utils.Serializable): def __eq__(self, other): if isinstance(other, MessageData): return self.__dict__ == other.__dict__ @@ -45,7 +44,7 @@ class MessageData(Serializable): return cls(**state) -class Message(Serializable): +class Message(utils.Serializable): def __init__(self, data): self.data = data diff --git a/netlib/tcp.py b/netlib/tcp.py index 2e91a70c..c8548aea 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,8 +16,7 @@ import six import OpenSSL from OpenSSL import SSL -from netlib.utils import Serializable -from . import certutils, version_check +from . import certutils, version_check, utils # This is a rather hackish way to make sure that # the latest version of pyOpenSSL is actually installed. @@ -299,7 +298,7 @@ class Reader(_FileLike): raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -class Address(Serializable): +class Address(utils.Serializable): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and diff --git a/test/http/test_request.py b/test/http/test_request.py index 1deee387..900b2cd1 100644 --- a/test/http/test_request.py +++ b/test/http/test_request.py @@ -12,7 +12,7 @@ from .test_message import _test_decoded_attr, _test_passthrough_attr class TestRequestData(object): def test_init(self): - with raises(ValueError): + with raises(ValueError if six.PY2 else TypeError): treq(headers="foobar") assert isinstance(treq(headers=None).headers, Headers) diff --git a/test/http/test_response.py b/test/http/test_response.py index c7d90b16..14588000 100644 --- a/test/http/test_response.py +++ b/test/http/test_response.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, print_function, division +import six + from netlib.http import Headers from netlib.odict import ODict, ODictCaseless from netlib.tutils import raises, tresp @@ -8,7 +10,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr class TestResponseData(object): def test_init(self): - with raises(ValueError): + with raises(ValueError if six.PY2 else TypeError): tresp(headers="foobar") assert isinstance(tresp(headers=None).headers, Headers) -- cgit v1.2.3 From 655b521749efd5a600d342a1d95b67d32da280a8 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 04:33:10 +0100 Subject: fix docstrings --- netlib/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/netlib/utils.py b/netlib/utils.py index a0c2035c..d2fc7195 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -14,22 +14,29 @@ import hyperframe @six.add_metaclass(ABCMeta) class Serializable(object): """ - ABC for Python's pickle protocol __getstate__ and __setstate__. + Abstract Base Class that defines an API to save an object's state and restore it later on. """ @classmethod @abstractmethod def from_state(cls, state): - obj = cls.__new__(cls) - obj.__setstate__(state) - return obj + """ + Create a new object from the given state. + """ + raise NotImplementedError() @abstractmethod def get_state(self): + """ + Retrieve object state. + """ raise NotImplementedError() @abstractmethod def set_state(self, state): + """ + Set object state to the given state. + """ raise NotImplementedError() -- cgit v1.2.3 From ead9b0ab8c399feeb25e0851f2dadf654acf51f5 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 15:09:25 +0100 Subject: fix http version string --- netlib/http/http2/connections.py | 4 ++-- test/http/http2/test_connections.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index 5e877286..52fa7193 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -129,7 +129,7 @@ class HTTP2Protocol(object): host.encode('ascii'), port, path.encode('ascii'), - b'2.0', + b"HTTP/2.0", headers, body, timestamp_start, @@ -171,7 +171,7 @@ class HTTP2Protocol(object): timestamp_end = None response = Response( - b'2.0', + b"HTTP/2.0", int(headers.get(':status', 502)), b'', headers, diff --git a/test/http/http2/test_connections.py b/test/http/http2/test_connections.py index 22a43266..a115fc7c 100644 --- a/test/http/http2/test_connections.py +++ b/test/http/http2/test_connections.py @@ -414,7 +414,7 @@ class TestReadResponse(tservers.ServerTestBase): resp = protocol.read_response(NotImplemented, stream_id=42) - assert resp.http_version == '2.0' + assert resp.http_version == "HTTP/2.0" assert resp.status_code == 200 assert resp.msg == '' assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] @@ -441,7 +441,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): resp = protocol.read_response(NotImplemented, stream_id=42) assert resp.stream_id == 42 - assert resp.http_version == '2.0' + assert resp.http_version == "HTTP/2.0" assert resp.status_code == 200 assert resp.msg == '' assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] @@ -459,7 +459,7 @@ class TestAssembleRequest(object): b'', b'', b'/', - b'2.0', + b"HTTP/2.0", None, None, )) @@ -474,7 +474,7 @@ class TestAssembleRequest(object): b'', b'', b'/', - b'2.0', + b"HTTP/2.0", None, None, ) @@ -491,7 +491,7 @@ class TestAssembleRequest(object): b'', b'', b'/', - b'2.0', + b"HTTP/2.0", http.Headers([(b'foo', b'bar')]), b'foobar', )) @@ -507,7 +507,7 @@ class TestAssembleResponse(object): def test_simple(self): bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( - b'2.0', + b"HTTP/2.0", 200, )) assert len(bytes) == 1 @@ -516,7 +516,7 @@ class TestAssembleResponse(object): def test_with_stream_id(self): resp = http.Response( - b'2.0', + b"HTTP/2.0", 200, ) resp.stream_id = 0x42 @@ -527,7 +527,7 @@ class TestAssembleResponse(object): def test_with_body(self): bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( - b'2.0', + b"HTTP/2.0", 200, b'', http.Headers(foo=b"bar"), -- cgit v1.2.3 From 1dcb8b14acc3ba1f474ee9673bf4271e576fab9f Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 15:09:29 +0100 Subject: bump version --- netlib/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netlib/version.py b/netlib/version.py index 7a68ca39..8ff869cd 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 15, 1) +IVERSION = (0, 16) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3 From 1af231fe01350d056f3b6c14a695caf9eadda520 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 8 Feb 2016 20:26:17 +0100 Subject: change ci notifications --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 529c7ed3..651fdae8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -79,8 +79,8 @@ notifications: on_failure: always slack: rooms: - - mitmproxy:YaDGC9Gt9TEM7o8zkC2OLNsu - on_success: change + - mitmproxy:YaDGC9Gt9TEM7o8zkC2OLNsu#ci + on_success: always on_failure: always # exclude cryptography from cache -- cgit v1.2.3 From aafa69a73829a7ec291a2d6fa0c4522caf287d17 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 14 Feb 2016 17:25:30 +0100 Subject: bump version --- CONTRIBUTORS | 6 ++++-- netlib/version.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/CONTRIBUTORS b/CONTRIBUTORS index a43d31c9..3a4b9b46 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -1,10 +1,11 @@ 253 Aldo Cortesi - 212 Maximilian Hils - 109 Thomas Kriechbaumer + 230 Maximilian Hils + 123 Thomas Kriechbaumer 8 Chandler Abraham 8 Kyle Morton 5 Sam Cleveland 3 Benjamin Lee + 3 Sandor Nemes 2 Brad Peabody 2 Israel Nir 2 Matthias Urlichs @@ -12,6 +13,7 @@ 2 Sean Coates 1 Andrey Plotnikov 1 Bradley Baetz + 1 Felix Yan 1 M. Utku Altinkaya 1 Paul 1 Pritam Baral diff --git a/netlib/version.py b/netlib/version.py index 8ff869cd..bc35c30f 100644 --- a/netlib/version.py +++ b/netlib/version.py @@ -1,6 +1,6 @@ from __future__ import (absolute_import, print_function, division) -IVERSION = (0, 16) +IVERSION = (0, 17) VERSION = ".".join(str(i) for i in IVERSION) MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) NAME = "netlib" -- cgit v1.2.3