aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/tcp.py36
-rw-r--r--test/test_tcp.py33
2 files changed, 60 insertions, 9 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 3aee4c74..8771e789 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -1,4 +1,4 @@
-import select, socket, threading, traceback, sys
+import select, socket, threading, traceback, sys, time
from OpenSSL import SSL
import certutils
@@ -35,8 +35,8 @@ OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG
class NetLibError(Exception): pass
-
class NetLibDisconnect(Exception): pass
+class NetLibTimeout(Exception): pass
class FileLike:
@@ -47,15 +47,25 @@ class FileLike:
return getattr(self.o, attr)
def flush(self):
- pass
+ if hasattr(self.o, "flush"):
+ self.o.flush()
def read(self, length):
result = ''
+ start = time.time()
while length > 0:
try:
data = self.o.read(length)
except (SSL.ZeroReturnError, SSL.SysCallError):
break
+ except SSL.WantReadError:
+ if (time.time() - start) < self.o.gettimeout():
+ time.sleep(0.1)
+ continue
+ else:
+ raise NetLibTimeout
+ except socket.timeout:
+ raise NetLibTimeout
if not data:
break
result += data
@@ -65,7 +75,11 @@ class FileLike:
def write(self, v):
if v:
try:
- return self.o.sendall(v)
+ if hasattr(self.o, "sendall"):
+ return self.o.sendall(v)
+ else:
+ r = self.o.write(v)
+ return r
except SSL.Error:
raise NetLibDisconnect()
@@ -119,12 +133,18 @@ class TCPClient:
addr = socket.gethostbyname(self.host)
connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
connection.connect((addr, self.port))
- self.rfile = connection.makefile('rb', self.rbufsize)
- self.wfile = connection.makefile('wb', self.wbufsize)
+ self.rfile = FileLike(connection.makefile('rb', self.rbufsize))
+ self.wfile = FileLike(connection.makefile('wb', self.wbufsize))
except socket.error, err:
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
self.connection = connection
+ def settimeout(self, n):
+ self.connection.settimeout(n)
+
+ def gettimeout(self):
+ self.connection.gettimeout()
+
def close(self):
"""
Does a hard close of the socket, i.e. a shutdown, followed by a close.
@@ -148,8 +168,8 @@ class BaseHandler:
wbufsize = -1
def __init__(self, connection, client_address, server):
self.connection = connection
- self.rfile = self.connection.makefile('rb', self.rbufsize)
- self.wfile = self.connection.makefile('wb', self.wbufsize)
+ self.rfile = FileLike(self.connection.makefile('rb', self.rbufsize))
+ self.wfile = FileLike(self.connection.makefile('wb', self.wbufsize))
self.client_address = client_address
self.server = server
diff --git a/test/test_tcp.py b/test/test_tcp.py
index cb27c63b..d6235b01 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -1,4 +1,4 @@
-import cStringIO, threading, Queue
+import cStringIO, threading, Queue, time
from netlib import tcp, certutils
import tutils
@@ -57,6 +57,12 @@ class DisconnectHandler(tcp.BaseHandler):
self.close()
+class HangHandler(tcp.BaseHandler):
+ def handle(self):
+ while 1:
+ time.sleep(1)
+
+
class TServer(tcp.TCPServer):
def __init__(self, addr, ssl, q, handler, v3_only=False):
tcp.TCPServer.__init__(self, addr)
@@ -188,6 +194,31 @@ class TestDisconnect(ServerTestBase):
c.close()
+class TestTimeOut(ServerTestBase):
+ @classmethod
+ def makeserver(cls):
+ return TServer(("127.0.0.1", 0), False, cls.q, HangHandler)
+
+ def test_timeout_client(self):
+ c = tcp.TCPClient("127.0.0.1", self.port)
+ c.connect()
+ c.settimeout(0.1)
+ tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
+
+
+class TestSSLTimeOut(ServerTestBase):
+ @classmethod
+ def makeserver(cls):
+ return TServer(("127.0.0.1", 0), True, cls.q, HangHandler)
+
+ def test_timeout_client(self):
+ c = tcp.TCPClient("127.0.0.1", self.port)
+ c.connect()
+ c.convert_to_ssl()
+ c.settimeout(0.1)
+ tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
+
+
class TestTCPClient:
def test_conerr(self):
c = tcp.TCPClient("127.0.0.1", 0)