diff options
| -rw-r--r-- | netlib/tcp.py | 34 | ||||
| -rw-r--r-- | test/test_tcp.py | 34 | 
2 files changed, 64 insertions, 4 deletions
| diff --git a/netlib/tcp.py b/netlib/tcp.py index 6ba58d86..b7f2b3bc 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -66,7 +66,7 @@ class FileLike:          if v:              try:                  return self.o.sendall(v) -            except SSL.SysCallError: +            except SSL.Error:                  raise NetLibDisconnect()      def readline(self, size = None): @@ -125,6 +125,20 @@ class TCPClient:              raise NetLibError('Error connecting to "%s": %s' % (self.host, err))          self.connection = connection +    def close(self): +        """ +            Does a hard close of the socket, i.e. a shutdown, followed by a close. +        """ +        try: +            if self.ssl_established: +                self.connection.shutdown() +            else: +                self.connection.shutdown(socket.SHUT_RDWR) +            self.connection.close() +        except (socket.error, SSL.Error): +            # Socket probably already closed +            pass +  class BaseHandler:      """ @@ -170,7 +184,7 @@ class BaseHandler:                  self.wfile.flush()              self.wfile.close()              self.rfile.close() -            self.connection.close() +            self.close()          except socket.error:              # Remote has disconnected              pass @@ -195,6 +209,20 @@ class BaseHandler:      def handle(self): # pragma: no cover          raise NotImplementedError +    def close(self): +        """ +            Does a hard close of the socket, i.e. a shutdown, followed by a close. +        """ +        try: +            if self.ssl_established: +                self.connection.shutdown() +            else: +                self.connection.shutdown(socket.SHUT_RDWR) +            self.connection.close() +        except (socket.error, SSL.Error): +            # Socket probably already closed +            pass +  class TCPServer:      request_queue_size = 20 @@ -252,7 +280,7 @@ class TCPServer:              Called when handle_connection raises an exception.          """          # If a thread has persisted after interpreter exit, the module might be -        # none.  +        # none.          if traceback:              exc = traceback.format_exc()              print >> fp, '-'*40 diff --git a/test/test_tcp.py b/test/test_tcp.py index 359890d5..cb27c63b 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -54,7 +54,7 @@ class EchoHandler(tcp.BaseHandler):  class DisconnectHandler(tcp.BaseHandler):      def handle(self): -        self.finish() +        self.close()  class TServer(tcp.TCPServer): @@ -102,6 +102,20 @@ class TestServer(ServerTestBase):          assert c.rfile.readline() == testval +class TestDisconnect(ServerTestBase): +    @classmethod +    def makeserver(cls): +        return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) + +    def test_echo(self): +        testval = "echo!\n" +        c = tcp.TCPClient("127.0.0.1", self.port) +        c.connect() +        c.wfile.write(testval) +        c.wfile.flush() +        assert c.rfile.readline() == testval + +  class TestServerSSL(ServerTestBase):      @classmethod      def makeserver(cls): @@ -154,6 +168,24 @@ class TestSSLDisconnect(ServerTestBase):          c.convert_to_ssl()          # Excercise SSL.ZeroReturnError          c.rfile.read(10) +        c.close() +        tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") +        tutils.raises(Queue.Empty, self.q.get_nowait) + + +class TestDisconnect(ServerTestBase): +    @classmethod +    def makeserver(cls): +        return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler) + +    def test_echo(self): +        c = tcp.TCPClient("127.0.0.1", self.port) +        c.connect() +        # Excercise SSL.ZeroReturnError +        c.rfile.read(10) +        c.wfile.write("foo") +        c.close() +        c.close()  class TestTCPClient: | 
