From bda49dd178fee1361f3585bd7efad67883298e5a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 1 Feb 2016 19:38:14 +0100 Subject: fix #113, make Reader.peek() work on Python 3 --- netlib/tcp.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) (limited to 'netlib/tcp.py') diff --git a/netlib/tcp.py b/netlib/tcp.py index 8902b9dc..57a9b737 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -25,6 +25,10 @@ from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, Tl version_check.check_pyopenssl_version() +if six.PY2: + socket_fileobject = socket._fileobject +else: + socket_fileobject = socket.SocketIO EINTR = 4 @@ -270,7 +274,7 @@ class Reader(_FileLike): TlsException if there was an error with pyOpenSSL. NotImplementedError if the underlying file object is not a (pyOpenSSL) socket """ - if isinstance(self.o, socket._fileobject): + if isinstance(self.o, socket_fileobject): try: return self.o._sock.recv(length, socket.MSG_PEEK) except socket.error as e: @@ -423,8 +427,17 @@ class _Connection(object): def __init__(self, connection): if connection: self.connection = connection - self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + # Ideally, we would use the Buffered IO in Python 3 by default. + # Unfortunately, the implementation of .peek() is broken for n>1 bytes, + # as it may just return what's left in the buffer and not all the bytes we want. + # As a workaround, we just use unbuffered sockets directly. + # https://mail.python.org/pipermail/python-dev/2009-June/089986.html + if six.PY2: + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(self.connection, "rb")) + self.wfile = Writer(socket.SocketIO(self.connection, "wb")) else: self.connection = None self.rfile = None @@ -663,8 +676,15 @@ class TCPClient(_Connection): connection.connect(self.address()) if not self.source_address: self.source_address = Address(connection.getsockname()) - self.rfile = Reader(connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(connection.makefile('wb', self.wbufsize)) + + # See _Connection.__init__ why we do this dance. + if six.PY2: + self.rfile = Reader(connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(connection, "rb")) + self.wfile = Writer(socket.SocketIO(connection, "wb")) + except (socket.error, IOError) as err: raise TcpException( 'Error connecting to "%s": %s' % -- cgit v1.2.3 From a3af0ce71d5b4368f1d9de8d17ff5e20086edcc4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 1 Feb 2016 20:10:18 +0100 Subject: tests++ --- netlib/tcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'netlib/tcp.py') diff --git a/netlib/tcp.py b/netlib/tcp.py index 57a9b737..1523370b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -272,7 +272,7 @@ class Reader(_FileLike): Raises: TcpException if there was an error with the socket TlsException if there was an error with pyOpenSSL. - NotImplementedError if the underlying file object is not a (pyOpenSSL) socket + NotImplementedError if the underlying file object is not a [pyOpenSSL] socket """ if isinstance(self.o, socket_fileobject): try: -- cgit v1.2.3 From 931b5459e92ec237914d7cca9034c5a348033bdb Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 1 Feb 2016 20:19:34 +0100 Subject: remove code duplication --- netlib/tcp.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) (limited to 'netlib/tcp.py') diff --git a/netlib/tcp.py b/netlib/tcp.py index 1523370b..682db29a 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -424,20 +424,26 @@ class _Connection(object): rbufsize = -1 wbufsize = -1 + def _makefile(self): + """ + Set up .rfile and .wfile attributes from .connection + """ + # Ideally, we would use the Buffered IO in Python 3 by default. + # Unfortunately, the implementation of .peek() is broken for n>1 bytes, + # as it may just return what's left in the buffer and not all the bytes we want. + # As a workaround, we just use unbuffered sockets directly. + # https://mail.python.org/pipermail/python-dev/2009-June/089986.html + if six.PY2: + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(self.connection, "rb")) + self.wfile = Writer(socket.SocketIO(self.connection, "wb")) + def __init__(self, connection): if connection: self.connection = connection - # Ideally, we would use the Buffered IO in Python 3 by default. - # Unfortunately, the implementation of .peek() is broken for n>1 bytes, - # as it may just return what's left in the buffer and not all the bytes we want. - # As a workaround, we just use unbuffered sockets directly. - # https://mail.python.org/pipermail/python-dev/2009-June/089986.html - if six.PY2: - self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) - else: - self.rfile = Reader(socket.SocketIO(self.connection, "rb")) - self.wfile = Writer(socket.SocketIO(self.connection, "wb")) + self._makefile() else: self.connection = None self.rfile = None @@ -676,20 +682,12 @@ class TCPClient(_Connection): connection.connect(self.address()) if not self.source_address: self.source_address = Address(connection.getsockname()) - - # See _Connection.__init__ why we do this dance. - if six.PY2: - self.rfile = Reader(connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(connection.makefile('wb', self.wbufsize)) - else: - self.rfile = Reader(socket.SocketIO(connection, "rb")) - self.wfile = Writer(socket.SocketIO(connection, "wb")) - except (socket.error, IOError) as err: raise TcpException( 'Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection + self._makefile() def settimeout(self, n): self.connection.settimeout(n) -- cgit v1.2.3