diff options
| -rw-r--r-- | netlib/tcp.py | 23 | ||||
| -rw-r--r-- | test/test_tcp.py | 31 | 
2 files changed, 52 insertions, 2 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py index 276d3162..c8ffefdf 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -53,11 +53,13 @@ class TCPClient:          self.connection, self.rfile, self.wfile = None, None, None          self.cert = None -    def convert_to_ssl(self, clientcert=None): +    def convert_to_ssl(self, clientcert=None, sni=None):          context = SSL.Context(SSL.SSLv23_METHOD)          if clientcert:              context.use_certificate_file(self.clientcert)          self.connection = SSL.Connection(context, self.connection) +        if sni: +            self.connection.set_tlsext_host_name(sni)          self.connection.set_connect_state()          self.connection.do_handshake()          self.cert = self.connection.get_peer_certificate() @@ -92,10 +94,12 @@ class BaseHandler:      def convert_to_ssl(self, cert, key):          ctx = SSL.Context(SSL.SSLv23_METHOD) +        ctx.set_tlsext_servername_callback(self.handle_sni)          ctx.use_privatekey_file(key)          ctx.use_certificate_file(cert)          self.connection = SSL.Connection(ctx, self.connection)          self.connection.set_accept_state() +        # SNI callback happens during do_handshake()          self.connection.do_handshake()          self.rfile = FileLike(self.connection)          self.wfile = FileLike(self.connection) @@ -111,6 +115,23 @@ class BaseHandler:          except IOError: # pragma: no cover              pass +    def handle_sni(self, connection): +        """ +            Called if the client has given a server name indication. + +            Server name can be retrieved like this: + +                connection.get_servername() + +            And you can specify the connection keys as follows: + +                new_context = Context(TLSv1_METHOD) +                new_context.use_privatekey(key) +                new_context.use_certificate(cert) +                connection.set_context(new_context) +        """ +        pass +      def handle(self): # pragma: no cover          raise NotImplementedError diff --git a/test/test_tcp.py b/test/test_tcp.py index a81632e7..a2ee5e36 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -25,7 +25,21 @@ class ServerTestBase:          cls.server.shutdown() +class SNIHandler(tcp.BaseHandler): +    sni = None +    def handle_sni(self, connection): +        self.sni = connection.get_servername() + +    def handle(self): +        self.wfile.write(self.sni) +        self.wfile.flush() + +  class EchoHandler(tcp.BaseHandler): +    sni = None +    def handle_sni(self, connection): +        self.sni = connection.get_servername() +      def handle(self):          v = self.rfile.readline()          if v.startswith("echo"): @@ -90,13 +104,28 @@ class TestServerSSL(ServerTestBase):      def test_echo(self):          c = tcp.TCPClient("127.0.0.1", self.port)          c.connect() -        c.convert_to_ssl() +        c.convert_to_ssl(sni="foo.com")          testval = "echo!\n"          c.wfile.write(testval)          c.wfile.flush()          assert c.rfile.readline() == testval +class TestSNI(ServerTestBase): +    @classmethod +    def makeserver(cls): +        cls.q = Queue.Queue() +        s = TServer(("127.0.0.1", 0), True, cls.q, SNIHandler) +        cls.port = s.port +        return s + +    def test_echo(self): +        c = tcp.TCPClient("127.0.0.1", self.port) +        c.connect() +        c.convert_to_ssl(sni="foo.com") +        assert c.rfile.readline() == "foo.com" + +  class TestSSLDisconnect(ServerTestBase):      @classmethod      def makeserver(cls):  | 
