diff options
| -rw-r--r-- | netlib/tcp.py | 2 | ||||
| -rw-r--r-- | test/test_tcp.py | 42 | 
2 files changed, 34 insertions, 10 deletions
| diff --git a/netlib/tcp.py b/netlib/tcp.py index 3c5c89b7..276d3162 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -59,6 +59,7 @@ class TCPClient:              context.use_certificate_file(self.clientcert)          self.connection = SSL.Connection(context, self.connection)          self.connection.set_connect_state() +        self.connection.do_handshake()          self.cert = self.connection.get_peer_certificate()          self.rfile = FileLike(self.connection)          self.wfile = FileLike(self.connection) @@ -95,6 +96,7 @@ class BaseHandler:          ctx.use_certificate_file(cert)          self.connection = SSL.Connection(ctx, self.connection)          self.connection.set_accept_state() +        self.connection.do_handshake()          self.rfile = FileLike(self.connection)          self.wfile = FileLike(self.connection) diff --git a/test/test_tcp.py b/test/test_tcp.py index 26286bc4..a81632e7 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -25,13 +25,8 @@ class ServerTestBase:          cls.server.shutdown() -class THandler(tcp.BaseHandler): +class EchoHandler(tcp.BaseHandler):      def handle(self): -        if self.server.ssl: -            self.convert_to_ssl( -                tutils.test_data.path("data/server.crt"), -                tutils.test_data.path("data/server.key"), -            )          v = self.rfile.readline()          if v.startswith("echo"):              self.wfile.write(v) @@ -40,13 +35,24 @@ class THandler(tcp.BaseHandler):          self.wfile.flush() +class DisconnectHandler(tcp.BaseHandler): +    def handle(self): +        self.finish() + +  class TServer(tcp.TCPServer): -    def __init__(self, addr, ssl, q): +    def __init__(self, addr, ssl, q, handler):          tcp.TCPServer.__init__(self, addr)          self.ssl, self.q = ssl, q +        self.handler = handler      def handle_connection(self, request, client_address): -        h = THandler(request, client_address, self) +        h = self.handler(request, client_address, self) +        if self.ssl: +            h.convert_to_ssl( +                tutils.test_data.path("data/server.crt"), +                tutils.test_data.path("data/server.key"), +            )          h.handle()          h.finish() @@ -60,7 +66,7 @@ class TestServer(ServerTestBase):      @classmethod      def makeserver(cls):          cls.q = Queue.Queue() -        s = TServer(("127.0.0.1", 0), False, cls.q) +        s = TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)          cls.port = s.port          return s @@ -77,7 +83,7 @@ class TestServerSSL(ServerTestBase):      @classmethod      def makeserver(cls):          cls.q = Queue.Queue() -        s = TServer(("127.0.0.1", 0), True, cls.q) +        s = TServer(("127.0.0.1", 0), True, cls.q, EchoHandler)          cls.port = s.port          return s @@ -91,6 +97,22 @@ class TestServerSSL(ServerTestBase):          assert c.rfile.readline() == testval +class TestSSLDisconnect(ServerTestBase): +    @classmethod +    def makeserver(cls): +        cls.q = Queue.Queue() +        s = TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler) +        cls.port = s.port +        return s + +    def test_echo(self): +        c = tcp.TCPClient("127.0.0.1", self.port) +        c.connect() +        c.convert_to_ssl() +        # Excercise SSL.ZeroReturnError +        c.rfile.read(10) + +  class TestTCPClient:      def test_conerr(self):          c = tcp.TCPClient("127.0.0.1", 0) | 
