diff options
Diffstat (limited to 'test/http/http2/test_protocol.py')
| -rw-r--r-- | test/http/http2/test_protocol.py | 247 | 
1 files changed, 173 insertions, 74 deletions
diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 8a27bbb1..3044179f 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,10 +1,25 @@  import OpenSSL +import mock  from netlib import tcp, odict, http, tutils  from netlib.http import http2 +from netlib.http.http2 import HTTP2Protocol  from netlib.http.http2.frame import *  from ... import tservers +class TestTCPHandlerWrapper: +    def test_wrapped(self): +        h = http2.TCPHandler(rfile='foo', wfile='bar') +        p = HTTP2Protocol(h) +        assert p.tcp_handler.rfile == 'foo' +        assert p.tcp_handler.wfile == 'bar' + +    def test_direct(self): +        p = HTTP2Protocol(rfile='foo', wfile='bar') +        assert isinstance(p.tcp_handler, http2.TCPHandler) +        assert p.tcp_handler.rfile == 'foo' +        assert p.tcp_handler.wfile == 'bar' +  class EchoHandler(tcp.BaseHandler):      sni = None @@ -16,10 +31,40 @@ class EchoHandler(tcp.BaseHandler):              self.wfile.flush() +class TestProtocol: +    @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") +    @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") +    def test_perform_connection_preface(self, mock_client_method, mock_server_method): +        protocol = HTTP2Protocol(is_server=False) +        protocol.connection_preface_performed = True + +        protocol.perform_connection_preface() +        assert not mock_client_method.called +        assert not mock_server_method.called + +        protocol.perform_connection_preface(force=True) +        assert mock_client_method.called +        assert not mock_server_method.called + +    @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") +    @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") +    def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): +        protocol = HTTP2Protocol(is_server=True) +        protocol.connection_preface_performed = True + +        protocol.perform_connection_preface() +        assert not mock_client_method.called +        assert not mock_server_method.called + +        protocol.perform_connection_preface(force=True) +        assert not mock_client_method.called +        assert mock_server_method.called + +  class TestCheckALPNMatch(tservers.ServerTestBase):      handler = EchoHandler      ssl = dict( -        alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, +        alpn_select=HTTP2Protocol.ALPN_PROTO_H2,      )      if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -27,8 +72,8 @@ class TestCheckALPNMatch(tservers.ServerTestBase):          def test_check_alpn(self):              c = tcp.TCPClient(("127.0.0.1", self.port))              c.connect() -            c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) -            protocol = http2.HTTP2Protocol(c) +            c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) +            protocol = HTTP2Protocol(c)              assert protocol.check_alpn() @@ -43,8 +88,8 @@ class TestCheckALPNMismatch(tservers.ServerTestBase):          def test_check_alpn(self):              c = tcp.TCPClient(("127.0.0.1", self.port))              c.connect() -            c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) -            protocol = http2.HTTP2Protocol(c) +            c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) +            protocol = HTTP2Protocol(c)              tutils.raises(NotImplementedError, protocol.check_alpn) @@ -76,8 +121,13 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):      def test_perform_server_connection_preface(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          c.connect() -        protocol = http2.HTTP2Protocol(c) +        protocol = HTTP2Protocol(c) + +        assert not protocol.connection_preface_performed          protocol.perform_server_connection_preface() +        assert protocol.connection_preface_performed + +        tutils.raises(tcp.NetLibIncomplete, protocol.perform_server_connection_preface, force=True)  class TestPerformClientConnectionPreface(tservers.ServerTestBase): @@ -107,13 +157,16 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase):      def test_perform_client_connection_preface(self):          c = tcp.TCPClient(("127.0.0.1", self.port))          c.connect() -        protocol = http2.HTTP2Protocol(c) +        protocol = HTTP2Protocol(c) + +        assert not protocol.connection_preface_performed          protocol.perform_client_connection_preface() +        assert protocol.connection_preface_performed  class TestClientStreamIds():      c = tcp.TCPClient(("127.0.0.1", 0)) -    protocol = http2.HTTP2Protocol(c) +    protocol = HTTP2Protocol(c)      def test_client_stream_ids(self):          assert self.protocol.current_stream_id is None @@ -127,7 +180,7 @@ class TestClientStreamIds():  class TestServerStreamIds():      c = tcp.TCPClient(("127.0.0.1", 0)) -    protocol = http2.HTTP2Protocol(c, is_server=True) +    protocol = HTTP2Protocol(c, is_server=True)      def test_server_stream_ids(self):          assert self.protocol.current_stream_id is None @@ -154,7 +207,7 @@ class TestApplySettings(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          c.connect()          c.convert_to_ssl() -        protocol = http2.HTTP2Protocol(c) +        protocol = HTTP2Protocol(c)          protocol._apply_settings({              SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', @@ -182,13 +235,13 @@ class TestCreateHeaders():              (b':scheme', b'https'),              (b'foo', b'bar')] -        bytes = http2.HTTP2Protocol(self.c)._create_headers( +        bytes = HTTP2Protocol(self.c)._create_headers(              headers, 1, end_stream=True)          assert b''.join(bytes) ==\              '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\              .decode('hex') -        bytes = http2.HTTP2Protocol(self.c)._create_headers( +        bytes = HTTP2Protocol(self.c)._create_headers(              headers, 1, end_stream=False)          assert b''.join(bytes) ==\              '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ @@ -199,7 +252,7 @@ class TestCreateHeaders():  class TestCreateBody():      c = tcp.TCPClient(("127.0.0.1", 0)) -    protocol = http2.HTTP2Protocol(c) +    protocol = HTTP2Protocol(c)      def test_create_body_empty(self):          bytes = self.protocol._create_body(b'', 1) @@ -215,41 +268,30 @@ class TestCreateBody():          # TODO: add test for too large frames -class TestAssembleRequest(): -    c = tcp.TCPClient(("127.0.0.1", 0)) +class TestReadRequest(tservers.ServerTestBase): +    class handler(tcp.BaseHandler): -    def test_assemble_request_simple(self): -        bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( -            '', -            'GET', -            'https', -            '', -            '', -            '/', -            (2, 0), -            None, -            None, -        )) -        assert len(bytes) == 1 -        assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') +        def handle(self): +            self.wfile.write( +                b'000003010400000001828487'.decode('hex')) +            self.wfile.write( +                b'000006000100000001666f6f626172'.decode('hex')) +            self.wfile.flush() -    def test_assemble_request_with_body(self): -        bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( -            '', -            'GET', -            'https', -            '', -            '', -            '/', -            (2, 0), -            odict.ODictCaseless([('foo', 'bar')]), -            'foobar', -        )) -        assert len(bytes) == 2 -        assert bytes[0] ==\ -            '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') -        assert bytes[1] ==\ -            '000006000100000001666f6f626172'.decode('hex') +    ssl = True + +    def test_read_request(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        protocol = HTTP2Protocol(c, is_server=True) +        protocol.connection_preface_performed = True + +        resp = protocol.read_request() + +        assert resp.stream_id +        assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] +        assert resp.body == b'foobar'  class TestReadResponse(tservers.ServerTestBase): @@ -268,7 +310,7 @@ class TestReadResponse(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          c.connect()          c.convert_to_ssl() -        protocol = http2.HTTP2Protocol(c) +        protocol = HTTP2Protocol(c)          protocol.connection_preface_performed = True          resp = protocol.read_response() @@ -278,6 +320,23 @@ class TestReadResponse(tservers.ServerTestBase):          assert resp.msg == ""          assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]          assert resp.body == b'foobar' +        assert resp.timestamp_end + +    def test_read_response_no_body(self): +        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c.connect() +        c.convert_to_ssl() +        protocol = HTTP2Protocol(c) +        protocol.connection_preface_performed = True + +        resp = protocol.read_response(include_body=False) + +        assert resp.httpversion == (2, 0) +        assert resp.status_code == 200 +        assert resp.msg == "" +        assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] +        assert resp.body == b'foobar'  # TODO: this should be true: assert resp.body == http.CONTENT_MISSING +        assert not resp.timestamp_end  class TestReadEmptyResponse(tservers.ServerTestBase): @@ -294,7 +353,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):          c = tcp.TCPClient(("127.0.0.1", self.port))          c.connect()          c.convert_to_ssl() -        protocol = http2.HTTP2Protocol(c) +        protocol = HTTP2Protocol(c)          protocol.connection_preface_performed = True          resp = protocol.read_response() @@ -307,37 +366,66 @@ class TestReadEmptyResponse(tservers.ServerTestBase):          assert resp.body == b'' -class TestReadRequest(tservers.ServerTestBase): -    class handler(tcp.BaseHandler): - -        def handle(self): -            self.wfile.write( -                b'000003010400000001828487'.decode('hex')) -            self.wfile.write( -                b'000006000100000001666f6f626172'.decode('hex')) -            self.wfile.flush() - -    ssl = True +class TestAssembleRequest(object): +    c = tcp.TCPClient(("127.0.0.1", 0)) -    def test_read_request(self): -        c = tcp.TCPClient(("127.0.0.1", self.port)) -        c.connect() -        c.convert_to_ssl() -        protocol = http2.HTTP2Protocol(c, is_server=True) -        protocol.connection_preface_performed = True +    def test_request_simple(self): +        bytes = HTTP2Protocol(self.c).assemble_request(http.Request( +            '', +            'GET', +            'https', +            '', +            '', +            '/', +            (2, 0), +            None, +            None, +        )) +        assert len(bytes) == 1 +        assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') -        resp = protocol.read_request() +    def test_request_with_stream_id(self): +        req = http.Request( +            '', +            'GET', +            'https', +            '', +            '', +            '/', +            (2, 0), +            None, +            None, +        ) +        req.stream_id = 0x42 +        bytes = HTTP2Protocol(self.c).assemble_request(req) +        assert len(bytes) == 1 +        print(bytes[0].encode('hex')) +        assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex') -        assert resp.stream_id -        assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] -        assert resp.body == b'foobar' +    def test_request_with_body(self): +        bytes = HTTP2Protocol(self.c).assemble_request(http.Request( +            '', +            'GET', +            'https', +            '', +            '', +            '/', +            (2, 0), +            odict.ODictCaseless([('foo', 'bar')]), +            'foobar', +        )) +        assert len(bytes) == 2 +        assert bytes[0] ==\ +            '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') +        assert bytes[1] ==\ +            '000006000100000001666f6f626172'.decode('hex') -class TestCreateResponse(): +class TestAssembleResponse(object):      c = tcp.TCPClient(("127.0.0.1", 0)) -    def test_create_response_simple(self): -        bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( +    def test_simple(self): +        bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(              (2, 0),              200,          )) @@ -345,8 +433,19 @@ class TestCreateResponse():          assert bytes[0] ==\              '00000101050000000288'.decode('hex') -    def test_create_response_with_body(self): -        bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( +    def test_with_stream_id(self): +        resp = http.Response( +            (2, 0), +            200, +        ) +        resp.stream_id = 0x42 +        bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp) +        assert len(bytes) == 1 +        assert bytes[0] ==\ +            '00000101050000004288'.decode('hex') + +    def test_with_body(self): +        bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(              (2, 0),              200,              '',  | 
