From e60860e65d06d2b4452b7ea94902d79eed11d78c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 3 Jun 2016 12:06:36 +1200 Subject: Make tcp.Client.connect return a context manager that closes the connection --- netlib/tcp.py | 8 ++++++++ pathod/pathoc.py | 3 ++- test/pathod/tutils.py | 36 ++++++++++++++++++------------------ 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index bb0c93a9..61209d64 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -6,6 +6,7 @@ import sys import threading import time import traceback +import contextlib import binascii from six.moves import range @@ -577,6 +578,12 @@ class _Connection(object): return context +@contextlib.contextmanager +def _closer(client): + yield + client.close() + + class TCPClient(_Connection): def __init__(self, address, source_address=None): @@ -708,6 +715,7 @@ class TCPClient(_Connection): self.connection = connection self.ip_address = Address(connection.getpeername()) self._makefile() + return _closer(self) def settimeout(self, n): self.connection.settimeout(n) diff --git a/pathod/pathoc.py b/pathod/pathoc.py index 2b7d053c..5cfb4591 100644 --- a/pathod/pathoc.py +++ b/pathod/pathoc.py @@ -286,7 +286,7 @@ class Pathoc(tcp.TCPClient): if self.use_http2 and not self.ssl: raise NotImplementedError("HTTP2 without SSL is not supported.") - tcp.TCPClient.connect(self) + ret = tcp.TCPClient.connect(self) if connect_to: self.http_connect(connect_to) @@ -324,6 +324,7 @@ class Pathoc(tcp.TCPClient): if self.timeout: self.settimeout(self.timeout) + return ret def stop(self): if self.ws_framereader: diff --git a/test/pathod/tutils.py b/test/pathod/tutils.py index e674812b..b9f38d86 100644 --- a/test/pathod/tutils.py +++ b/test/pathod/tutils.py @@ -88,11 +88,11 @@ class DaemonTests(object): ssl=self.ssl, fp=logfp, ) - c.connect() - if params: - path = path + "?" + urllib.urlencode(params) - resp = c.request("get:%s" % path) - return resp + with c.connect(): + if params: + path = path + "?" + urllib.urlencode(params) + resp = c.request("get:%s" % path) + return resp def get(self, spec): logfp = StringIO() @@ -101,9 +101,9 @@ class DaemonTests(object): ssl=self.ssl, fp=logfp, ) - c.connect() - resp = c.request("get:/p/%s" % urllib.quote(spec).encode("string_escape")) - return resp + with c.connect(): + resp = c.request("get:/p/%s" % urllib.quote(spec).encode("string_escape")) + return resp def pathoc( self, @@ -128,16 +128,16 @@ class DaemonTests(object): fp=logfp, use_http2=use_http2, ) - c.connect(connect_to) - ret = [] - for i in specs: - resp = c.request(i) - if resp: - ret.append(resp) - for frm in c.wait(): - ret.append(frm) - c.stop() - return ret, logfp.getvalue() + with c.connect(connect_to): + ret = [] + for i in specs: + resp = c.request(i) + if resp: + ret.append(resp) + for frm in c.wait(): + ret.append(frm) + c.stop() + return ret, logfp.getvalue() tmpdir = tutils.tmpdir -- cgit v1.2.3