import OpenSSL import mock import codecs from hyperframe.frame import * from netlib import tcp, http, utils from netlib.tutils import raises from netlib.exceptions import TcpDisconnect from netlib.http.http2.connections import HTTP2Protocol, TCPHandler from ... import tservers class TestTCPHandlerWrapper: def test_wrapped(self): h = TCPHandler(rfile='foo', wfile='bar') p = HTTP2Protocol(h) assert p.tcp_handler.rfile == 'foo' assert p.tcp_handler.wfile == 'bar' def test_direct(self): p = HTTP2Protocol(rfile='foo', wfile='bar') assert isinstance(p.tcp_handler, TCPHandler) assert p.tcp_handler.rfile == 'foo' assert p.tcp_handler.wfile == 'bar' class EchoHandler(tcp.BaseHandler): sni = None def handle(self): while True: v = self.rfile.safe_read(1) self.wfile.write(v) self.wfile.flush() class TestProtocol: @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") def test_perform_connection_preface(self, mock_client_method, mock_server_method): protocol = HTTP2Protocol(is_server=False) protocol.connection_preface_performed = True protocol.perform_connection_preface() assert not mock_client_method.called assert not mock_server_method.called protocol.perform_connection_preface(force=True) assert mock_client_method.called assert not mock_server_method.called @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): protocol = HTTP2Protocol(is_server=True) protocol.connection_preface_performed = True protocol.perform_connection_preface() assert not mock_client_method.called assert not mock_server_method.called protocol.perform_connection_preface(force=True) assert not mock_client_method.called assert mock_server_method.called class TestCheckALPNMatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( alpn_select=b'h2', ) if tcp.HAS_ALPN: def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(alpn_protos=[b'h2']) protocol = HTTP2Protocol(c) assert protocol.check_alpn() class TestCheckALPNMismatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( alpn_select=None, ) if tcp.HAS_ALPN: def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(alpn_protos=[b'h2']) protocol = HTTP2Protocol(c) with raises(NotImplementedError): protocol.check_alpn() class TestPerformServerConnectionPreface(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): # send magic self.wfile.write(codecs.decode('505249202a20485454502f322e300d0a0d0a534d0d0a0d0a', 'hex_codec')) self.wfile.flush() # send empty settings frame self.wfile.write(codecs.decode('000000040000000000', 'hex_codec')) self.wfile.flush() # check empty settings fr