diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/http2/test_protocol.py | 18 | ||||
-rw-r--r-- | test/test_http.py | 6 | ||||
-rw-r--r-- | test/test_tcp.py | 56 | ||||
-rw-r--r-- | test/test_websockets.py | 8 | ||||
-rw-r--r-- | test/tservers.py | 108 |
5 files changed, 152 insertions, 44 deletions
diff --git a/test/http2/test_protocol.py b/test/http2/test_protocol.py index 9b49acd3..5e2af34e 100644 --- a/test/http2/test_protocol.py +++ b/test/http2/test_protocol.py @@ -2,9 +2,9 @@ import OpenSSL from netlib import http2 from netlib import tcp -from netlib import test from netlib.http2.frame import * from test import tutils +from .. import tservers class EchoHandler(tcp.BaseHandler): @@ -17,7 +17,7 @@ class EchoHandler(tcp.BaseHandler): self.wfile.flush() -class TestCheckALPNMatch(test.ServerTestBase): +class TestCheckALPNMatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, @@ -33,7 +33,7 @@ class TestCheckALPNMatch(test.ServerTestBase): assert protocol.check_alpn() -class TestCheckALPNMismatch(test.ServerTestBase): +class TestCheckALPNMismatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( alpn_select=None, @@ -49,7 +49,7 @@ class TestCheckALPNMismatch(test.ServerTestBase): tutils.raises(NotImplementedError, protocol.check_alpn) -class TestPerformServerConnectionPreface(test.ServerTestBase): +class TestPerformServerConnectionPreface(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -81,7 +81,7 @@ class TestPerformServerConnectionPreface(test.ServerTestBase): protocol.perform_server_connection_preface() -class TestPerformClientConnectionPreface(test.ServerTestBase): +class TestPerformClientConnectionPreface(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -140,7 +140,7 @@ class TestServerStreamIds(): assert self.protocol.current_stream_id == 6 -class TestApplySettings(test.ServerTestBase): +class TestApplySettings(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -234,7 +234,7 @@ class TestCreateRequest(): '000006000100000001666f6f626172'.decode('hex') -class TestReadResponse(test.ServerTestBase): +class TestReadResponse(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -259,7 +259,7 @@ class TestReadResponse(test.ServerTestBase): assert body == b'foobar' -class TestReadEmptyResponse(test.ServerTestBase): +class TestReadEmptyResponse(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -282,7 +282,7 @@ class TestReadEmptyResponse(test.ServerTestBase): assert body == b'' -class TestReadRequest(test.ServerTestBase): +class TestReadRequest(tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): diff --git a/test/test_http.py b/test/test_http.py index 0a9e276f..2ad81d24 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -1,8 +1,8 @@ import cStringIO import textwrap import binascii -from netlib import http, odict, tcp, test -import tutils +from netlib import http, odict, tcp +from . import tutils, tservers def test_httperror(): @@ -284,7 +284,7 @@ class NoContentLengthHTTPHandler(tcp.BaseHandler): self.wfile.flush() -class TestReadResponseNoContentLength(test.ServerTestBase): +class TestReadResponseNoContentLength(tservers.ServerTestBase): handler = NoContentLengthHTTPHandler def test_no_content_length(self): diff --git a/test/test_tcp.py b/test/test_tcp.py index 122c1f0f..4253e073 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -10,8 +10,8 @@ import mock from OpenSSL import SSL import OpenSSL -from netlib import tcp, certutils, test, certffi -import tutils +from netlib import tcp, certutils, certffi +from . import tutils, tservers class EchoHandler(tcp.BaseHandler): @@ -53,7 +53,7 @@ class ALPNHandler(tcp.BaseHandler): self.wfile.flush() -class TestServer(test.ServerTestBase): +class TestServer(tservers.ServerTestBase): handler = EchoHandler def test_echo(self): @@ -74,7 +74,7 @@ class TestServer(test.ServerTestBase): self.test_echo() -class TestServerBind(test.ServerTestBase): +class TestServerBind(tservers.ServerTestBase): class handler(tcp.BaseHandler): @@ -97,7 +97,7 @@ class TestServerBind(test.ServerTestBase): pass -class TestServerIPv6(test.ServerTestBase): +class TestServerIPv6(tservers.ServerTestBase): handler = EchoHandler addr = tcp.Address(("localhost", 0), use_ipv6=True) @@ -110,7 +110,7 @@ class TestServerIPv6(test.ServerTestBase): assert c.rfile.readline() == testval -class TestEcho(test.ServerTestBase): +class TestEcho(tservers.ServerTestBase): handler = EchoHandler def test_echo(self): @@ -128,7 +128,7 @@ class HardDisconnectHandler(tcp.BaseHandler): self.connection.close() -class TestFinishFail(test.ServerTestBase): +class TestFinishFail(tservers.ServerTestBase): """ This tests a difficult-to-trigger exception in the .finish() method of @@ -144,7 +144,7 @@ class TestFinishFail(test.ServerTestBase): c.finish() -class TestServerSSL(test.ServerTestBase): +class TestServerSSL(tservers.ServerTestBase): handler = EchoHandler ssl = dict( cipher_list="AES256-SHA", @@ -170,7 +170,7 @@ class TestServerSSL(test.ServerTestBase): assert "AES" in ret[0] -class TestSSLv3Only(test.ServerTestBase): +class TestSSLv3Only(tservers.ServerTestBase): handler = EchoHandler ssl = dict( request_client_cert=False, @@ -183,7 +183,7 @@ class TestSSLv3Only(test.ServerTestBase): tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com") -class TestSSLUpstreamCertVerification(test.ServerTestBase): +class TestSSLUpstreamCertVerification(tservers.ServerTestBase): handler = EchoHandler ssl = dict( @@ -236,7 +236,7 @@ class TestSSLUpstreamCertVerification(test.ServerTestBase): assert c.rfile.readline() == testval -class TestSSLClientCert(test.ServerTestBase): +class TestSSLClientCert(tservers.ServerTestBase): class handler(tcp.BaseHandler): sni = None @@ -270,7 +270,7 @@ class TestSSLClientCert(test.ServerTestBase): ) -class TestSNI(test.ServerTestBase): +class TestSNI(tservers.ServerTestBase): class handler(tcp.BaseHandler): sni = None @@ -292,7 +292,7 @@ class TestSNI(test.ServerTestBase): assert c.rfile.readline() == "foo.com" -class TestServerCipherList(test.ServerTestBase): +class TestServerCipherList(tservers.ServerTestBase): handler = ClientCipherListHandler ssl = dict( cipher_list='RC4-SHA' @@ -305,7 +305,7 @@ class TestServerCipherList(test.ServerTestBase): assert c.rfile.readline() == "['RC4-SHA']" -class TestServerCurrentCipher(test.ServerTestBase): +class TestServerCurrentCipher(tservers.ServerTestBase): class handler(tcp.BaseHandler): sni = None @@ -325,7 +325,7 @@ class TestServerCurrentCipher(test.ServerTestBase): assert "RC4-SHA" in c.rfile.readline() -class TestServerCipherListError(test.ServerTestBase): +class TestServerCipherListError(tservers.ServerTestBase): handler = ClientCipherListHandler ssl = dict( cipher_list='bogus' @@ -337,7 +337,7 @@ class TestServerCipherListError(test.ServerTestBase): tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com") -class TestClientCipherListError(test.ServerTestBase): +class TestClientCipherListError(tservers.ServerTestBase): handler = ClientCipherListHandler ssl = dict( cipher_list='RC4-SHA' @@ -353,7 +353,7 @@ class TestClientCipherListError(test.ServerTestBase): cipher_list="bogus") -class TestSSLDisconnect(test.ServerTestBase): +class TestSSLDisconnect(tservers.ServerTestBase): class handler(tcp.BaseHandler): @@ -373,7 +373,7 @@ class TestSSLDisconnect(test.ServerTestBase): tutils.raises(Queue.Empty, self.q.get_nowait) -class TestSSLHardDisconnect(test.ServerTestBase): +class TestSSLHardDisconnect(tservers.ServerTestBase): handler = HardDisconnectHandler ssl = True @@ -387,7 +387,7 @@ class TestSSLHardDisconnect(test.ServerTestBase): tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") -class TestDisconnect(test.ServerTestBase): +class TestDisconnect(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) @@ -398,7 +398,7 @@ class TestDisconnect(test.ServerTestBase): c.close() -class TestServerTimeOut(test.ServerTestBase): +class TestServerTimeOut(tservers.ServerTestBase): class handler(tcp.BaseHandler): @@ -417,7 +417,7 @@ class TestServerTimeOut(test.ServerTestBase): assert self.last_handler.timeout -class TestTimeOut(test.ServerTestBase): +class TestTimeOut(tservers.ServerTestBase): handler = HangHandler def test_timeout(self): @@ -428,7 +428,7 @@ class TestTimeOut(test.ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) -class TestALPNClient(test.ServerTestBase): +class TestALPNClient(tservers.ServerTestBase): handler = ALPNHandler ssl = dict( alpn_select="bar" @@ -457,7 +457,7 @@ class TestALPNClient(test.ServerTestBase): assert c.get_alpn_proto_negotiated() == "" assert c.rfile.readline() == "NONE" -class TestNoSSLNoALPNClient(test.ServerTestBase): +class TestNoSSLNoALPNClient(tservers.ServerTestBase): handler = ALPNHandler def test_no_ssl_no_alpn(self): @@ -467,7 +467,7 @@ class TestNoSSLNoALPNClient(test.ServerTestBase): assert c.rfile.readline().strip() == "NONE" -class TestSSLTimeOut(test.ServerTestBase): +class TestSSLTimeOut(tservers.ServerTestBase): handler = HangHandler ssl = True @@ -479,7 +479,7 @@ class TestSSLTimeOut(test.ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) -class TestDHParams(test.ServerTestBase): +class TestDHParams(tservers.ServerTestBase): handler = HangHandler ssl = dict( dhparams=certutils.CertStore.load_dhparam( @@ -502,7 +502,7 @@ class TestDHParams(test.ServerTestBase): assert os.path.exists(filename) -class TestPrivkeyGen(test.ServerTestBase): +class TestPrivkeyGen(tservers.ServerTestBase): class handler(tcp.BaseHandler): @@ -520,7 +520,7 @@ class TestPrivkeyGen(test.ServerTestBase): tutils.raises("bad record mac", c.convert_to_ssl) -class TestPrivkeyGenNoFlags(test.ServerTestBase): +class TestPrivkeyGenNoFlags(tservers.ServerTestBase): class handler(tcp.BaseHandler): @@ -684,7 +684,7 @@ class TestAddress: assert repr(a) -class TestSSLKeyLogger(test.ServerTestBase): +class TestSSLKeyLogger(tservers.ServerTestBase): handler = EchoHandler ssl = dict( cipher_list="AES256-SHA" diff --git a/test/test_websockets.py b/test/test_websockets.py index 8ed14708..9956543b 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -2,8 +2,8 @@ import os from nose.tools import raises -from netlib import tcp, test, websockets, http -import tutils +from netlib import tcp, websockets, http +from . import tutils, tservers class WebSocketsEchoHandler(tcp.BaseHandler): @@ -75,7 +75,7 @@ class WebSocketsClient(tcp.TCPClient): frame.to_file(self.wfile) -class TestWebSockets(test.ServerTestBase): +class TestWebSockets(tservers.ServerTestBase): handler = WebSocketsEchoHandler def random_bytes(self, n=100): @@ -155,7 +155,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler): self.handshake_done = True -class TestBadHandshake(test.ServerTestBase): +class TestBadHandshake(tservers.ServerTestBase): """ Ensure that the client disconnects if the server handshake is malformed diff --git a/test/tservers.py b/test/tservers.py new file mode 100644 index 00000000..899b51bd --- /dev/null +++ b/test/tservers.py @@ -0,0 +1,108 @@ +from __future__ import (absolute_import, print_function, division) +import threading +import Queue +import cStringIO +import OpenSSL +from netlib import tcp, certutils +from . import tutils + + +class ServerThread(threading.Thread): + + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class ServerTestBase(object): + ssl = None + handler = None + addr = ("localhost", 0) + + @classmethod + def setupAll(cls): + cls.q = Queue.Queue() + s = cls.makeserver() + cls.port = s.address.port + cls.server = ServerThread(s) + cls.server.start() + + @classmethod + def makeserver(cls): + return TServer(cls.ssl, cls.q, cls.handler, cls.addr) + + @classmethod + def teardownAll(cls): + cls.server.shutdown() + + @property + def last_handler(self): + return self.server.server.last_handler + + +class TServer(tcp.TCPServer): + + def __init__(self, ssl, q, handler_klass, addr): + """ + ssl: A dictionary of SSL parameters: + + cert, key, request_client_cert, cipher_list, + dhparams, v3_only + """ + tcp.TCPServer.__init__(self, addr) + + if ssl is True: + self.ssl = dict() + elif isinstance(ssl, dict): + self.ssl = ssl + else: + self.ssl = None + + self.q = q + self.handler_klass = handler_klass + self.last_handler = None + + def handle_client_connection(self, request, client_address): + h = self.handler_klass(request, client_address, self) + self.last_handler = h + if self.ssl is not None: + raw_cert = self.ssl.get( + "cert", + tutils.test_data.path("data/server.crt")) + cert = certutils.SSLCert.from_pem(open(raw_cert, "rb").read()) + raw_key = self.ssl.get( + "key", + tutils.test_data.path("data/server.key")) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + open(raw_key, "rb").read()) + if self.ssl.get("v3_only", False): + method = tcp.SSLv3_METHOD + options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 + else: + method = tcp.SSLv23_METHOD + options = None + h.convert_to_ssl( + cert, key, + method=method, + options=options, + handle_sni=getattr(h, "handle_sni", None), + request_client_cert=self.ssl.get("request_client_cert", None), + cipher_list=self.ssl.get("cipher_list", None), + dhparams=self.ssl.get("dhparams", None), + chain_file=self.ssl.get("chain_file", None), + alpn_select=self.ssl.get("alpn_select", None) + ) + h.handle() + h.finish() + + def handle_error(self, connection, client_address, fp=None): + s = cStringIO.StringIO() + tcp.TCPServer.handle_error(self, connection, client_address, s) + self.q.put(s.getvalue()) |