aboutsummaryrefslogtreecommitdiffstats
path: root/test/netlib/tservers.py
blob: b344e25c951a7c0ea82faadd1114e549feebde05 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import threading
import queue
import io
import OpenSSL

from netlib import tcp
from netlib import tutils


class _ServerThread(threading.Thread):

    def __init__(self, server):
        self.server = server
        threading.Thread.__init__(self)

    def run(self):
        self.server.serve_forever()

    def shutdown(self):
        self.server.shutdown()


class _TServer(tcp.TCPServer):

    def __init__(self, ssl, q, handler_klass, addr, **kwargs):
        """
            ssl: A dictionary of SSL parameters:

                    cert, key, request_client_cert, cipher_list,
                    dhparams, v3_only
        """
        tcp.TCPServer.__init__(self, addr)

        if ssl is True:
            self.ssl = dict()
        elif isinstance(ssl, dict):
            self.ssl = ssl
        else:
            self.ssl = None

        self.q = q
        self.handler_klass = handler_klass
        if self.handler_klass is not None:
            self.handler_klass.kwargs = kwargs
        self.last_handler = None

    def handle_client_connection(self, request, client_address):
        h = self.handler_klass(request, client_address, self)
        self.last_handler = h
        if self.ssl is not None:
            cert = self.ssl.get(
                "cert",
                tutils.test_data.path("data/server.crt"))
            raw_key = self.ssl.get(
                "key",
                tutils.test_data.path("data/server.key"))
            key = OpenSSL.crypto.load_privatekey(
                OpenSSL.crypto.FILETYPE_PEM,
                open(raw_key, "rb").read())
            if self.ssl.get("v3_only", False):
                method = OpenSSL.SSL.SSLv3_METHOD
                options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1
            else:
                method = OpenSSL.SSL.SSLv23_METHOD
                options = None
            h.convert_to_ssl(
                cert, key,
                method=method,
                options=options,
                handle_sni=getattr(h, "handle_sni", None),
                request_client_cert=self.ssl.get("request_client_cert", None),
                cipher_list=self.ssl.get("cipher_list", None),
                dhparams=self.ssl.get("dhparams", None),
                chain_file=self.ssl.get("chain_file", None),
                alpn_select=self.ssl.get("alpn_select", None)
            )
        h.handle()
        h.finish()

    def handle_error(self, connection, client_address, fp=None):
        s = io.StringIO()
        tcp.TCPServer.handle_error(self, connection, client_address, s)
        self.q.put(s.getvalue())


class ServerTestBase:
    ssl = None
    handler = None
    addr = ("localhost", 0)

    @classmethod
    def setup_class(cls, **kwargs):
        cls.q = queue.Queue()
        s = cls.makeserver(**kwargs)
        cls.port = s.address.port
        cls.server = _ServerThread(s)
        cls.server.start()

    @classmethod
    def makeserver(cls, **kwargs):
        ssl = kwargs.pop('ssl', cls.ssl)
        return _TServer(ssl, cls.q, cls.handler, cls.addr, **kwargs)

    @classmethod
    def teardown_class(cls):
        cls.server.shutdown()

    def teardown(self):
        self.server.server.wait_for_silence()

    @property
    def last_handler(self):
        return self.server.server.last_handler