diff options
-rw-r--r-- | netlib/http/http1/protocol.py | 14 | ||||
-rw-r--r-- | netlib/http/semantics.py | 43 | ||||
-rw-r--r-- | netlib/odict.py | 3 | ||||
-rw-r--r-- | netlib/tutils.py | 4 | ||||
-rw-r--r-- | netlib/utils.py | 56 | ||||
-rw-r--r-- | test/http/http1/test_protocol.py | 265 | ||||
-rw-r--r-- | test/http/http2/test_protocol.py | 247 | ||||
-rw-r--r-- | test/http/test_exceptions.py | 29 | ||||
-rw-r--r-- | test/http/test_semantics.py | 389 | ||||
-rw-r--r-- | test/test_utils.py | 31 | ||||
-rw-r--r-- | test/websockets/test_websockets.py | 7 |
11 files changed, 778 insertions, 310 deletions
diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 2e85a762..31e9cc85 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -360,20 +360,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): @classmethod - def request_preamble(self, method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - - @classmethod - def response_preamble(self, code, message=None, http_major="1", http_minor="1"): - if message is None: - message = status_codes.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) - - - @classmethod def has_chunked_encoding(self, headers): return "chunked" in [ i.lower() for i in utils.get_header_tokens(headers, "transfer-encoding") diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index e7ae2b5f..974fe6e6 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -7,7 +7,7 @@ import urllib import urlparse from .. import utils, odict -from . import cookies +from . import cookies, exceptions from netlib import utils, encoding HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" @@ -18,10 +18,10 @@ CONTENT_MISSING = 0 class ProtocolMixin(object): - def read_request(self): + def read_request(self, *args, **kwargs): # pragma: no cover raise NotImplemented - def read_response(self): + def read_response(self, *args, **kwargs): # pragma: no cover raise NotImplemented def assemble(self, message): @@ -32,14 +32,23 @@ class ProtocolMixin(object): else: raise ValueError("HTTP message not supported.") - def assemble_request(self, request): + def assemble_request(self, request): # pragma: no cover raise NotImplemented - def assemble_response(self, response): + def assemble_response(self, response): # pragma: no cover raise NotImplemented class Request(object): + # This list is adopted legacy code. + # We probably don't need to strip off keep-alive. + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding', + 'Upgrade', + ] def __init__( self, @@ -71,7 +80,6 @@ class Request(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] @@ -114,7 +122,7 @@ class Request(object): self.httpversion[1], ) else: - raise http.HttpError(400, "Invalid request form") + raise exceptions.HttpError(400, "Invalid request form") def anticache(self): """ @@ -143,7 +151,7 @@ class Request(object): if self.headers["accept-encoding"]: self.headers["accept-encoding"] = [ ', '.join( - e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0])] + e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))] def update_host_header(self): """ @@ -317,12 +325,12 @@ class Request(object): self.scheme, self.host, self.port, self.path = parts @property - def content(self): + def content(self): # pragma: no cover # TODO: remove deprecated getter return self.body @content.setter - def content(self, content): + def content(self, content): # pragma: no cover # TODO: remove deprecated setter self.body = content @@ -343,6 +351,11 @@ class EmptyRequest(Request): class Response(object): + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Alternate-Protocol', + 'Alt-Svc', + ] def __init__( self, @@ -368,7 +381,6 @@ class Response(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] @@ -393,7 +405,6 @@ class Response(object): size=size ) - def get_cookies(self): """ Get the contents of all Set-Cookie headers. @@ -430,21 +441,21 @@ class Response(object): self.headers["Set-Cookie"] = values @property - def content(self): + def content(self): # pragma: no cover # TODO: remove deprecated getter return self.body @content.setter - def content(self, content): + def content(self, content): # pragma: no cover # TODO: remove deprecated setter self.body = content @property - def code(self): + def code(self): # pragma: no cover # TODO: remove deprecated getter return self.status_code @code.setter - def code(self, code): + def code(self, code): # pragma: no cover # TODO: remove deprecated setter self.status_code = code diff --git a/netlib/odict.py b/netlib/odict.py index d02de08d..11d5d52a 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -91,8 +91,9 @@ class ODict(object): self.lst = self._filter_lst(k, self.lst) def __contains__(self, k): + k = self._kconv(k) for i in self.lst: - if self._kconv(i[0]) == self._kconv(k): + if self._kconv(i[0]) == k: return True return False diff --git a/netlib/tutils.py b/netlib/tutils.py index 5018b9e8..3c471d0d 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -119,7 +119,7 @@ def tresp(content="message"): "OK", headers, content, - time.time(), - time.time(), + timestamp_start=time.time(), + timestamp_end=time.time(), ) return resp diff --git a/netlib/utils.py b/netlib/utils.py index 35ea0ec7..2dfcafc6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -4,6 +4,7 @@ import cgi import urllib import urlparse import string +import re def isascii(s): @@ -239,3 +240,58 @@ def urldecode(s): Takes a urlencoded string and returns a list of (key, value) tuples. """ return cgi.parse_qsl(s, keep_blank_values=True) + + +def parse_content_type(c): + """ + A simple parser for content-type values. Returns a (type, subtype, + parameters) tuple, where type and subtype are strings, and parameters + is a dict. If the string could not be parsed, return None. + + E.g. the following string: + + text/html; charset=UTF-8 + + Returns: + + ("text", "html", {"charset": "UTF-8"}) + """ + parts = c.split(";", 1) + ts = parts[0].split("/", 1) + if len(ts) != 2: + return None + d = {} + if len(parts) == 2: + for i in parts[1].split(";"): + clause = i.split("=", 1) + if len(clause) == 2: + d[clause[0].strip()] = clause[1].strip() + return ts[0].lower(), ts[1].lower(), d + + +def multipartdecode(hdrs, content): + """ + Takes a multipart boundary encoded string and returns list of (key, value) tuples. + """ + v = hdrs.get_first("content-type") + if v: + v = parse_content_type(v) + if not v: + return [] + boundary = v[2].get("boundary") + if not boundary: + return [] + + rx = re.compile(r'\bname="([^"]+)"') + r = [] + + for i in content.split("--" + boundary): + parts = i.splitlines() + if len(parts) > 1 and parts[0][0:2] != "--": + match = rx.search(parts[1]) + if match: + key = match.group(1) + value = "".join(parts[3 + parts[2:].index(""):]) + r.append((key, value)) + return r + return [] diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index e3c3ff43..ff70b87d 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -3,16 +3,40 @@ import textwrap import binascii from netlib import http, odict, tcp, tutils +from netlib.http import semantics from netlib.http.http1 import HTTP1Protocol from ... import tservers -def mock_protocol(data='', chunked=False): +class NoContentLengthHTTPHandler(tcp.BaseHandler): + def handle(self): + self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") + self.wfile.flush() + + +def mock_protocol(data=''): rfile = cStringIO.StringIO(data) wfile = cStringIO.StringIO() return HTTP1Protocol(rfile=rfile, wfile=wfile) +def match_http_string(data): + return textwrap.dedent(data).strip().replace('\n', '\r\n') + + +def test_stripped_chunked_encoding_no_content(): + """ + https://github.com/mitmproxy/mitmproxy/issues/186 + """ + + r = tutils.treq(content="") + r.headers["Transfer-Encoding"] = ["chunked"] + assert "Content-Length" in mock_protocol()._assemble_request_headers(r) + + r = tutils.tresp(content="") + r.headers["Transfer-Encoding"] = ["chunked"] + assert "Content-Length" in mock_protocol()._assemble_response_headers(r) + def test_has_chunked_encoding(): h = odict.ODictCaseless() @@ -75,7 +99,6 @@ def test_connection_close(): assert HTTP1Protocol.connection_close((1, 1), h) - def test_read_http_body_request(): h = odict.ODictCaseless() data = "testing" @@ -85,7 +108,7 @@ def test_read_http_body_request(): def test_read_http_body_response(): h = odict.ODictCaseless() data = "testing" - assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing" + assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing" def test_read_http_body(): @@ -129,13 +152,13 @@ def test_read_http_body(): # test no content length: limit > actual content h = odict.ODictCaseless() data = "testing" - assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7 + assert len(mock_protocol(data).read_http_body(h, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content data = "testing" tutils.raises( http.HttpError, - mock_protocol(data, chunked=True).read_http_body, + mock_protocol(data).read_http_body, h, 4, "GET", 200, False ) @@ -143,7 +166,7 @@ def test_read_http_body(): h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] data = "5\r\naaaaa\r\n0\r\n\r\n" - assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa" + assert mock_protocol(data).read_http_body(h, 100, "GET", 200, False) == "aaaaa" def test_expected_http_body_size(): @@ -167,6 +190,13 @@ def test_expected_http_body_size(): assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0 +def test_get_request_line(): + data = "\nfoo" + p = mock_protocol(data) + assert p._get_request_line() == "foo" + assert not p._get_request_line() + + def test_parse_http_protocol(): assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1) assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0) @@ -269,96 +299,7 @@ class TestReadHeaders: assert self._read(data) is None -class NoContentLengthHTTPHandler(tcp.BaseHandler): - - def handle(self): - self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") - self.wfile.flush() - - -class TestReadResponseNoContentLength(tservers.ServerTestBase): - handler = NoContentLengthHTTPHandler - - def test_no_content_length(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - resp = HTTP1Protocol(c).read_response("GET", None) - assert resp.body == "bar\r\n\r\n" - - -def test_read_response(): - def tst(data, method, body_size_limit, include_body=True): - data = textwrap.dedent(data) - return mock_protocol(data).read_response( - method, body_size_limit, include_body=include_body - ) - - tutils.raises("server disconnect", tst, "", "GET", None) - tutils.raises("invalid server response", tst, "foo", "GET", None) - data = """ - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http.Response( - (1, 1), 200, 'OK', odict.ODictCaseless(), '' - ) - data = """ - HTTP/1.1 200 - """ - assert tst(data, "GET", None) == http.Response( - (1, 1), 200, '', odict.ODictCaseless(), '' - ) - data = """ - HTTP/x 200 OK - """ - tutils.raises("invalid http version", tst, data, "GET", None) - data = """ - HTTP/1.1 xx OK - """ - tutils.raises("invalid server response", tst, data, "GET", None) - - data = """ - HTTP/1.1 100 CONTINUE - - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http.Response( - (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' - ) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None).body == 'foo' - assert tst(data, "HEAD", None).body == '' - - data = """ - HTTP/1.1 200 OK - \tContent-Length: 3 - - foo - """ - tutils.raises("invalid headers", tst, data, "GET", None) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None, include_body=False).body is None - - -def test_get_request_line(): - data = "\nfoo" - p = mock_protocol(data) - assert p._get_request_line() == "foo" - assert not p._get_request_line() - - -class TestReadRequest(): +class TestReadRequest(object): def tst(self, data, **kwargs): return mock_protocol(data).read_request(**kwargs) @@ -385,6 +326,10 @@ class TestReadRequest(): "\r\n" ) + def test_empty(self): + v = self.tst("", allow_empty=True) + assert isinstance(v, semantics.EmptyRequest) + def test_asterisk_form_in(self): v = self.tst("OPTIONS * HTTP/1.1") assert v.form_in == "relative" @@ -427,3 +372,131 @@ class TestReadRequest(): assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" assert v.body == "foo" assert p.tcp_handler.rfile.read(3) == "bar" + + +class TestReadResponse(object): + def tst(self, data, method, body_size_limit, include_body=True): + data = textwrap.dedent(data) + return mock_protocol(data).read_response( + method, body_size_limit, include_body=include_body + ) + + def test_errors(self): + tutils.raises("server disconnect", self.tst, "", "GET", None) + tutils.raises("invalid server response", self.tst, "foo", "GET", None) + + def test_simple(self): + data = """ + HTTP/1.1 200 + """ + assert self.tst(data, "GET", None) == http.Response( + (1, 1), 200, '', odict.ODictCaseless(), '' + ) + + def test_simple_message(self): + data = """ + HTTP/1.1 200 OK + """ + assert self.tst(data, "GET", None) == http.Response( + (1, 1), 200, 'OK', odict.ODictCaseless(), '' + ) + + def test_invalid_http_version(self): + data = """ + HTTP/x 200 OK + """ + tutils.raises("invalid http version", self.tst, data, "GET", None) + + def test_invalid_status_code(self): + data = """ + HTTP/1.1 xx OK + """ + tutils.raises("invalid server response", self.tst, data, "GET", None) + + def test_valid_with_continue(self): + data = """ + HTTP/1.1 100 CONTINUE + + HTTP/1.1 200 OK + """ + assert self.tst(data, "GET", None) == http.Response( + (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' + ) + + def test_simple_body(self): + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert self.tst(data, "GET", None).body == 'foo' + assert self.tst(data, "HEAD", None).body == '' + + def test_invalid_headers(self): + data = """ + HTTP/1.1 200 OK + \tContent-Length: 3 + + foo + """ + tutils.raises("invalid headers", self.tst, data, "GET", None) + + def test_without_body(self): + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert self.tst(data, "GET", None, include_body=False).body is None + + +class TestReadResponseNoContentLength(tservers.ServerTestBase): + handler = NoContentLengthHTTPHandler + + def test_no_content_length(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + resp = HTTP1Protocol(c).read_response("GET", None) + assert resp.body == "bar\r\n\r\n" + + +class TestAssembleRequest(object): + def test_simple(self): + req = tutils.treq() + b = HTTP1Protocol().assemble_request(req) + assert b == match_http_string(""" + GET /path HTTP/1.1 + header: qvalue + Host: address:22 + Content-Length: 7 + + content""") + + def test_body_missing(self): + req = tutils.treq(content=semantics.CONTENT_MISSING) + tutils.raises(http.HttpError, HTTP1Protocol().assemble_request, req) + + def test_not_a_request(self): + tutils.raises(AssertionError, HTTP1Protocol().assemble_request, 'foo') + + +class TestAssembleResponse(object): + def test_simple(self): + resp = tutils.tresp() + b = HTTP1Protocol().assemble_response(resp) + print(b) + assert b == match_http_string(""" + HTTP/1.1 200 OK + header_response: svalue + Content-Length: 7 + + message""") + + def test_body_missing(self): + resp = tutils.tresp(content=semantics.CONTENT_MISSING) + tutils.raises(http.HttpError, HTTP1Protocol().assemble_response, resp) + + def test_not_a_request(self): + tutils.raises(AssertionError, HTTP1Protocol().assemble_response, 'foo') diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 8a27bbb1..3044179f 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,10 +1,25 @@ import OpenSSL +import mock from netlib import tcp, odict, http, tutils from netlib.http import http2 +from netlib.http.http2 import HTTP2Protocol from netlib.http.http2.frame import * from ... import tservers +class TestTCPHandlerWrapper: + def test_wrapped(self): + h = http2.TCPHandler(rfile='foo', wfile='bar') + p = HTTP2Protocol(h) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + + def test_direct(self): + p = HTTP2Protocol(rfile='foo', wfile='bar') + assert isinstance(p.tcp_handler, http2.TCPHandler) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + class EchoHandler(tcp.BaseHandler): sni = None @@ -16,10 +31,40 @@ class EchoHandler(tcp.BaseHandler): self.wfile.flush() +class TestProtocol: + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + def test_perform_connection_preface(self, mock_client_method, mock_server_method): + protocol = HTTP2Protocol(is_server=False) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert mock_client_method.called + assert not mock_server_method.called + + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): + protocol = HTTP2Protocol(is_server=True) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert not mock_client_method.called + assert mock_server_method.called + + class TestCheckALPNMatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, + alpn_select=HTTP2Protocol.ALPN_PROTO_H2, ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -27,8 +72,8 @@ class TestCheckALPNMatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) + c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + protocol = HTTP2Protocol(c) assert protocol.check_alpn() @@ -43,8 +88,8 @@ class TestCheckALPNMismatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) + c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + protocol = HTTP2Protocol(c) tutils.raises(NotImplementedError, protocol.check_alpn) @@ -76,8 +121,13 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase): def test_perform_server_connection_preface(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) + + assert not protocol.connection_preface_performed protocol.perform_server_connection_preface() + assert protocol.connection_preface_performed + + tutils.raises(tcp.NetLibIncomplete, protocol.perform_server_connection_preface, force=True) class TestPerformClientConnectionPreface(tservers.ServerTestBase): @@ -107,13 +157,16 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase): def test_perform_client_connection_preface(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) + + assert not protocol.connection_preface_performed protocol.perform_client_connection_preface() + assert protocol.connection_preface_performed class TestClientStreamIds(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) def test_client_stream_ids(self): assert self.protocol.current_stream_id is None @@ -127,7 +180,7 @@ class TestClientStreamIds(): class TestServerStreamIds(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c, is_server=True) + protocol = HTTP2Protocol(c, is_server=True) def test_server_stream_ids(self): assert self.protocol.current_stream_id is None @@ -154,7 +207,7 @@ class TestApplySettings(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) protocol._apply_settings({ SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', @@ -182,13 +235,13 @@ class TestCreateHeaders(): (b':scheme', b'https'), (b'foo', b'bar')] - bytes = http2.HTTP2Protocol(self.c)._create_headers( + bytes = HTTP2Protocol(self.c)._create_headers( headers, 1, end_stream=True) assert b''.join(bytes) ==\ '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ .decode('hex') - bytes = http2.HTTP2Protocol(self.c)._create_headers( + bytes = HTTP2Protocol(self.c)._create_headers( headers, 1, end_stream=False) assert b''.join(bytes) ==\ '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ @@ -199,7 +252,7 @@ class TestCreateHeaders(): class TestCreateBody(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) def test_create_body_empty(self): bytes = self.protocol._create_body(b'', 1) @@ -215,41 +268,30 @@ class TestCreateBody(): # TODO: add test for too large frames -class TestAssembleRequest(): - c = tcp.TCPClient(("127.0.0.1", 0)) +class TestReadRequest(tservers.ServerTestBase): + class handler(tcp.BaseHandler): - def test_assemble_request_simple(self): - bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - None, - None, - )) - assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') + def handle(self): + self.wfile.write( + b'000003010400000001828487'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() - def test_assemble_request_with_body(self): - bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - odict.ODictCaseless([('foo', 'bar')]), - 'foobar', - )) - assert len(bytes) == 2 - assert bytes[0] ==\ - '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') + ssl = True + + def test_read_request(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True + + resp = protocol.read_request() + + assert resp.stream_id + assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] + assert resp.body == b'foobar' class TestReadResponse(tservers.ServerTestBase): @@ -268,7 +310,7 @@ class TestReadResponse(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True resp = protocol.read_response() @@ -278,6 +320,23 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.msg == "" assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'foobar' + assert resp.timestamp_end + + def test_read_response_no_body(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c) + protocol.connection_preface_performed = True + + resp = protocol.read_response(include_body=False) + + assert resp.httpversion == (2, 0) + assert resp.status_code == 200 + assert resp.msg == "" + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] + assert resp.body == b'foobar' # TODO: this should be true: assert resp.body == http.CONTENT_MISSING + assert not resp.timestamp_end class TestReadEmptyResponse(tservers.ServerTestBase): @@ -294,7 +353,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True resp = protocol.read_response() @@ -307,37 +366,66 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.body == b'' -class TestReadRequest(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'000003010400000001828487'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True +class TestAssembleRequest(object): + c = tcp.TCPClient(("127.0.0.1", 0)) - def test_read_request(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c, is_server=True) - protocol.connection_preface_performed = True + def test_request_simple(self): + bytes = HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + None, + None, + )) + assert len(bytes) == 1 + assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') - resp = protocol.read_request() + def test_request_with_stream_id(self): + req = http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + None, + None, + ) + req.stream_id = 0x42 + bytes = HTTP2Protocol(self.c).assemble_request(req) + assert len(bytes) == 1 + print(bytes[0].encode('hex')) + assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex') - assert resp.stream_id - assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] - assert resp.body == b'foobar' + def test_request_with_body(self): + bytes = HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + odict.ODictCaseless([('foo', 'bar')]), + 'foobar', + )) + assert len(bytes) == 2 + assert bytes[0] ==\ + '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') -class TestCreateResponse(): +class TestAssembleResponse(object): c = tcp.TCPClient(("127.0.0.1", 0)) - def test_create_response_simple(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + def test_simple(self): + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( (2, 0), 200, )) @@ -345,8 +433,19 @@ class TestCreateResponse(): assert bytes[0] ==\ '00000101050000000288'.decode('hex') - def test_create_response_with_body(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + def test_with_stream_id(self): + resp = http.Response( + (2, 0), + 200, + ) + resp.stream_id = 0x42 + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000004288'.decode('hex') + + def test_with_body(self): + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( (2, 0), 200, '', diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py index aa57f831..0131c7ef 100644 --- a/test/http/test_exceptions.py +++ b/test/http/test_exceptions.py @@ -1,6 +1,27 @@ from netlib.http.exceptions import * +from netlib import odict -def test_HttpAuthenticationError(): - x = HttpAuthenticationError({"foo": "bar"}) - assert str(x) - assert "foo" in x.headers +class TestHttpError: + def test_simple(self): + e = HttpError(404, "Not found") + assert str(e) + +class TestHttpAuthenticationError: + def test_init(self): + headers = odict.ODictCaseless([("foo", "bar")]) + x = HttpAuthenticationError(headers) + assert str(x) + assert isinstance(x.headers, odict.ODictCaseless) + assert x.code == 407 + assert x.headers == headers + print(x.headers.keys()) + assert "foo" in x.headers.keys() + + def test_header_conversion(self): + headers = {"foo": "bar"} + x = HttpAuthenticationError(headers) + assert isinstance(x.headers, odict.ODictCaseless) + assert x.headers.lst == headers.items() + + def test_repr(self): + assert repr(HttpAuthenticationError()) == "Proxy Authentication Required" diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py index d58a44d2..59364eae 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_semantics.py @@ -1,18 +1,277 @@ import cStringIO import textwrap import binascii +import mock from mock import MagicMock -from netlib import http, odict, tcp, tutils -from netlib.http import http1 +from netlib import http, odict, tcp, tutils, utils +from netlib.http import semantics from netlib.http.semantics import CONTENT_MISSING from .. import tservers -def test_httperror(): - e = http.exceptions.HttpError(404, "Not found") - assert str(e) +class TestProtocolMixin(object): + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") + def test_assemble_request(self, mock_request_method, mock_response_method): + p = semantics.ProtocolMixin() + p.assemble(tutils.treq()) + assert mock_request_method.called + assert not mock_response_method.called + + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") + def test_assemble_response(self, mock_request_method, mock_response_method): + p = semantics.ProtocolMixin() + p.assemble(tutils.tresp()) + assert not mock_request_method.called + assert mock_response_method.called + + def test_assemble_foo(self): + p = semantics.ProtocolMixin() + tutils.raises(ValueError, p.assemble, 'foo') + +class TestRequest(object): + def test_repr(self): + r = tutils.treq() + assert repr(r) + + def test_headers_odict(self): + tutils.raises(AssertionError, semantics.Request, + 'form_in', + 'method', + 'scheme', + 'host', + 'port', + 'path', + (1, 1), + 'foobar', + ) + + req = semantics.Request( + 'form_in', + 'method', + 'scheme', + 'host', + 'port', + 'path', + (1, 1), + ) + assert isinstance(req.headers, odict.ODictCaseless) + + def test_equal(self): + a = tutils.treq() + b = tutils.treq() + assert a == b + + assert not a == 'foo' + assert not b == 'foo' + assert not 'foo' == a + assert not 'foo' == b + + def test_legacy_first_line(self): + req = tutils.treq() + + req.form_in = 'relative' + assert req.legacy_first_line() == "GET /path HTTP/1.1" + + req.form_in = 'authority' + assert req.legacy_first_line() == "GET address:22 HTTP/1.1" + + req.form_in = 'absolute' + assert req.legacy_first_line() == "GET http://address:22/path HTTP/1.1" + + req.form_in = 'foobar' + tutils.raises(http.HttpError, req.legacy_first_line) + + def test_anticache(self): + req = tutils.treq() + req.headers.add("If-Modified-Since", "foo") + req.headers.add("If-None-Match", "bar") + req.anticache() + assert "If-Modified-Since" not in req.headers + assert "If-None-Match" not in req.headers + + def test_anticomp(self): + req = tutils.treq() + req.headers.add("Accept-Encoding", "foobar") + req.anticomp() + assert req.headers["Accept-Encoding"] == ["identity"] + + def test_constrain_encoding(self): + req = tutils.treq() + req.headers.add("Accept-Encoding", "identity, gzip, foo") + req.constrain_encoding() + assert "foo" not in req.headers.get_first("Accept-Encoding") + + def test_update_host(self): + req = tutils.treq() + req.headers.add("Host", "") + req.host = "foobar" + req.update_host_header() + assert req.headers.get_first("Host") == "foobar" + + def test_get_form(self): + req = tutils.treq() + assert req.get_form() == odict.ODict() + + @mock.patch("netlib.http.semantics.Request.get_form_multipart") + @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") + def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart): + req = tutils.treq() + assert req.get_form() == odict.ODict() + + req = tutils.treq() + req.body = "foobar" + req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED] + req.get_form() + assert req.get_form_urlencoded.called + assert not req.get_form_multipart.called + + @mock.patch("netlib.http.semantics.Request.get_form_multipart") + @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") + def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart): + req = tutils.treq() + req.body = "foobar" + req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART] + req.get_form() + assert not req.get_form_urlencoded.called + assert req.get_form_multipart.called + + def test_get_form_urlencoded(self): + req = tutils.treq("foobar") + assert req.get_form_urlencoded() == odict.ODict() + + req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED] + assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body)) + + def test_get_form_multipart(self): + req = tutils.treq("foobar") + assert req.get_form_multipart() == odict.ODict() + + req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART] + assert req.get_form_multipart() == odict.ODict( + utils.multipartdecode( + req.headers, + req.body)) + + def test_set_form_urlencoded(self): + req = tutils.treq() + req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')])) + assert req.headers.get_first("Content-Type") == semantics.HDR_FORM_URLENCODED + assert req.body + + def test_get_path_components(self): + req = tutils.treq() + assert req.get_path_components() + # TODO: add meaningful assertions + + def test_set_path_components(self): + req = tutils.treq() + req.set_path_components(["foo", "bar"]) + # TODO: add meaningful assertions + + def test_get_query(self): + req = tutils.treq() + assert req.get_query().lst == [] + + req.url = "http://localhost:80/foo?bar=42" + assert req.get_query().lst == [("bar", "42")] + + def test_set_query(self): + req = tutils.treq() + req.set_query(odict.ODict([])) + + def test_pretty_host(self): + r = tutils.treq() + assert r.pretty_host(True) == "address" + assert r.pretty_host(False) == "address" + r.headers["host"] = ["other"] + assert r.pretty_host(True) == "other" + assert r.pretty_host(False) == "address" + r.host = None + assert r.pretty_host(True) == "other" + assert r.pretty_host(False) is None + del r.headers["host"] + assert r.pretty_host(True) is None + assert r.pretty_host(False) is None + + # Invalid IDNA + r.headers["host"] = [".disqus.com"] + assert r.pretty_host(True) == ".disqus.com" + + def test_pretty_url(self): + req = tutils.treq() + req.form_out = "authority" + assert req.pretty_url(True) == "address:22" + assert req.pretty_url(False) == "address:22" + + req.form_out = "relative" + assert req.pretty_url(True) == "http://address:22/path" + assert req.pretty_url(False) == "http://address:22/path" + + def test_get_cookies_none(self): + h = odict.ODictCaseless() + r = tutils.treq() + r.headers = h + assert len(r.get_cookies()) == 0 + + def test_get_cookies_single(self): + h = odict.ODictCaseless() + h["Cookie"] = ["cookiename=cookievalue"] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 1 + assert result['cookiename'] == ['cookievalue'] + + def test_get_cookies_double(self): + h = odict.ODictCaseless() + h["Cookie"] = [ + "cookiename=cookievalue;othercookiename=othercookievalue" + ] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 2 + assert result['cookiename'] == ['cookievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_get_cookies_withequalsign(self): + h = odict.ODictCaseless() + h["Cookie"] = [ + "cookiename=coo=kievalue;othercookiename=othercookievalue" + ] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 2 + assert result['cookiename'] == ['coo=kievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_set_cookies(self): + h = odict.ODictCaseless() + h["Cookie"] = ["cookiename=cookievalue"] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + result["cookiename"] = ["foo"] + r.set_cookies(result) + assert r.get_cookies()["cookiename"] == ["foo"] + + def test_set_url(self): + r = tutils.treq_absolute() + r.url = "https://otheraddress:42/ORLY" + assert r.scheme == "https" + assert r.host == "otheraddress" + assert r.port == 42 + assert r.path == "/ORLY" + + try: + r.url = "//localhost:80/foo@bar" + assert False + except: + assert True -class TestRequest: # def test_asterisk_form_in(self): # f = tutils.tflow(req=None) # protocol = mock_protocol("OPTIONS * HTTP/1.1") @@ -92,105 +351,35 @@ class TestRequest: # "Host: address\r\n" # "Content-Length: 0\r\n\r\n") - def test_set_url(self): - r = tutils.treq_absolute() - r.url = "https://otheraddress:42/ORLY" - assert r.scheme == "https" - assert r.host == "otheraddress" - assert r.port == 42 - assert r.path == "/ORLY" - - def test_repr(self): - r = tutils.treq() - assert repr(r) - - def test_pretty_host(self): - r = tutils.treq() - assert r.pretty_host(True) == "address" - assert r.pretty_host(False) == "address" - r.headers["host"] = ["other"] - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) == "address" - r.host = None - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) is None - del r.headers["host"] - assert r.pretty_host(True) is None - assert r.pretty_host(False) is None - - # Invalid IDNA - r.headers["host"] = [".disqus.com"] - assert r.pretty_host(True) == ".disqus.com" - - def test_get_form_for_urlencoded(self): - r = tutils.treq() - r.headers.add("content-type", "application/x-www-form-urlencoded") - r.get_form_urlencoded = MagicMock() - - r.get_form() - - assert r.get_form_urlencoded.called - - def test_get_form_for_multipart(self): - r = tutils.treq() - r.headers.add("content-type", "multipart/form-data") - r.get_form_multipart = MagicMock() - - r.get_form() +class TestEmptyRequest(object): + def test_init(self): + req = semantics.EmptyRequest() + assert req - assert r.get_form_multipart.called - - def test_get_cookies_none(self): - h = odict.ODictCaseless() - r = tutils.treq() - r.headers = h - assert len(r.get_cookies()) == 0 - - def test_get_cookies_single(self): - h = odict.ODictCaseless() - h["Cookie"] = ["cookiename=cookievalue"] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - assert len(result) == 1 - assert result['cookiename'] == ['cookievalue'] - - def test_get_cookies_double(self): - h = odict.ODictCaseless() - h["Cookie"] = [ - "cookiename=cookievalue;othercookiename=othercookievalue" - ] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['cookievalue'] - assert result['othercookiename'] == ['othercookievalue'] +class TestResponse(object): + def test_headers_odict(self): + tutils.raises(AssertionError, semantics.Response, + (1, 1), + 200, + headers='foobar', + ) - def test_get_cookies_withequalsign(self): - h = odict.ODictCaseless() - h["Cookie"] = [ - "cookiename=coo=kievalue;othercookiename=othercookievalue" - ] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['coo=kievalue'] - assert result['othercookiename'] == ['othercookievalue'] + resp = semantics.Response( + (1, 1), + 200, + ) + assert isinstance(resp.headers, odict.ODictCaseless) - def test_set_cookies(self): - h = odict.ODictCaseless() - h["Cookie"] = ["cookiename=cookievalue"] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - result["cookiename"] = ["foo"] - r.set_cookies(result) - assert r.get_cookies()["cookiename"] == ["foo"] + def test_equal(self): + a = tutils.tresp() + b = tutils.tresp() + assert a == b + assert not a == 'foo' + assert not b == 'foo' + assert not 'foo' == a + assert not 'foo' == b -class TestResponse(object): def test_repr(self): r = tutils.tresp() assert "unknown content type" in repr(r) diff --git a/test/test_utils.py b/test/test_utils.py index 5e681eb6..aafa1571 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -101,3 +101,34 @@ def test_get_header_tokens(): assert utils.get_header_tokens(h, "foo") == ["bar", "voing"] h["foo"] = ["bar, voing", "oink"] assert utils.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] + + + + + +def test_multipartdecode(): + boundary = 'somefancyboundary' + headers = odict.ODict( + [('content-type', ('multipart/form-data; boundary=%s' % boundary))]) + content = "--{0}\n" \ + "Content-Disposition: form-data; name=\"field1\"\n\n" \ + "value1\n" \ + "--{0}\n" \ + "Content-Disposition: form-data; name=\"field2\"\n\n" \ + "value2\n" \ + "--{0}--".format(boundary) + + form = utils.multipartdecode(headers, content) + + assert len(form) == 2 + assert form[0] == ('field1', 'value1') + assert form[1] == ('field2', 'value2') + + +def test_parse_content_type(): + p = utils.parse_content_type + assert p("text/html") == ("text", "html", {}) + assert p("text") is None + + v = p("text/html; charset=UTF-8") + assert v == ('text', 'html', {'charset': 'UTF-8'}) diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 28dbb833..9fa98172 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -3,6 +3,7 @@ import os from nose.tools import raises from netlib import tcp, http, websockets, tutils +from netlib.http import status_codes from netlib.http.exceptions import * from netlib.http.http1 import HTTP1Protocol from .. import tservers @@ -38,7 +39,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): req = http1_protocol.read_request() key = self.protocol.check_client_handshake(req.headers) - preamble = http1_protocol.response_preamble(101) + preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers(key) self.wfile.write(headers.format() + "\r\n") @@ -62,7 +63,7 @@ class WebSocketsClient(tcp.TCPClient): http1_protocol = HTTP1Protocol(self) - preamble = http1_protocol.request_preamble("GET", "/") + preamble = 'GET / HTTP/1.1' self.wfile.write(preamble + "\r\n") headers = self.protocol.client_handshake_headers() self.client_nonce = headers.get_first("sec-websocket-key") @@ -162,7 +163,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler): client_hs = http1_protocol.read_request() self.protocol.check_client_handshake(client_hs.headers) - preamble = http1_protocol.response_preamble(101) + preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers("malformed key") self.wfile.write(headers.format() + "\r\n") |