diff options
Diffstat (limited to 'test/http/http2/test_protocol.py')
-rw-r--r-- | test/http/http2/test_protocol.py | 247 |
1 files changed, 173 insertions, 74 deletions
diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 8a27bbb1..3044179f 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,10 +1,25 @@ import OpenSSL +import mock from netlib import tcp, odict, http, tutils from netlib.http import http2 +from netlib.http.http2 import HTTP2Protocol from netlib.http.http2.frame import * from ... import tservers +class TestTCPHandlerWrapper: + def test_wrapped(self): + h = http2.TCPHandler(rfile='foo', wfile='bar') + p = HTTP2Protocol(h) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + + def test_direct(self): + p = HTTP2Protocol(rfile='foo', wfile='bar') + assert isinstance(p.tcp_handler, http2.TCPHandler) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + class EchoHandler(tcp.BaseHandler): sni = None @@ -16,10 +31,40 @@ class EchoHandler(tcp.BaseHandler): self.wfile.flush() +class TestProtocol: + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + def test_perform_connection_preface(self, mock_client_method, mock_server_method): + protocol = HTTP2Protocol(is_server=False) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert mock_client_method.called + assert not mock_server_method.called + + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): + protocol = HTTP2Protocol(is_server=True) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert not mock_client_method.called + assert mock_server_method.called + + class TestCheckALPNMatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, + alpn_select=HTTP2Protocol.ALPN_PROTO_H2, ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -27,8 +72,8 @@ class TestCheckALPNMatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) + c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + protocol = HTTP2Protocol(c) assert protocol.check_alpn() @@ -43,8 +88,8 @@ class TestCheckALPNMismatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) + c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + protocol = HTTP2Protocol(c) tutils.raises(NotImplementedError, protocol.check_alpn) @@ -76,8 +121,13 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase): def test_perform_server_connection_preface(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) + + assert not protocol.connection_preface_performed protocol.perform_server_connection_preface() + assert protocol.connection_preface_performed + + tutils.raises(tcp.NetLibIncomplete, protocol.perform_server_connection_preface, force=True) class TestPerformClientConnectionPreface(tservers.ServerTestBase): @@ -107,13 +157,16 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase): def test_perform_client_connection_preface(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) + + assert not protocol.connection_preface_performed protocol.perform_client_connection_preface() + assert protocol.connection_preface_performed class TestClientStreamIds(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) def test_client_stream_ids(self): assert self.protocol.current_stream_id is None @@ -127,7 +180,7 @@ class TestClientStreamIds(): class TestServerStreamIds(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c, is_server=True) + protocol = HTTP2Protocol(c, is_server=True) def test_server_stream_ids(self): assert self.protocol.current_stream_id is None @@ -154,7 +207,7 @@ class TestApplySettings(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) protocol._apply_settings({ SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', @@ -182,13 +235,13 @@ class TestCreateHeaders(): (b':scheme', b'https'), (b'foo', b'bar')] - bytes = http2.HTTP2Protocol(self.c)._create_headers( + bytes = HTTP2Protocol(self.c)._create_headers( headers, 1, end_stream=True) assert b''.join(bytes) ==\ '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ .decode('hex') - bytes = http2.HTTP2Protocol(self.c)._create_headers( + bytes = HTTP2Protocol(self.c)._create_headers( headers, 1, end_stream=False) assert b''.join(bytes) ==\ '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ @@ -199,7 +252,7 @@ class TestCreateHeaders(): class TestCreateBody(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) def test_create_body_empty(self): bytes = self.protocol._create_body(b'', 1) @@ -215,41 +268,30 @@ class TestCreateBody(): # TODO: add test for too large frames -class TestAssembleRequest(): - c = tcp.TCPClient(("127.0.0.1", 0)) +class TestReadRequest(tservers.ServerTestBase): + class handler(tcp.BaseHandler): - def test_assemble_request_simple(self): - bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - None, - None, - )) - assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') + def handle(self): + self.wfile.write( + b'000003010400000001828487'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() - def test_assemble_request_with_body(self): - bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - odict.ODictCaseless([('foo', 'bar')]), - 'foobar', - )) - assert len(bytes) == 2 - assert bytes[0] ==\ - '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') + ssl = True + + def test_read_request(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c, is_server=True) + protocol.connection_preface_performed = True + + resp = protocol.read_request() + + assert resp.stream_id + assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] + assert resp.body == b'foobar' class TestReadResponse(tservers.ServerTestBase): @@ -268,7 +310,7 @@ class TestReadResponse(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True resp = protocol.read_response() @@ -278,6 +320,23 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.msg == "" assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'foobar' + assert resp.timestamp_end + + def test_read_response_no_body(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c) + protocol.connection_preface_performed = True + + resp = protocol.read_response(include_body=False) + + assert resp.httpversion == (2, 0) + assert resp.status_code == 200 + assert resp.msg == "" + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] + assert resp.body == b'foobar' # TODO: this should be true: assert resp.body == http.CONTENT_MISSING + assert not resp.timestamp_end class TestReadEmptyResponse(tservers.ServerTestBase): @@ -294,7 +353,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True resp = protocol.read_response() @@ -307,37 +366,66 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.body == b'' -class TestReadRequest(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'000003010400000001828487'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True +class TestAssembleRequest(object): + c = tcp.TCPClient(("127.0.0.1", 0)) - def test_read_request(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c, is_server=True) - protocol.connection_preface_performed = True + def test_request_simple(self): + bytes = HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + None, + None, + )) + assert len(bytes) == 1 + assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') - resp = protocol.read_request() + def test_request_with_stream_id(self): + req = http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + None, + None, + ) + req.stream_id = 0x42 + bytes = HTTP2Protocol(self.c).assemble_request(req) + assert len(bytes) == 1 + print(bytes[0].encode('hex')) + assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex') - assert resp.stream_id - assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] - assert resp.body == b'foobar' + def test_request_with_body(self): + bytes = HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + odict.ODictCaseless([('foo', 'bar')]), + 'foobar', + )) + assert len(bytes) == 2 + assert bytes[0] ==\ + '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') -class TestCreateResponse(): +class TestAssembleResponse(object): c = tcp.TCPClient(("127.0.0.1", 0)) - def test_create_response_simple(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + def test_simple(self): + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( (2, 0), 200, )) @@ -345,8 +433,19 @@ class TestCreateResponse(): assert bytes[0] ==\ '00000101050000000288'.decode('hex') - def test_create_response_with_body(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + def test_with_stream_id(self): + resp = http.Response( + (2, 0), + 200, + ) + resp.stream_id = 0x42 + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000004288'.decode('hex') + + def test_with_body(self): + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( (2, 0), 200, '', |