aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/tcp.py30
-rw-r--r--pathod/pathoc.py67
-rw-r--r--test/pathod/test_pathoc.py8
3 files changed, 62 insertions, 43 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index acd67cad..a8a68139 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -6,7 +6,6 @@ import sys
import threading
import time
import traceback
-import contextlib
import binascii
from six.moves import range
@@ -582,12 +581,24 @@ class _Connection(object):
return context
-@contextlib.contextmanager
-def _closer(client):
- try:
- yield
- finally:
- client.close()
+class ConnectionCloser(object):
+ def __init__(self, conn):
+ self.conn = conn
+ self._canceled = False
+
+ def pop(self):
+ """
+ Cancel the current closer, and return a fresh one.
+ """
+ self._canceled = True
+ return ConnectionCloser(self.conn)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ if not self._canceled:
+ self.conn.close()
class TCPClient(_Connection):
@@ -717,11 +728,12 @@ class TCPClient(_Connection):
except (socket.error, IOError) as err:
raise exceptions.TcpException(
'Error connecting to "%s": %s' %
- (self.address.host, err))
+ (self.address.host, err)
+ )
self.connection = connection
self.ip_address = Address(connection.getpeername())
self._makefile()
- return _closer(self)
+ return ConnectionCloser(self)
def settimeout(self, n):
self.connection.settimeout(n)
diff --git a/pathod/pathoc.py b/pathod/pathoc.py
index b2563988..21fc9845 100644
--- a/pathod/pathoc.py
+++ b/pathod/pathoc.py
@@ -291,44 +291,45 @@ class Pathoc(tcp.TCPClient):
if self.use_http2 and not self.ssl:
raise NotImplementedError("HTTP2 without SSL is not supported.")
- ret = tcp.TCPClient.connect(self)
- if connect_to:
- self.http_connect(connect_to)
+ with tcp.TCPClient.connect(self) as closer:
+ if connect_to:
+ self.http_connect(connect_to)
- self.sslinfo = None
- if self.ssl:
- try:
- alpn_protos = [b'http/1.1']
- if self.use_http2:
- alpn_protos.append(b'h2')
-
- self.convert_to_ssl(
- sni=self.sni,
- cert=self.clientcert,
- method=self.ssl_version,
- options=self.ssl_options,
- cipher_list=self.ciphers,
- alpn_protos=alpn_protos
+ self.sslinfo = None
+ if self.ssl:
+ try:
+ alpn_protos = [b'http/1.1']
+ if self.use_http2:
+ alpn_protos.append(b'h2')
+
+ self.convert_to_ssl(
+ sni=self.sni,
+ cert=self.clientcert,
+ method=self.ssl_version,
+ options=self.ssl_options,
+ cipher_list=self.ciphers,
+ alpn_protos=alpn_protos
+ )
+ except exceptions.TlsException as v:
+ raise PathocError(str(v))
+
+ self.sslinfo = SSLInfo(
+ self.connection.get_peer_cert_chain(),
+ self.get_current_cipher(),
+ self.get_alpn_proto_negotiated()
)
- except exceptions.TlsException as v:
- raise PathocError(str(v))
+ if showssl:
+ print(str(self.sslinfo), file=fp)
- self.sslinfo = SSLInfo(
- self.connection.get_peer_cert_chain(),
- self.get_current_cipher(),
- self.get_alpn_proto_negotiated()
- )
- if showssl:
- print(str(self.sslinfo), file=fp)
+ if self.use_http2:
+ self.protocol.check_alpn()
+ if not self.http2_skip_connection_preface:
+ self.protocol.perform_client_connection_preface()
- if self.use_http2:
- self.protocol.check_alpn()
- if not self.http2_skip_connection_preface:
- self.protocol.perform_client_connection_preface()
+ if self.timeout:
+ self.settimeout(self.timeout)
- if self.timeout:
- self.settimeout(self.timeout)
- return ret
+ return closer.pop()
def stop(self):
if self.ws_framereader:
diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py
index dc5011c0..77d4721c 100644
--- a/test/pathod/test_pathoc.py
+++ b/test/pathod/test_pathoc.py
@@ -83,7 +83,13 @@ class TestDaemon(PathocTestDaemon):
def test_ssl_error(self):
c = pathoc.Pathoc(("127.0.0.1", self.d.port), ssl=True, fp=None)
- tutils.raises("ssl handshake", c.connect)
+ try:
+ with c.connect():
+ pass
+ except Exception as e:
+ assert "SSL" in str(e)
+ else:
+ raise AssertionError("No exception raised.")
def test_showssl(self):
assert "certificate chain" not in self.tval(