From 6b585023fd4ef068df7452a77f52b0c2ff490fd5 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 16 Feb 2016 21:31:07 +0100 Subject: move tservers helper --- netlib/netlib/tservers.py | 109 ----------------------------- test/mitmproxy/test_protocol_http2.py | 2 +- test/netlib/http/http2/test_connections.py | 3 +- test/netlib/test_tcp.py | 3 +- test/netlib/tservers.py | 109 +++++++++++++++++++++++++++++ test/netlib/websockets/test_websockets.py | 4 +- 6 files changed, 116 insertions(+), 114 deletions(-) delete mode 100644 netlib/netlib/tservers.py create mode 100644 test/netlib/tservers.py diff --git a/netlib/netlib/tservers.py b/netlib/netlib/tservers.py deleted file mode 100644 index 44ef8063..00000000 --- a/netlib/netlib/tservers.py +++ /dev/null @@ -1,109 +0,0 @@ -from __future__ import (absolute_import, print_function, division) - -import threading -from six.moves import queue -from io import StringIO -import OpenSSL - -from netlib import tcp -from netlib import tutils - - -class ServerThread(threading.Thread): - - def __init__(self, server): - self.server = server - threading.Thread.__init__(self) - - def run(self): - self.server.serve_forever() - - def shutdown(self): - self.server.shutdown() - - -class ServerTestBase(object): - ssl = None - handler = None - addr = ("localhost", 0) - - @classmethod - def setup_class(cls): - cls.q = queue.Queue() - s = cls.makeserver() - cls.port = s.address.port - cls.server = ServerThread(s) - cls.server.start() - - @classmethod - def makeserver(cls): - return TServer(cls.ssl, cls.q, cls.handler, cls.addr) - - @classmethod - def teardown_class(cls): - cls.server.shutdown() - - @property - def last_handler(self): - return self.server.server.last_handler - - -class TServer(tcp.TCPServer): - - def __init__(self, ssl, q, handler_klass, addr): - """ - ssl: A dictionary of SSL parameters: - - cert, key, request_client_cert, cipher_list, - dhparams, v3_only - """ - tcp.TCPServer.__init__(self, addr) - - if ssl is True: - self.ssl = dict() - elif isinstance(ssl, dict): - self.ssl = ssl - else: - self.ssl = None - - self.q = q - self.handler_klass = handler_klass - self.last_handler = None - - def handle_client_connection(self, request, client_address): - h = self.handler_klass(request, client_address, self) - self.last_handler = h - if self.ssl is not None: - cert = self.ssl.get( - "cert", - tutils.test_data.path("data/server.crt")) - raw_key = self.ssl.get( - "key", - tutils.test_data.path("data/server.key")) - key = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - open(raw_key, "rb").read()) - if self.ssl.get("v3_only", False): - method = OpenSSL.SSL.SSLv3_METHOD - options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 - else: - method = OpenSSL.SSL.SSLv23_METHOD - options = None - h.convert_to_ssl( - cert, key, - method=method, - options=options, - handle_sni=getattr(h, "handle_sni", None), - request_client_cert=self.ssl.get("request_client_cert", None), - cipher_list=self.ssl.get("cipher_list", None), - dhparams=self.ssl.get("dhparams", None), - chain_file=self.ssl.get("chain_file", None), - alpn_select=self.ssl.get("alpn_select", None) - ) - h.handle() - h.finish() - - def handle_error(self, connection, client_address, fp=None): - s = StringIO() - tcp.TCPServer.handle_error(self, connection, client_address, s) - self.q.put(s.getvalue()) diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py index fc27cc3f..1da140d8 100644 --- a/test/mitmproxy/test_protocol_http2.py +++ b/test/mitmproxy/test_protocol_http2.py @@ -18,7 +18,7 @@ logging.getLogger("PIL.Image").setLevel(logging.WARNING) logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING) import netlib -from netlib import tservers as netlib_tservers +from ..netlib import tservers as netlib_tservers from netlib.utils import http2_read_raw_frame import h2 diff --git a/test/netlib/http/http2/test_connections.py b/test/netlib/http/http2/test_connections.py index 8be127e4..c067d487 100644 --- a/test/netlib/http/http2/test_connections.py +++ b/test/netlib/http/http2/test_connections.py @@ -4,11 +4,12 @@ import codecs from hyperframe.frame import * -from netlib import tcp, http, utils, tservers +from netlib import tcp, http, utils from netlib.tutils import raises from netlib.exceptions import TcpDisconnect from netlib.http.http2.connections import HTTP2Protocol, TCPHandler +from ... import tservers class TestTCPHandlerWrapper: def test_wrapped(self): diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py index 8ae3aa51..e65a2e2f 100644 --- a/test/netlib/test_tcp.py +++ b/test/netlib/test_tcp.py @@ -10,10 +10,11 @@ import mock from OpenSSL import SSL import OpenSSL -from netlib import tcp, certutils, tutils, tservers +from netlib import tcp, certutils, tutils from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ TcpTimeout, TcpDisconnect, TcpException, NetlibException +from . import tservers class EchoHandler(tcp.BaseHandler): sni = None diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py new file mode 100644 index 00000000..569745e6 --- /dev/null +++ b/test/netlib/tservers.py @@ -0,0 +1,109 @@ +from __future__ import (absolute_import, print_function, division) + +import threading +from six.moves import queue +from io import StringIO +import OpenSSL + +from netlib import tcp +from netlib import tutils + + +class _ServerThread(threading.Thread): + + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class _TServer(tcp.TCPServer): + + def __init__(self, ssl, q, handler_klass, addr): + """ + ssl: A dictionary of SSL parameters: + + cert, key, request_client_cert, cipher_list, + dhparams, v3_only + """ + tcp.TCPServer.__init__(self, addr) + + if ssl is True: + self.ssl = dict() + elif isinstance(ssl, dict): + self.ssl = ssl + else: + self.ssl = None + + self.q = q + self.handler_klass = handler_klass + self.last_handler = None + + def handle_client_connection(self, request, client_address): + h = self.handler_klass(request, client_address, self) + self.last_handler = h + if self.ssl is not None: + cert = self.ssl.get( + "cert", + tutils.test_data.path("data/server.crt")) + raw_key = self.ssl.get( + "key", + tutils.test_data.path("data/server.key")) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + open(raw_key, "rb").read()) + if self.ssl.get("v3_only", False): + method = OpenSSL.SSL.SSLv3_METHOD + options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 + else: + method = OpenSSL.SSL.SSLv23_METHOD + options = None + h.convert_to_ssl( + cert, key, + method=method, + options=options, + handle_sni=getattr(h, "handle_sni", None), + request_client_cert=self.ssl.get("request_client_cert", None), + cipher_list=self.ssl.get("cipher_list", None), + dhparams=self.ssl.get("dhparams", None), + chain_file=self.ssl.get("chain_file", None), + alpn_select=self.ssl.get("alpn_select", None) + ) + h.handle() + h.finish() + + def handle_error(self, connection, client_address, fp=None): + s = StringIO() + tcp.TCPServer.handle_error(self, connection, client_address, s) + self.q.put(s.getvalue()) + + +class ServerTestBase(object): + ssl = None + handler = None + addr = ("localhost", 0) + + @classmethod + def setup_class(cls): + cls.q = queue.Queue() + s = cls.makeserver() + cls.port = s.address.port + cls.server = _ServerThread(s) + cls.server.start() + + @classmethod + def makeserver(cls): + return _TServer(cls.ssl, cls.q, cls.handler, cls.addr) + + @classmethod + def teardown_class(cls): + cls.server.shutdown() + + @property + def last_handler(self): + return self.server.server.last_handler diff --git a/test/netlib/websockets/test_websockets.py b/test/netlib/websockets/test_websockets.py index d53f0d83..a7d782a4 100644 --- a/test/netlib/websockets/test_websockets.py +++ b/test/netlib/websockets/test_websockets.py @@ -2,12 +2,12 @@ import os from netlib.http.http1 import read_response, read_request -from netlib import tcp, websockets, http, tutils, tservers +from netlib import tcp, websockets, http, tutils from netlib.http import status_codes from netlib.tutils import treq - from netlib.exceptions import * +from .. import tservers class WebSocketsEchoHandler(tcp.BaseHandler): -- cgit v1.2.3