aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/tcp.py3
-rw-r--r--test/test_tcp.py37
2 files changed, 36 insertions, 4 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index e1318435..414c1237 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -284,6 +284,9 @@ class BaseHandler:
def handle(self): # pragma: no cover
raise NotImplementedError
+ def settimeout(self, n):
+ self.connection.settimeout(n)
+
def close(self):
"""
Does a hard close of the socket, i.e. a shutdown, followed by a close.
diff --git a/test/test_tcp.py b/test/test_tcp.py
index 9d581939..c833ce07 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -28,6 +28,11 @@ class ServerTestBase:
cls.server.shutdown()
+ @property
+ def last_handler(self):
+ return self.server.server.last_handler
+
+
class SNIHandler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
@@ -63,15 +68,27 @@ class HangHandler(tcp.BaseHandler):
time.sleep(1)
+class TimeoutHandler(tcp.BaseHandler):
+ def handle(self):
+ self.timeout = False
+ self.settimeout(0.01)
+ try:
+ self.rfile.read(10)
+ except tcp.NetLibTimeout:
+ self.timeout = True
+
+
class TServer(tcp.TCPServer):
- def __init__(self, addr, ssl, q, handler, v3_only=False):
+ def __init__(self, addr, ssl, q, handler_klass, v3_only=False):
tcp.TCPServer.__init__(self, addr)
self.ssl, self.q = ssl, q
self.v3_only = v3_only
- self.handler = handler
+ self.handler_klass = handler_klass
+ self.last_handler = None
def handle_connection(self, request, client_address):
- h = self.handler(request, client_address, self)
+ h = self.handler_klass(request, client_address, self)
+ self.last_handler = h
if self.ssl:
if self.v3_only:
method = tcp.SSLv3_METHOD
@@ -194,12 +211,24 @@ class TestDisconnect(ServerTestBase):
c.close()
+class TestServerTimeOut(ServerTestBase):
+ @classmethod
+ def makeserver(cls):
+ return TServer(("127.0.0.1", 0), False, cls.q, TimeoutHandler)
+
+ def test_timeout(self):
+ c = tcp.TCPClient("127.0.0.1", self.port)
+ c.connect()
+ time.sleep(0.3)
+ assert self.last_handler.timeout
+
+
class TestTimeOut(ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), False, cls.q, HangHandler)
- def test_timeout_client(self):
+ def test_timeout(self):
c = tcp.TCPClient("127.0.0.1", self.port)
c.connect()
c.settimeout(0.1)