diff options
Diffstat (limited to 'netlib/tcp.py')
-rw-r--r-- | netlib/tcp.py | 175 |
1 files changed, 100 insertions, 75 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py index e72d5e48..6b7540aa 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,6 +1,13 @@ -import select, socket, threading, sys, time, traceback +from __future__ import (absolute_import, print_function, division) +import select +import socket +import sys +import threading +import time +import traceback from OpenSSL import SSL -import certutils + +from . import certutils EINTR = 4 @@ -10,32 +17,6 @@ SSLv3_METHOD = SSL.SSLv3_METHOD SSLv23_METHOD = SSL.SSLv23_METHOD TLSv1_METHOD = SSL.TLSv1_METHOD -OP_ALL = SSL.OP_ALL -OP_CIPHER_SERVER_PREFERENCE = SSL.OP_CIPHER_SERVER_PREFERENCE -OP_COOKIE_EXCHANGE = SSL.OP_COOKIE_EXCHANGE -OP_DONT_INSERT_EMPTY_FRAGMENTS = SSL.OP_DONT_INSERT_EMPTY_FRAGMENTS -OP_EPHEMERAL_RSA = SSL.OP_EPHEMERAL_RSA -OP_MICROSOFT_BIG_SSLV3_BUFFER = SSL.OP_MICROSOFT_BIG_SSLV3_BUFFER -OP_MICROSOFT_SESS_ID_BUG = SSL.OP_MICROSOFT_SESS_ID_BUG -OP_MSIE_SSLV2_RSA_PADDING = SSL.OP_MSIE_SSLV2_RSA_PADDING -OP_NETSCAPE_CA_DN_BUG = SSL.OP_NETSCAPE_CA_DN_BUG -OP_NETSCAPE_CHALLENGE_BUG = SSL.OP_NETSCAPE_CHALLENGE_BUG -OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_DEMO_CIPHER_CHANGE_BUG -OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG = SSL.OP_NETSCAPE_REUSE_CIPHER_CHANGE_BUG -OP_NO_QUERY_MTU = SSL.OP_NO_QUERY_MTU -OP_NO_SSLv2 = SSL.OP_NO_SSLv2 -OP_NO_SSLv3 = SSL.OP_NO_SSLv3 -OP_NO_TICKET = SSL.OP_NO_TICKET -OP_NO_TLSv1 = SSL.OP_NO_TLSv1 -OP_PKCS1_CHECK_1 = SSL.OP_PKCS1_CHECK_1 -OP_PKCS1_CHECK_2 = SSL.OP_PKCS1_CHECK_2 -OP_SINGLE_DH_USE = SSL.OP_SINGLE_DH_USE -OP_SSLEAY_080_CLIENT_DH_BUG = SSL.OP_SSLEAY_080_CLIENT_DH_BUG -OP_SSLREF2_REUSE_CERT_TYPE_BUG = SSL.OP_SSLREF2_REUSE_CERT_TYPE_BUG -OP_TLS_BLOCK_PADDING_BUG = SSL.OP_TLS_BLOCK_PADDING_BUG -OP_TLS_D5_BUG = SSL.OP_TLS_D5_BUG -OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG - class NetLibError(Exception): pass class NetLibDisconnect(NetLibError): pass @@ -212,10 +193,47 @@ class Address(object): def use_ipv6(self, b): self.family = socket.AF_INET6 if b else socket.AF_INET + def __repr__(self): + return repr(self.address) + def __eq__(self, other): other = Address.wrap(other) return (self.address, self.family) == (other.address, other.family) + def __ne__(self, other): + return not self.__eq__(other) + + +def close_socket(sock): + """ + Does a hard close of a socket, without emitting a RST. + """ + try: + # We already indicate that we close our end. + # If we close RD, any further received bytes would result in a RST being set, which we want to avoid + # for our purposes + sock.shutdown(socket.SHUT_WR) # may raise "Transport endpoint is not connected" on Linux + + # Section 4.2.2.13 of RFC 1122 tells us that a close() with any + # pending readable data could lead to an immediate RST being sent (which is the case on Windows). + # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + # + # However, we cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: + # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that + # recv() would block infinitely. + # As a workaround, we set a timeout here even if we are in blocking mode. + # Please let us know if you have a better solution to this problem. + + sock.settimeout(sock.gettimeout() or 20) + # may raise a timeout/disconnect exception. + while sock.recv(4096): # pragma: no cover + pass + + except socket.error: + pass + + sock.close() + class _Connection(object): def get_current_cipher(self): @@ -229,40 +247,39 @@ class _Connection(object): def finish(self): self.finished = True - try: + + # If we have an SSL connection, wfile.close == connection.close + # (We call _FileLike.set_descriptor(conn)) + # Closing the socket is not our task, therefore we don't call close then. + if type(self.connection) != SSL.Connection: if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.close() + try: + self.wfile.flush() + except NetLibDisconnect: + pass + self.wfile.close() self.rfile.close() - except (socket.error, NetLibDisconnect): - # Remote has disconnected - pass - - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. - """ - try: - if self.ssl_established: + else: + try: self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent. - #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): # pragma: no cover + except SSL.Error: pass - self.connection.close() - except (socket.error, SSL.Error, IOError): - # Socket probably already closed - pass class TCPClient(_Connection): rbufsize = -1 wbufsize = -1 + + def close(self): + # Make sure to close the real socket, not the SSL proxy. + # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, + # it tries to renegotiate... + if type(self.connection) == SSL.Connection: + close_socket(self.connection._socket) + else: + close_socket(self.connection) + def __init__(self, address, source_address=None): self.address = Address.wrap(address) self.source_address = Address.wrap(source_address) if source_address else None @@ -274,6 +291,8 @@ class TCPClient(_Connection): def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None, cipher_list=None): """ cert: Path to a file containing both client cert and private key. + + options: A bit field consisting of OpenSSL.SSL.OP_* values """ context = SSL.Context(method) if cipher_list: @@ -290,7 +309,6 @@ class TCPClient(_Connection): except SSL.Error, v: raise NetLibError("SSL client certificate error: %s"%str(v)) self.connection = SSL.Connection(context, self.connection) - self.ssl_established = True if sni: self.sni = sni self.connection.set_tlsext_host_name(sni) @@ -298,7 +316,8 @@ class TCPClient(_Connection): try: self.connection.do_handshake() except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%str(v)) + raise NetLibError("SSL handshake error: %s"%repr(v)) + self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) @@ -309,6 +328,8 @@ class TCPClient(_Connection): if self.source_address: connection.bind(self.source_address()) connection.connect(self.address()) + if not self.source_address: + self.source_address = Address(connection.getsockname()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: @@ -343,21 +364,25 @@ class BaseHandler(_Connection): def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=None, cipher_list=None, - dhparams=None ): + dhparams=None, chain_file=None): """ cert: A certutils.SSLCert object. + method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, or TLSv1_METHOD + handle_sni: SNI handler, should take a connection object. Server name can be retrieved like this: - connection.get_servername() + connection.get_servername() + + options: A bit field consisting of OpenSSL.SSL.OP_* values - And you can specify the connection keys as follows: + And you can specify the connection keys as follows: - new_context = Context(TLSv1_METHOD) - new_context.use_privatekey(key) - new_context.use_certificate(cert) - connection.set_context(new_context) + new_context = Context(TLSv1_METHOD) + new_context.use_privatekey(key) + new_context.use_certificate(cert) + connection.set_context(new_context) The request_client_cert argument requires some explanation. We're supposed to be able to do this with no negative effects - if the @@ -371,6 +396,8 @@ class BaseHandler(_Connection): ctx = SSL.Context(method) if not options is None: ctx.set_options(options) + if chain_file: + ctx.load_verify_locations(chain_file) if cipher_list: try: ctx.set_cipher_list(cipher_list) @@ -398,12 +425,12 @@ class BaseHandler(_Connection): """ ctx = self._create_ssl_context(cert, key, **sslctx_kwargs) self.connection = SSL.Connection(ctx, self.connection) - self.ssl_established = True self.connection.set_accept_state() try: self.connection.do_handshake() except SSL.Error, v: - raise NetLibError("SSL handshake error: %s"%str(v)) + raise NetLibError("SSL handshake error: %s"%repr(v)) + self.ssl_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) @@ -414,7 +441,6 @@ class BaseHandler(_Connection): self.connection.settimeout(n) - class TCPServer(object): request_queue_size = 20 def __init__(self, address): @@ -434,11 +460,7 @@ class TCPServer(object): except: self.handle_error(connection, client_address) finally: - try: - connection.shutdown(socket.SHUT_RDWR) - except: - pass - connection.close() + close_socket(connection) def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() @@ -450,7 +472,7 @@ class TCPServer(object): if ex[0] == EINTR: continue else: - raise + raise if self.socket in r: connection, client_address = self.socket.accept() t = threading.Thread( @@ -472,7 +494,7 @@ class TCPServer(object): self.socket.close() self.handle_shutdown() - def handle_error(self, request, client_address, fp=sys.stderr): + def handle_error(self, connection, client_address, fp=sys.stderr): """ Called when handle_client_connection raises an exception. """ @@ -480,10 +502,13 @@ class TCPServer(object): # none. if traceback: exc = traceback.format_exc() - print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) - print >> fp, exc - print >> fp, '-'*40 + print('-' * 40, file=fp) + print( + "Error in processing of request from %s:%s" % ( + client_address.host, client_address.port + ), file=fp) + print(exc, file=fp) + print('-' * 40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ |