aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/tcp.py
diff options
context:
space:
mode:
Diffstat (limited to 'netlib/tcp.py')
-rw-r--r--netlib/tcp.py175
1 files changed, 100 insertions, 75 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index e72d5e48..6b7540aa 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -1,6 +1,13 @@
-import select, socket, threading, sys, time, traceback
+from __future__ import (absolute_import, print_function, division)
+import select
+import socket
+import sys
+import threading
+import time
+import traceback
from OpenSSL import SSL
-import certutils
+
+from . import certutils
EINTR = 4
@@ -10,32 +17,6 @@ SSLv3_METHOD = SSL.SSLv3_METHOD
SSLv23_METHOD = SSL.SSLv23_METHOD
TLSv1_METHOD = SSL.TLSv1_METHOD
-OP_ALL = SSL.OP_ALL
-OP_CIPHER_SERVER_PREFERENCE = SSL.OP_CIPHER_SERVER_PREFERENCE
-OP_COOKIE_EXCHANGE = SSL.OP_COOKIE_EXCHANGE
-OP_DONT_INSERT_EMPTY_FRAGMENTS = SSL.OP_DONT_INSERT_EMPTY_FRAGMENTS
-OP_EPHEMERAL_RSA = SSL.OP_EPHEMERAL_RSA
-OP_MICROSOFT_BIG_SSLV3_BUFFER = SSL.OP_MICROSOFT_BIG_SSLV3_BUFFER
-OP_MICROSOFT_SESS_ID_BUG = SSL.OP_MICROSOFT_SESS_ID_BUG
-OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING
-OP_NETSCAPE_CA_DN_BUG = SSL.OP_NETSCAPE_CA_DN_BUG
-OP_NETSCAPE_CHALLENGE_BUG = SSL.OP_NETSCAPE_CHALLENGE_BUG
-OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG
-OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG
-OP_NO_QUERY_MTU = SSL.OP_NO_QUERY_MTU
-OP_NO_SSLv2 = SSL.OP_NO_SSLv2
-OP_NO_SSLv3 = SSL.OP_NO_SSLv3
-OP_NO_TICKET = SSL.OP_NO_TICKET
-OP_NO_TLSv1 = SSL.OP_NO_TLSv1
-OP_PKCS1_CHECK_1 = SSL.OP_PKCS1_CHECK_1
-OP_PKCS1_CHECK_2 = SSL.OP_PKCS1_CHECK_2
-OP_SINGLE_DH_USE = SSL.OP_SINGLE_DH_USE
-OP_SSLEAY_080_CLIENT_DH_BUG = SSL.OP_SSLEAY_080_CLIENT_DH_BUG
-OP_SSLREF2_REUSE_CERT_TYPE_BUG = SSL.OP_SSLREF2_REUSE_CERT_TYPE_BUG
-OP_TLS_BLOCK_PADDING_BUG = SSL.OP_TLS_BLOCK_PADDING_BUG
-OP_TLS_D5_BUG = SSL.OP_TLS_D5_BUG
-OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG
-
class NetLibError(Exception): pass
class NetLibDisconnect(NetLibError): pass
@@ -212,10 +193,47 @@ class Address(object):
def use_ipv6(self, b):
self.family = socket.AF_INET6 if b else socket.AF_INET
+ def __repr__(self):
+ return repr(self.address)
+
def __eq__(self, other):
other = Address.wrap(other)
return (self.address, self.family) == (other.address, other.family)
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+def close_socket(sock):
+ """
+ Does a hard close of a socket, without emitting a RST.
+ """
+ try:
+ # We already indicate that we close our end.
+ # If we close RD, any further received bytes would result in a RST being set, which we want to avoid
+ # for our purposes
+ sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux
+
+ # Section 4.2.2.13 of RFC 1122 tells us that a close() with any
+ # pending readable data could lead to an immediate RST being sent (which is the case on Windows).
+ # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
+ #
+ # However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above:
+ # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that
+ # recv() would block infinitely.
+ # As a workaround, we set a timeout here even if we are in blocking mode.
+ # Please let us know if you have a better solution to this problem.
+
+ sock.settimeout(sock.gettimeout() or 20)
+ # may raise a timeout/disconnect exception.
+ while sock.recv(4096): # pragma: no cover
+ pass
+
+ except socket.error:
+ pass
+
+ sock.close()
+
class _Connection(object):
def get_current_cipher(self):
@@ -229,40 +247,39 @@ class _Connection(object):
def finish(self):
self.finished = True
- try:
+
+ # If we have an SSL connection, wfile.close == connection.close
+ # (We call _FileLike.set_descriptor(conn))
+ # Closing the socket is not our task, therefore we don't call close then.
+ if type(self.connection) != SSL.Connection:
if not getattr(self.wfile, "closed", False):
- self.wfile.flush()
- self.close()
+ try:
+ self.wfile.flush()
+ except NetLibDisconnect:
+ pass
+
self.wfile.close()
self.rfile.close()
- except (socket.error, NetLibDisconnect):
- # Remote has disconnected
- pass
-
- def close(self):
- """
- Does a hard close of the socket, i.e. a shutdown, followed by a close.
- """
- try:
- if self.ssl_established:
+ else:
+ try:
self.connection.shutdown()
- self.connection.sock_shutdown(socket.SHUT_WR)
- else:
- self.connection.shutdown(socket.SHUT_WR)
- #Section 4.2.2.13 of RFC 1122 tells us that a close() with any
- # pending readable data could lead to an immediate RST being sent.
- #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
- while self.connection.recv(4096): # pragma: no cover
+ except SSL.Error:
pass
- self.connection.close()
- except (socket.error, SSL.Error, IOError):
- # Socket probably already closed
- pass
class TCPClient(_Connection):
rbufsize = -1
wbufsize = -1
+
+ def close(self):
+ # Make sure to close the real socket, not the SSL proxy.
+ # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
+ # it tries to renegotiate...
+ if type(self.connection) == SSL.Connection:
+ close_socket(self.connection._socket)
+ else:
+ close_socket(self.connection)
+
def __init__(self, address, source_address=None):
self.address = Address.wrap(address)
self.source_address = Address.wrap(source_address) if source_address else None
@@ -274,6 +291,8 @@ class TCPClient(_Connection):
def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None, cipher_list=None):
"""
cert: Path to a file containing both client cert and private key.
+
+ options: A bit field consisting of OpenSSL.SSL.OP_* values
"""
context = SSL.Context(method)
if cipher_list:
@@ -290,7 +309,6 @@ class TCPClient(_Connection):
except SSL.Error, v:
raise NetLibError("SSL client certificate error: %s"%str(v))
self.connection = SSL.Connection(context, self.connection)
- self.ssl_established = True
if sni:
self.sni = sni
self.connection.set_tlsext_host_name(sni)
@@ -298,7 +316,8 @@ class TCPClient(_Connection):
try:
self.connection.do_handshake()
except SSL.Error, v:
- raise NetLibError("SSL handshake error: %s"%str(v))
+ raise NetLibError("SSL handshake error: %s"%repr(v))
+ self.ssl_established = True
self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)
@@ -309,6 +328,8 @@ class TCPClient(_Connection):
if self.source_address:
connection.bind(self.source_address())
connection.connect(self.address())
+ if not self.source_address:
+ self.source_address = Address(connection.getsockname())
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
except (socket.error, IOError), err:
@@ -343,21 +364,25 @@ class BaseHandler(_Connection):
def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None,
handle_sni=None, request_client_cert=None, cipher_list=None,
- dhparams=None ):
+ dhparams=None, chain_file=None):
"""
cert: A certutils.SSLCert object.
+
method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD
+
handle_sni: SNI handler, should take a connection object. Server
name can be retrieved like this:
- connection.get_servername()
+ connection.get_servername()
+
+ options: A bit field consisting of OpenSSL.SSL.OP_* values
- And you can specify the connection keys as follows:
+ And you can specify the connection keys as follows:
- new_context = Context(TLSv1_METHOD)
- new_context.use_privatekey(key)
- new_context.use_certificate(cert)
- connection.set_context(new_context)
+ new_context = Context(TLSv1_METHOD)
+ new_context.use_privatekey(key)
+ new_context.use_certificate(cert)
+ connection.set_context(new_context)
The request_client_cert argument requires some explanation. We're
supposed to be able to do this with no negative effects - if the
@@ -371,6 +396,8 @@ class BaseHandler(_Connection):
ctx = SSL.Context(method)
if not options is None:
ctx.set_options(options)
+ if chain_file:
+ ctx.load_verify_locations(chain_file)
if cipher_list:
try:
ctx.set_cipher_list(cipher_list)
@@ -398,12 +425,12 @@ class BaseHandler(_Connection):
"""
ctx = self._create_ssl_context(cert, key, **sslctx_kwargs)
self.connection = SSL.Connection(ctx, self.connection)
- self.ssl_established = True
self.connection.set_accept_state()
try:
self.connection.do_handshake()
except SSL.Error, v:
- raise NetLibError("SSL handshake error: %s"%str(v))
+ raise NetLibError("SSL handshake error: %s"%repr(v))
+ self.ssl_established = True
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)
@@ -414,7 +441,6 @@ class BaseHandler(_Connection):
self.connection.settimeout(n)
-
class TCPServer(object):
request_queue_size = 20
def __init__(self, address):
@@ -434,11 +460,7 @@ class TCPServer(object):
except:
self.handle_error(connection, client_address)
finally:
- try:
- connection.shutdown(socket.SHUT_RDWR)
- except:
- pass
- connection.close()
+ close_socket(connection)
def serve_forever(self, poll_interval=0.1):
self.__is_shut_down.clear()
@@ -450,7 +472,7 @@ class TCPServer(object):
if ex[0] == EINTR:
continue
else:
- raise
+ raise
if self.socket in r:
connection, client_address = self.socket.accept()
t = threading.Thread(
@@ -472,7 +494,7 @@ class TCPServer(object):
self.socket.close()
self.handle_shutdown()
- def handle_error(self, request, client_address, fp=sys.stderr):
+ def handle_error(self, connection, client_address, fp=sys.stderr):
"""
Called when handle_client_connection raises an exception.
"""
@@ -480,10 +502,13 @@ class TCPServer(object):
# none.
if traceback:
exc = traceback.format_exc()
- print >> fp, '-'*40
- print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port)
- print >> fp, exc
- print >> fp, '-'*40
+ print('-' * 40, file=fp)
+ print(
+ "Error in processing of request from %s:%s" % (
+ client_address.host, client_address.port
+ ), file=fp)
+ print(exc, file=fp)
+ print('-' * 40, file=fp)
def handle_client_connection(self, conn, client_address): # pragma: no cover
"""