diff options
Diffstat (limited to 'netlib')
| -rw-r--r-- | netlib/tcp.py | 43 | 
1 files changed, 36 insertions, 7 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py index 914aa701..de12102e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -6,6 +6,7 @@ import sys  import threading  import time  import traceback +import contextlib  import binascii  from six.moves import range @@ -577,6 +578,12 @@ class _Connection(object):          return context +@contextlib.contextmanager +def _closer(client): +    yield +    client.close() + +  class TCPClient(_Connection):      def __init__(self, address, source_address=None): @@ -708,6 +715,7 @@ class TCPClient(_Connection):          self.connection = connection          self.ip_address = Address(connection.getpeername())          self._makefile() +        return _closer(self)      def settimeout(self, n):          self.connection.settimeout(n) @@ -833,6 +841,25 @@ class BaseHandler(_Connection):              return b"" +class Counter: +    def __init__(self): +        self._count = 0 +        self._lock = threading.Lock() + +    @property +    def count(self): +        with self._lock: +            return self._count + +    def __enter__(self): +        with self._lock: +            self._count += 1 + +    def __exit__(self, *args): +        with self._lock: +            self._count -= 1 + +  class TCPServer(object):      request_queue_size = 20 @@ -845,15 +872,17 @@ class TCPServer(object):          self.socket.bind(self.address())          self.address = Address.wrap(self.socket.getsockname())          self.socket.listen(self.request_queue_size) +        self.handler_counter = Counter()      def connection_thread(self, connection, client_address): -        client_address = Address(client_address) -        try: -            self.handle_client_connection(connection, client_address) -        except: -            self.handle_error(connection, client_address) -        finally: -            close_socket(connection) +        with self.handler_counter: +            client_address = Address(client_address) +            try: +                self.handle_client_connection(connection, client_address) +            except: +                self.handle_error(connection, client_address) +            finally: +                close_socket(connection)      def serve_forever(self, poll_interval=0.1):          self.__is_shut_down.clear()  | 
