aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/tcp.py30
-rw-r--r--test/test_tcp.py35
2 files changed, 58 insertions, 7 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 8902b9dc..682db29a 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -25,6 +25,10 @@ from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, Tl
version_check.check_pyopenssl_version()
+if six.PY2:
+ socket_fileobject = socket._fileobject
+else:
+ socket_fileobject = socket.SocketIO
EINTR = 4
@@ -268,9 +272,9 @@ class Reader(_FileLike):
Raises:
TcpException if there was an error with the socket
TlsException if there was an error with pyOpenSSL.
- NotImplementedError if the underlying file object is not a (pyOpenSSL) socket
+ NotImplementedError if the underlying file object is not a [pyOpenSSL] socket
"""
- if isinstance(self.o, socket._fileobject):
+ if isinstance(self.o, socket_fileobject):
try:
return self.o._sock.recv(length, socket.MSG_PEEK)
except socket.error as e:
@@ -420,11 +424,26 @@ class _Connection(object):
rbufsize = -1
wbufsize = -1
+ def _makefile(self):
+ """
+ Set up .rfile and .wfile attributes from .connection
+ """
+ # Ideally, we would use the Buffered IO in Python 3 by default.
+ # Unfortunately, the implementation of .peek() is broken for n>1 bytes,
+ # as it may just return what's left in the buffer and not all the bytes we want.
+ # As a workaround, we just use unbuffered sockets directly.
+ # https://mail.python.org/pipermail/python-dev/2009-June/089986.html
+ if six.PY2:
+ self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
+ self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
+ else:
+ self.rfile = Reader(socket.SocketIO(self.connection, "rb"))
+ self.wfile = Writer(socket.SocketIO(self.connection, "wb"))
+
def __init__(self, connection):
if connection:
self.connection = connection
- self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
- self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
+ self._makefile()
else:
self.connection = None
self.rfile = None
@@ -663,13 +682,12 @@ class TCPClient(_Connection):
connection.connect(self.address())
if not self.source_address:
self.source_address = Address(connection.getsockname())
- self.rfile = Reader(connection.makefile('rb', self.rbufsize))
- self.wfile = Writer(connection.makefile('wb', self.wbufsize))
except (socket.error, IOError) as err:
raise TcpException(
'Error connecting to "%s": %s' %
(self.address.host, err))
self.connection = connection
+ self._makefile()
def settimeout(self, n):
self.connection.settimeout(n)
diff --git a/test/test_tcp.py b/test/test_tcp.py
index 738fb2eb..2b091ef0 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -12,7 +12,7 @@ import OpenSSL
from netlib import tcp, certutils, tutils, tservers
from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \
- TcpTimeout, TcpDisconnect, TcpException
+ TcpTimeout, TcpDisconnect, TcpException, NetlibException
class EchoHandler(tcp.BaseHandler):
@@ -713,6 +713,39 @@ class TestFileLike:
tutils.raises(TcpReadIncomplete, s.safe_read, 10)
+class TestPeek(tservers.ServerTestBase):
+ handler = EchoHandler
+
+ def _connect(self, c):
+ c.connect()
+
+ def test_peek(self):
+ testval = b"peek!\n"
+ c = tcp.TCPClient(("127.0.0.1", self.port))
+ self._connect(c)
+ c.wfile.write(testval)
+ c.wfile.flush()
+
+ assert c.rfile.peek(4) == b"peek"
+ assert c.rfile.peek(6) == b"peek!\n"
+ assert c.rfile.readline() == testval
+
+ c.close()
+ with tutils.raises(NetlibException):
+ if c.rfile.peek(1) == b"":
+ # Workaround for Python 2 on Unix:
+ # Peeking a closed connection does not raise an exception here.
+ raise NetlibException()
+
+
+class TestPeekSSL(TestPeek):
+ ssl = True
+
+ def _connect(self, c):
+ c.connect()
+ c.convert_to_ssl()
+
+
class TestAddress:
def test_simple(self):