From 951f2d517fa2e464d654a54bebacbd983f944c62 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Jan 2014 01:57:37 +0100 Subject: change parameter names to reflect changes --- netlib/tcp.py | 29 +++++++++++++---------------- netlib/test.py | 2 +- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 33f7ef3a..d35818bf 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -138,8 +138,8 @@ class Reader(_FileLike): raise NetLibTimeout except socket.timeout: raise NetLibTimeout - except socket.error: - raise NetLibDisconnect + except socket.error, v: + raise NetLibDisconnect(v[1]) except SSL.SysCallError: raise NetLibDisconnect except SSL.Error, v: @@ -255,16 +255,13 @@ class BaseHandler: """ rbufsize = -1 wbufsize = -1 - def __init__(self, connection, client_address, server): + def __init__(self, connection): self.connection = connection self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) - self.client_address = client_address - self.server = server self.finished = False self.ssl_established = False - self.clientcert = None def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None): @@ -371,13 +368,13 @@ class TCPServer: self.port = self.server_address[1] self.socket.listen(self.request_queue_size) - def request_thread(self, request, client_address): + def connection_thread(self, connection, client_address): try: - self.handle_connection(request, client_address) - request.close() + self.handle_client_connection(connection, client_address) except: - self.handle_error(request, client_address) - request.close() + self.handle_error(connection, client_address) + finally: + connection.close() def serve_forever(self, poll_interval=0.1): self.__is_shut_down.clear() @@ -391,10 +388,10 @@ class TCPServer: else: raise if self.socket in r: - request, client_address = self.socket.accept() + connection, client_address = self.socket.accept() t = threading.Thread( - target = self.request_thread, - args = (request, client_address) + target = self.connection_thread, + args = (connection, client_address) ) t.setDaemon(1) t.start() @@ -410,7 +407,7 @@ class TCPServer: def handle_error(self, request, client_address, fp=sys.stderr): """ - Called when handle_connection raises an exception. + Called when handle_client_connection raises an exception. """ # If a thread has persisted after interpreter exit, the module might be # none. @@ -421,7 +418,7 @@ class TCPServer: print >> fp, exc print >> fp, '-'*40 - def handle_connection(self, request, client_address): # pragma: no cover + def handle_client_connection(self, conn, client_address): # pragma: no cover """ Called after client connection. """ diff --git a/netlib/test.py b/netlib/test.py index cd1a3847..0c36da6a 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -50,7 +50,7 @@ class TServer(tcp.TCPServer): self.handler_klass = handler_klass self.last_handler = None - def handle_connection(self, request, client_address): + def handle_client_connection(self, request, client_address): h = self.handler_klass(request, client_address, self) self.last_handler = h if self.ssl: -- cgit v1.2.3 From d0a6d2e2545089893d3789e3c787e269645df852 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Jan 2014 05:33:21 +0100 Subject: fix tests, remove duplicate code --- netlib/tcp.py | 91 ++++++++++++++++++++++++---------------------------------- netlib/test.py | 2 +- 2 files changed, 38 insertions(+), 55 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index d35818bf..e48f4f6b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -138,8 +138,8 @@ class Reader(_FileLike): raise NetLibTimeout except socket.timeout: raise NetLibTimeout - except socket.error, v: - raise NetLibDisconnect(v[1]) + except socket.error: + raise NetLibDisconnect except SSL.SysCallError: raise NetLibDisconnect except SSL.Error, v: @@ -173,7 +173,40 @@ class Reader(_FileLike): return result -class TCPClient: +class SocketCloseMixin: + def finish(self): + self.finished = True + try: + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.close() + self.wfile.close() + self.rfile.close() + except (socket.error, NetLibDisconnect): + # Remote has disconnected + pass + + def close(self): + """ + Does a hard close of the socket, i.e. a shutdown, followed by a close. + """ + try: + if self.ssl_established: + self.connection.shutdown() + self.connection.sock_shutdown(socket.SHUT_WR) + else: + self.connection.shutdown(socket.SHUT_WR) + #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. + #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html + while self.connection.recv(4096): + pass + self.connection.close() + except (socket.error, SSL.Error, IOError): + # Socket probably already closed + pass + + +class TCPClient(SocketCloseMixin): rbufsize = -1 wbufsize = -1 def __init__(self, host, port, source_address=None, use_ipv6=False): @@ -228,27 +261,8 @@ class TCPClient: def gettimeout(self): return self.connection.gettimeout() - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. - """ - try: - if self.ssl_established: - self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent. - #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): - pass - self.connection.close() - except (socket.error, SSL.Error, IOError): - # Socket probably already closed - pass - -class BaseHandler: +class BaseHandler(SocketCloseMixin): """ The instantiator is expected to call the handle() and finish() methods. @@ -315,43 +329,12 @@ class BaseHandler: self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) - def finish(self): - self.finished = True - try: - if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.close() - self.wfile.close() - self.rfile.close() - except (socket.error, NetLibDisconnect): - # Remote has disconnected - pass - def handle(self): # pragma: no cover raise NotImplementedError def settimeout(self, n): self.connection.settimeout(n) - def close(self): - """ - Does a hard close of the socket, i.e. a shutdown, followed by a close. - """ - try: - if self.ssl_established: - self.connection.shutdown() - self.connection.sock_shutdown(socket.SHUT_WR) - else: - self.connection.shutdown(socket.SHUT_WR) - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any - # pending readable data could lead to an immediate RST being sent. - # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - while self.connection.recv(4096): - pass - except (socket.error, SSL.Error): - # Socket probably already closed - pass - self.connection.close() class TCPServer: diff --git a/netlib/test.py b/netlib/test.py index 2209ebc3..f5599082 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -51,7 +51,7 @@ class TServer(tcp.TCPServer): self.last_handler = None def handle_client_connection(self, request, client_address): - h = self.handler_klass(request, client_address, self) + h = self.handler_klass(request) self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( -- cgit v1.2.3 From 763cb90b66b23cd94b6e37df3d4c7b8e7f89492a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 28 Jan 2014 17:26:35 +0100 Subject: add tcp.Address to unify ipv4/ipv6 address handling --- netlib/certutils.py | 2 +- netlib/tcp.py | 56 +++++++++++++++++++++++++++++++++++++++-------------- netlib/test.py | 11 +++++------ test/test_http.py | 2 +- test/test_tcp.py | 36 +++++++++++++++++----------------- 5 files changed, 67 insertions(+), 40 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index 0349bec7..94294f6e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -237,7 +237,7 @@ class SSLCert: def get_remote_cert(host, port, sni): - c = tcp.TCPClient(host, port) + c = tcp.TCPClient((host, port)) c.connect() c.convert_to_ssl(sni=sni) return c.cert diff --git a/netlib/tcp.py b/netlib/tcp.py index e48f4f6b..bad166d0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -173,6 +173,35 @@ class Reader(_FileLike): return result +class Address(tuple): + """ + This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. + """ + def __new__(cls, address, use_ipv6=False): + a = super(Address, cls).__new__(cls, tuple(address)) + a.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET + return a + + @classmethod + def wrap(cls, t): + if isinstance(t, cls): + return t + else: + return cls(t) + + @property + def host(self): + return self[0] + + @property + def port(self): + return self[1] + + @property + def is_ipv6(self): + return self.family == socket.AF_INET6 + + class SocketCloseMixin: def finish(self): self.finished = True @@ -209,10 +238,9 @@ class SocketCloseMixin: class TCPClient(SocketCloseMixin): rbufsize = -1 wbufsize = -1 - def __init__(self, host, port, source_address=None, use_ipv6=False): - self.host, self.port = host, port + def __init__(self, address, source_address=None): + self.address = Address.wrap(address) self.source_address = source_address - self.use_ipv6 = use_ipv6 self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False @@ -245,14 +273,14 @@ class TCPClient(SocketCloseMixin): def connect(self): try: - connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) + connection = socket.socket(self.address.family, socket.SOCK_STREAM) if self.source_address: connection.bind(self.source_address) - connection.connect((self.host, self.port)) + connection.connect(self.address) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: - raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) + raise NetLibError('Error connecting to "%s": %s' % (self.address[0], err)) self.connection = connection def settimeout(self, n): @@ -269,8 +297,9 @@ class BaseHandler(SocketCloseMixin): """ rbufsize = -1 wbufsize = -1 - def __init__(self, connection): + def __init__(self, connection, address): self.connection = connection + self.address = Address.wrap(address) self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) @@ -339,19 +368,18 @@ class BaseHandler(SocketCloseMixin): class TCPServer: request_queue_size = 20 - def __init__(self, server_address, use_ipv6=False): - self.server_address = server_address - self.use_ipv6 = use_ipv6 + def __init__(self, address): + self.address = Address.wrap(address) self.__is_shut_down = threading.Event() self.__shutdown_request = False - self.socket = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM) + self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.server_address) - self.server_address = self.socket.getsockname() - self.port = self.server_address[1] + self.socket.bind(self.address) + self.address = Address.wrap(self.socket.getsockname()) self.socket.listen(self.request_queue_size) def connection_thread(self, connection, client_address): + client_address = Address(client_address) try: self.handle_client_connection(connection, client_address) except: diff --git a/netlib/test.py b/netlib/test.py index f5599082..565b97cd 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -17,19 +17,18 @@ class ServerTestBase: ssl = None handler = None addr = ("localhost", 0) - use_ipv6 = False @classmethod def setupAll(cls): cls.q = Queue.Queue() s = cls.makeserver() - cls.port = s.port + cls.port = s.address.port cls.server = ServerThread(s) cls.server.start() @classmethod def makeserver(cls): - return TServer(cls.ssl, cls.q, cls.handler, cls.addr, cls.use_ipv6) + return TServer(cls.ssl, cls.q, cls.handler, cls.addr) @classmethod def teardownAll(cls): @@ -41,17 +40,17 @@ class ServerTestBase: class TServer(tcp.TCPServer): - def __init__(self, ssl, q, handler_klass, addr, use_ipv6): + def __init__(self, ssl, q, handler_klass, addr): """ ssl: A {cert, key, v3_only} dict. """ - tcp.TCPServer.__init__(self, addr, use_ipv6=use_ipv6) + tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q self.handler_klass = handler_klass self.last_handler = None def handle_client_connection(self, request, client_address): - h = self.handler_klass(request) + h = self.handler_klass(request, client_address) self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( diff --git a/test/test_http.py b/test/test_http.py index a0386115..e80e4b8f 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -223,7 +223,7 @@ class TestReadResponseNoContentLength(test.ServerTestBase): handler = NoContentLengthHTTPHandler def test_no_content_length(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None) assert content == "bar\r\n\r\n" diff --git a/test/test_tcp.py b/test/test_tcp.py index 7f2c21c4..49e20635 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -73,7 +73,7 @@ class TestServer(test.ServerTestBase): handler = EchoHandler def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write(testval) c.wfile.flush() @@ -88,7 +88,7 @@ class TestServerBind(test.ServerTestBase): for i in range(20): random_port = random.randrange(1024, 65535) try: - c = tcp.TCPClient("127.0.0.1", self.port, source_address=("127.0.0.1", random_port)) + c = tcp.TCPClient(("127.0.0.1", self.port), source_address=("127.0.0.1", random_port)) c.connect() assert c.rfile.readline() == str(("127.0.0.1", random_port)) return @@ -98,11 +98,11 @@ class TestServerBind(test.ServerTestBase): class TestServerIPv6(test.ServerTestBase): handler = EchoHandler - use_ipv6 = True + addr = tcp.Address(("localhost", 0), use_ipv6=True) def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient("::1", self.port, use_ipv6=True) + c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True)) c.connect() c.wfile.write(testval) c.wfile.flush() @@ -127,7 +127,7 @@ class TestFinishFail(test.ServerTestBase): handler = FinishFailHandler def test_disconnect_in_finish(self): testval = "echo!\n" - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write("foo\n") c.wfile.flush() @@ -137,7 +137,7 @@ class TestDisconnect(test.ServerTestBase): handler = EchoHandler def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write(testval) c.wfile.flush() @@ -153,7 +153,7 @@ class TestServerSSL(test.ServerTestBase): v3_only = False ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(sni="foo.com", options=tcp.OP_ALL) testval = "echo!\n" @@ -174,7 +174,7 @@ class TestSSLv3Only(test.ServerTestBase): v3_only = True ) def test_failure(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD) @@ -188,13 +188,13 @@ class TestSSLClientCert(test.ServerTestBase): v3_only = False ) def test_clientcert(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(cert=tutils.test_data.path("data/clientcert/client.pem")) assert c.rfile.readline().strip() == "1" def test_clientcert_err(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() tutils.raises( tcp.NetLibError, @@ -212,7 +212,7 @@ class TestSNI(test.ServerTestBase): v3_only = False ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(sni="foo.com") assert c.rfile.readline() == "foo.com" @@ -228,7 +228,7 @@ class TestClientCipherList(test.ServerTestBase): cipher_list = 'RC4-SHA' ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(sni="foo.com") assert c.rfile.readline() == "['RC4-SHA']" @@ -243,7 +243,7 @@ class TestSSLDisconnect(test.ServerTestBase): v3_only = False ) def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() # Excercise SSL.ZeroReturnError @@ -255,7 +255,7 @@ class TestSSLDisconnect(test.ServerTestBase): class TestDisconnect(test.ServerTestBase): def test_echo(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.rfile.read(10) c.wfile.write("foo") @@ -266,7 +266,7 @@ class TestDisconnect(test.ServerTestBase): class TestServerTimeOut(test.ServerTestBase): handler = TimeoutHandler def test_timeout(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() time.sleep(0.3) assert self.last_handler.timeout @@ -275,7 +275,7 @@ class TestServerTimeOut(test.ServerTestBase): class TestTimeOut(test.ServerTestBase): handler = HangHandler def test_timeout(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.settimeout(0.1) assert c.gettimeout() == 0.1 @@ -291,7 +291,7 @@ class TestSSLTimeOut(test.ServerTestBase): v3_only = False ) def test_timeout_client(self): - c = tcp.TCPClient("127.0.0.1", self.port) + c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() c.settimeout(0.1) @@ -300,7 +300,7 @@ class TestSSLTimeOut(test.ServerTestBase): class TestTCPClient: def test_conerr(self): - c = tcp.TCPClient("127.0.0.1", 0) + c = tcp.TCPClient(("127.0.0.1", 0)) tutils.raises(tcp.NetLibError, c.connect) -- cgit v1.2.3 From e18ac4b672e8645388dc8057801092ce417f1511 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 28 Jan 2014 20:30:16 +0100 Subject: re-add server attribute to BaseHandler --- netlib/tcp.py | 4 +++- netlib/test.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index bad166d0..729e513e 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -297,9 +297,11 @@ class BaseHandler(SocketCloseMixin): """ rbufsize = -1 wbufsize = -1 - def __init__(self, connection, address): + + def __init__(self, connection, address, server): self.connection = connection self.address = Address.wrap(address) + self.server = server self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) diff --git a/netlib/test.py b/netlib/test.py index 565b97cd..2f6a7107 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -50,7 +50,7 @@ class TServer(tcp.TCPServer): self.last_handler = None def handle_client_connection(self, request, client_address): - h = self.handler_klass(request, client_address) + h = self.handler_klass(request, client_address, self) self.last_handler = h if self.ssl: cert = certutils.SSLCert.from_pem( -- cgit v1.2.3 From ff9656be80192ac837cf98997f9fe6c00c9c5a32 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 30 Jan 2014 20:07:30 +0100 Subject: remove subclassing of tuple in tcp.Address, move StateObject into netlib --- netlib/certutils.py | 12 +++++++- netlib/odict.py | 7 ++++- netlib/stateobject.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++ netlib/tcp.py | 45 ++++++++++++++++++++--------- 4 files changed, 128 insertions(+), 16 deletions(-) create mode 100644 netlib/stateobject.py diff --git a/netlib/certutils.py b/netlib/certutils.py index 94294f6e..139203b9 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -3,6 +3,7 @@ from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL +from netlib.stateobject import StateObject import tcp default_exp = 62208000 # =24 * 60 * 60 * 720 @@ -152,13 +153,22 @@ class _GeneralNames(univ.SequenceOf): sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) -class SSLCert: +class SSLCert(StateObject): def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. """ self.x509 = cert + def _get_state(self): + return self.to_pem() + + def _load_state(self, state): + self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) + + def _from_state(cls, state): + return cls.from_pem(state) + @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) diff --git a/netlib/odict.py b/netlib/odict.py index 0759a5bf..8e195afc 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,4 +1,6 @@ import re, copy +from netlib.stateobject import StateObject + def safe_subn(pattern, repl, target, *args, **kwargs): """ @@ -9,7 +11,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs): return re.subn(str(pattern), str(repl), target, *args, **kwargs) -class ODict: +class ODict(StateObject): """ A dictionary-like object for managing ordered (key, value) data. """ @@ -98,6 +100,9 @@ class ODict: def _get_state(self): return [tuple(i) for i in self.lst] + def _load_state(self, state): + self.list = [list(i) for i in state] + @classmethod def _from_state(klass, state): return klass([list(i) for i in state]) diff --git a/netlib/stateobject.py b/netlib/stateobject.py new file mode 100644 index 00000000..c2ef2cd4 --- /dev/null +++ b/netlib/stateobject.py @@ -0,0 +1,80 @@ +from types import ClassType + + +class StateObject: + def _get_state(self): + raise NotImplementedError + + def _load_state(self, state): + raise NotImplementedError + + @classmethod + def _from_state(cls, state): + raise NotImplementedError + + def __eq__(self, other): + try: + return self._get_state() == other._get_state() + except AttributeError: # we may compare with something that's not a StateObject + return False + + +class SimpleStateObject(StateObject): + """ + A StateObject with opionated conventions that tries to keep everything DRY. + + Simply put, you agree on a list of attributes and their type. + Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves. + SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods. + Overriding _get_state or _load_state to add custom adjustments is always possible. + """ + + _stateobject_attributes = None # none by default to raise an exception if definition was forgotten + """ + An attribute-name -> class-or-type dict containing all attributes that should be serialized + If the attribute is a class, this class must be a subclass of StateObject. + """ + + def _get_state(self): + return {attr: self.__get_state_attr(attr, cls) + for attr, cls in self._stateobject_attributes.iteritems()} + + def __get_state_attr(self, attr, cls): + """ + helper for _get_state. + returns the value of the given attribute + """ + if getattr(self, attr) is None: + return None + if isinstance(cls, ClassType): + return getattr(self, attr)._get_state() + else: + return getattr(self, attr) + + def _load_state(self, state): + for attr, cls in self._stateobject_attributes.iteritems(): + self.__load_state_attr(attr, cls, state) + + def __load_state_attr(self, attr, cls, state): + """ + helper for _load_state. + loads the given attribute from the state. + """ + if state[attr] is not None: # First, catch None as value. + if isinstance(cls, ClassType): # Is the attribute a StateObject itself? + assert issubclass(cls, StateObject) + curr = getattr(self, attr) + if curr: # if the attribute is already present, delegate to the objects ._load_state method. + curr._load_state(state[attr]) + else: # otherwise, create a new object. + setattr(self, attr, cls._from_state(state[attr])) + else: + setattr(self, attr, cls(state[attr])) + else: + setattr(self, attr, None) + + @classmethod + def _from_state(cls, state): + f = cls() # the default implementation assumes an empty constructor. Override accordingly. + f._load_state(state) + return f \ No newline at end of file diff --git a/netlib/tcp.py b/netlib/tcp.py index 729e513e..c26d1191 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,6 +1,7 @@ import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils +from netlib.stateobject import StateObject SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD @@ -173,14 +174,13 @@ class Reader(_FileLike): return result -class Address(tuple): +class Address(StateObject): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ - def __new__(cls, address, use_ipv6=False): - a = super(Address, cls).__new__(cls, tuple(address)) - a.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET - return a + def __init__(self, address, use_ipv6=False): + self.address = address + self.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET @classmethod def wrap(cls, t): @@ -189,18 +189,35 @@ class Address(tuple): else: return cls(t) + def __call__(self): + return self.address + @property def host(self): - return self[0] + return self.address[0] @property def port(self): - return self[1] + return self.address[1] @property - def is_ipv6(self): + def use_ipv6(self): return self.family == socket.AF_INET6 + def _load_state(self, state): + self.address = state["address"] + self.family = socket.AF_INET6 if state["use_ipv6"] else socket.AF_INET + + def _get_state(self): + return dict( + address=self.address, + use_ipv6=self.use_ipv6 + ) + + @classmethod + def _from_state(cls, state): + return cls(**state) + class SocketCloseMixin: def finish(self): @@ -240,7 +257,7 @@ class TCPClient(SocketCloseMixin): wbufsize = -1 def __init__(self, address, source_address=None): self.address = Address.wrap(address) - self.source_address = source_address + self.source_address = Address.wrap(source_address) if source_address else None self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False @@ -275,12 +292,12 @@ class TCPClient(SocketCloseMixin): try: connection = socket.socket(self.address.family, socket.SOCK_STREAM) if self.source_address: - connection.bind(self.source_address) - connection.connect(self.address) + connection.bind(self.source_address()) + connection.connect(self.address()) self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError), err: - raise NetLibError('Error connecting to "%s": %s' % (self.address[0], err)) + raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection def settimeout(self, n): @@ -376,7 +393,7 @@ class TCPServer: self.__shutdown_request = False self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.address) + self.socket.bind(self.address()) self.address = Address.wrap(self.socket.getsockname()) self.socket.listen(self.request_queue_size) @@ -427,7 +444,7 @@ class TCPServer: if traceback: exc = traceback.format_exc() print >> fp, '-'*40 - print >> fp, "Error in processing of request from %s:%s"%client_address + print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port) print >> fp, exc print >> fp, '-'*40 -- cgit v1.2.3 From dc45b4bf19bff5edc0b72ccb68fad04d479aff83 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 31 Jan 2014 01:06:53 +0100 Subject: move StateObject back into libmproxy --- netlib/certutils.py | 12 +------- netlib/odict.py | 3 +- netlib/stateobject.py | 80 --------------------------------------------------- netlib/tcp.py | 21 ++++---------- 4 files changed, 7 insertions(+), 109 deletions(-) delete mode 100644 netlib/stateobject.py diff --git a/netlib/certutils.py b/netlib/certutils.py index 139203b9..94294f6e 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -3,7 +3,6 @@ from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -from netlib.stateobject import StateObject import tcp default_exp = 62208000 # =24 * 60 * 60 * 720 @@ -153,22 +152,13 @@ class _GeneralNames(univ.SequenceOf): sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) -class SSLCert(StateObject): +class SSLCert: def __init__(self, cert): """ Returns a (common name, [subject alternative names]) tuple. """ self.x509 = cert - def _get_state(self): - return self.to_pem() - - def _load_state(self, state): - self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) - - def _from_state(cls, state): - return cls.from_pem(state) - @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) diff --git a/netlib/odict.py b/netlib/odict.py index 8e195afc..46b74e8e 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,5 +1,4 @@ import re, copy -from netlib.stateobject import StateObject def safe_subn(pattern, repl, target, *args, **kwargs): @@ -11,7 +10,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs): return re.subn(str(pattern), str(repl), target, *args, **kwargs) -class ODict(StateObject): +class ODict: """ A dictionary-like object for managing ordered (key, value) data. """ diff --git a/netlib/stateobject.py b/netlib/stateobject.py deleted file mode 100644 index c2ef2cd4..00000000 --- a/netlib/stateobject.py +++ /dev/null @@ -1,80 +0,0 @@ -from types import ClassType - - -class StateObject: - def _get_state(self): - raise NotImplementedError - - def _load_state(self, state): - raise NotImplementedError - - @classmethod - def _from_state(cls, state): - raise NotImplementedError - - def __eq__(self, other): - try: - return self._get_state() == other._get_state() - except AttributeError: # we may compare with something that's not a StateObject - return False - - -class SimpleStateObject(StateObject): - """ - A StateObject with opionated conventions that tries to keep everything DRY. - - Simply put, you agree on a list of attributes and their type. - Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves. - SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods. - Overriding _get_state or _load_state to add custom adjustments is always possible. - """ - - _stateobject_attributes = None # none by default to raise an exception if definition was forgotten - """ - An attribute-name -> class-or-type dict containing all attributes that should be serialized - If the attribute is a class, this class must be a subclass of StateObject. - """ - - def _get_state(self): - return {attr: self.__get_state_attr(attr, cls) - for attr, cls in self._stateobject_attributes.iteritems()} - - def __get_state_attr(self, attr, cls): - """ - helper for _get_state. - returns the value of the given attribute - """ - if getattr(self, attr) is None: - return None - if isinstance(cls, ClassType): - return getattr(self, attr)._get_state() - else: - return getattr(self, attr) - - def _load_state(self, state): - for attr, cls in self._stateobject_attributes.iteritems(): - self.__load_state_attr(attr, cls, state) - - def __load_state_attr(self, attr, cls, state): - """ - helper for _load_state. - loads the given attribute from the state. - """ - if state[attr] is not None: # First, catch None as value. - if isinstance(cls, ClassType): # Is the attribute a StateObject itself? - assert issubclass(cls, StateObject) - curr = getattr(self, attr) - if curr: # if the attribute is already present, delegate to the objects ._load_state method. - curr._load_state(state[attr]) - else: # otherwise, create a new object. - setattr(self, attr, cls._from_state(state[attr])) - else: - setattr(self, attr, cls(state[attr])) - else: - setattr(self, attr, None) - - @classmethod - def _from_state(cls, state): - f = cls() # the default implementation assumes an empty constructor. Override accordingly. - f._load_state(state) - return f \ No newline at end of file diff --git a/netlib/tcp.py b/netlib/tcp.py index c26d1191..346bc053 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,7 +1,6 @@ import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils -from netlib.stateobject import StateObject SSLv2_METHOD = SSL.SSLv2_METHOD SSLv3_METHOD = SSL.SSLv3_METHOD @@ -174,13 +173,13 @@ class Reader(_FileLike): return result -class Address(StateObject): +class Address(object): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ def __init__(self, address, use_ipv6=False): self.address = address - self.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET + self.use_ipv6 = use_ipv6 @classmethod def wrap(cls, t): @@ -204,19 +203,9 @@ class Address(StateObject): def use_ipv6(self): return self.family == socket.AF_INET6 - def _load_state(self, state): - self.address = state["address"] - self.family = socket.AF_INET6 if state["use_ipv6"] else socket.AF_INET - - def _get_state(self): - return dict( - address=self.address, - use_ipv6=self.use_ipv6 - ) - - @classmethod - def _from_state(cls, state): - return cls(**state) + @use_ipv6.setter + def use_ipv6(self, b): + self.family = socket.AF_INET6 if b else socket.AF_INET class SocketCloseMixin: -- cgit v1.2.3 From 0bbc40dc33dd7bd3729e639874882dd6dd7ea818 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 4 Feb 2014 04:51:41 +0100 Subject: store used sni in TCPClient, add equality check for tcp.Address --- netlib/tcp.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 346bc053..94ea8806 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -207,8 +207,12 @@ class Address(object): def use_ipv6(self, b): self.family = socket.AF_INET6 if b else socket.AF_INET + def __eq__(self, other): + other = Address.wrap(other) + return (self.address, self.family) == (other.address, other.family) -class SocketCloseMixin: + +class SocketCloseMixin(object): def finish(self): self.finished = True try: @@ -250,6 +254,7 @@ class TCPClient(SocketCloseMixin): self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.ssl_established = False + self.sni = None def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): """ @@ -267,6 +272,7 @@ class TCPClient(SocketCloseMixin): self.connection = SSL.Connection(context, self.connection) self.ssl_established = True if sni: + self.sni = sni self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() try: -- cgit v1.2.3 From 7fc544bc7ff8fd610ba9db92c0d3b59a0b040b5b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 5 Feb 2014 21:34:14 +0100 Subject: adjust netlib.wsgi to reflect changes in mitmproxys flow format --- netlib/tcp.py | 2 +- netlib/wsgi.py | 15 ++++++++++----- test/test_tcp.py | 1 + 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 94ea8806..34e47999 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -178,7 +178,7 @@ class Address(object): This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. """ def __init__(self, address, use_ipv6=False): - self.address = address + self.address = tuple(address) self.use_ipv6 = use_ipv6 @classmethod diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 647cb899..b576bdff 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,17 +1,22 @@ import cStringIO, urllib, time, traceback -import odict +import odict, tcp class ClientConn: def __init__(self, address): - self.address = address + self.address = tcp.Address.wrap(address) + + +class Flow: + def __init__(self, client_conn): + self.client_conn = client_conn class Request: def __init__(self, client_conn, scheme, method, path, headers, content): self.scheme, self.method, self.path = scheme, method, path self.headers, self.content = headers, content - self.client_conn = client_conn + self.flow = Flow(client_conn) def date_time_string(): @@ -60,8 +65,8 @@ class WSGIAdaptor: 'SERVER_PROTOCOL': "HTTP/1.1", } environ.update(extra) - if request.client_conn.address: - environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address + if request.flow.client_conn.address: + environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.flow.client_conn.address() for key, value in request.headers.items(): key = 'HTTP_' + key.upper().replace('-', '_') diff --git a/test/test_tcp.py b/test/test_tcp.py index 49e20635..525961d5 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -215,6 +215,7 @@ class TestSNI(test.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl(sni="foo.com") + assert c.sni == "foo.com" assert c.rfile.readline() == "foo.com" -- cgit v1.2.3