From 4e53f1ee908949c0dcafd822bf05f9523e00d189 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 16 Jun 2012 13:38:10 +1200 Subject: Rename our tcpserver to netlib, expand to include client network functions. --- libmproxy/netlib.py | 173 +++++++++++++++++++++++++++++++++++++++++++++++++ libmproxy/proxy.py | 120 ++++++++-------------------------- libmproxy/tcpserver.py | 88 ------------------------- test/test_netlib.py | 15 +++++ test/test_proxy.py | 11 ---- 5 files changed, 216 insertions(+), 191 deletions(-) create mode 100644 libmproxy/netlib.py delete mode 100644 libmproxy/tcpserver.py create mode 100644 test/test_netlib.py diff --git a/libmproxy/netlib.py b/libmproxy/netlib.py new file mode 100644 index 00000000..65dbee63 --- /dev/null +++ b/libmproxy/netlib.py @@ -0,0 +1,173 @@ +import select, socket, threading +from OpenSSL import SSL + + +class NetLibError(Exception): pass + + +class FileLike: + def __init__(self, o): + self.o = o + + def __getattr__(self, attr): + return getattr(self.o, attr) + + def flush(self): + pass + + def read(self, length): + result = '' + while len(result) < length: + try: + data = self.o.read(length) + except AttributeError: + break + except SSL.ZeroReturnError: + break + if not data: + break + result += data + return result + + def write(self, v): + self.o.sendall(v) + + def readline(self, size = None): + result = '' + bytes_read = 0 + while True: + if size is not None and bytes_read >= size: + break + ch = self.read(1) + bytes_read += 1 + if not ch: + break + else: + result += ch + if ch == '\n': + break + return result + + +class TCPClient: + def __init__(self, ssl, host, port, clientcert): + self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert + self.sock, self.rfile, self.wfile = None, None, None + self.cert = None + self.connect() + + def connect(self): + try: + addr = socket.gethostbyname(self.host) + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if self.ssl: + context = SSL.Context(SSL.SSLv23_METHOD) + if self.clientcert: + context.use_certificate_file(self.clientcert) + server = SSL.Connection(context, server) + server.connect((addr, self.port)) + if self.ssl: + self.cert = server.get_peer_certificate() + self.rfile, self.wfile = FileLike(server), FileLike(server) + else: + self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') + except socket.error, err: + raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) + self.sock = server + + +class BaseHandler: + rbufsize = -1 + wbufsize = 0 + def __init__(self, connection, client_address, server): + self.connection = connection + self.rfile = self.connection.makefile('rb', self.rbufsize) + self.wfile = self.connection.makefile('wb', self.wbufsize) + + self.client_address = client_address + self.server = server + self.handle() + self.finish() + + def convert_to_ssl(self, cert, key): + ctx = SSL.Context(SSL.SSLv23_METHOD) + ctx.use_privatekey_file(key) + ctx.use_certificate_file(cert) + self.connection = SSL.Connection(ctx, self.connection) + self.connection.set_accept_state() + self.rfile = FileLike(self.connection) + self.wfile = FileLike(self.connection) + + def finish(self): + try: + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.connection.close() + self.wfile.close() + self.rfile.close() + except IOError: + pass + + def handle(self): + raise NotImplementedError + + +class TCPServer: + request_queue_size = 20 + def __init__(self, server_address): + self.server_address = server_address + self.__is_shut_down = threading.Event() + self.__shutdown_request = False + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind(self.server_address) + self.server_address = self.socket.getsockname() + self.socket.listen(self.request_queue_size) + + def fileno(self): + return self.socket.fileno() + + def request_thread(self, request, client_address): + try: + self.handle_connection(request, client_address) + request.close() + except: + self.handle_error(request, client_address) + request.close() + + def serve_forever(self, poll_interval=0.5): + self.__is_shut_down.clear() + try: + while not self.__shutdown_request: + r, w, e = select.select([self], [], [], poll_interval) + if self in r: + try: + request, client_address = self.socket.accept() + except socket.error: + return + try: + t = threading.Thread(target = self.request_thread, + args = (request, client_address)) + t.setDaemon (1) + t.start() + except: + self.handle_error(request, client_address) + request.close() + finally: + self.__shutdown_request = False + self.__is_shut_down.set() + + def shutdown(self): + self.__shutdown_request = True + self.__is_shut_down.wait() + + def handle_error(self, request, client_address): + print '-'*40 + print 'Exception happened during processing of request from', + print client_address + import traceback + traceback.print_exc() # XXX But this goes to stderr! + print '-'*40 + + def handle_connection(self, request, client_address): + raise NotImplementedError diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 89493e79..9febba72 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -21,7 +21,7 @@ import sys, os, string, socket, time import shutil, tempfile, threading import optparse, SocketServer -import utils, flow, certutils, version, wsgi, tcpserver +import utils, flow, certutils, version, wsgi, netlib from OpenSSL import SSL @@ -232,50 +232,6 @@ def read_http_body_request(rfile, wfile, headers, httpversion, limit): return read_http_body(rfile, headers, False, limit) -class FileLike: - def __init__(self, o): - self.o = o - - def __getattr__(self, attr): - return getattr(self.o, attr) - - def flush(self): - pass - - def read(self, length): - result = '' - while len(result) < length: - try: - data = self.o.read(length) - except AttributeError: - break - except SSL.ZeroReturnError: - break - if not data: - break - result += data - return result - - def write(self, v): - self.o.sendall(v) - - def readline(self, size = None): - result = '' - bytes_read = 0 - while True: - if size is not None and bytes_read >= size: - break - ch = self.read(1) - bytes_read += 1 - if not ch: - break - else: - result += ch - if ch == '\n': - break - return result - - class RequestReplayThread(threading.Thread): def __init__(self, config, flow, masterq): self.config, self.flow, self.masterq = config, flow, masterq @@ -291,41 +247,27 @@ class RequestReplayThread(threading.Thread): except ProxyError, v: err = flow.Error(self.flow.request, v.msg) err._send(self.masterq) + except netlib.NetLibError, v: + raise ProxyError(502, v) -class ServerConnection: +class ServerConnection(netlib.TCPClient): def __init__(self, config, scheme, host, port): - self.config, self.scheme, self.host, self.port = config, scheme, host, port - self.cert = None - self.sock, self.rfile, self.wfile = None, None, None - self.connect() + clientcert = None + if config.clientcerts: + path = os.path.join(config.clientcerts, self.host) + ".pem" + if os.path.exists(clientcert): + clientcert = path + netlib.TCPClient.__init__( + self, + True if scheme == "https" else False, + host, + port, + clientcert + ) + self.config, self.scheme = config, scheme self.requestcount = 0 - def connect(self): - try: - addr = socket.gethostbyname(self.host) - server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - if self.scheme == "https": - if self.config.clientcerts: - clientcert = os.path.join(self.config.clientcerts, self.host) + ".pem" - if not os.path.exists(clientcert): - clientcert = None - else: - clientcert = None - context = SSL.Context(SSL.SSLv23_METHOD) - if clientcert: - context.use_certificate_file(clientcert) - server = SSL.Connection(context, server) - server.connect((addr, self.port)) - if self.scheme == "https": - self.cert = server.get_peer_certificate() - self.rfile, self.wfile = FileLike(server), FileLike(server) - else: - self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') - except socket.error, err: - raise ProxyError(502, 'Error connecting to "%s": %s' % (self.host, err)) - self.sock = server - def send(self, request): self.requestcount += 1 try: @@ -374,13 +316,13 @@ class ServerConnection: pass -class ProxyHandler(tcpserver.BaseHandler): +class ProxyHandler(netlib.BaseHandler): def __init__(self, config, connection, client_address, server, q): self.mqueue = q self.config = config self.server_conn = None self.proxy_connect_state = None - tcpserver.BaseHandler.__init__(self, connection, client_address, server) + netlib.BaseHandler.__init__(self, connection, client_address, server) def handle(self): cc = flow.ClientConnect(self.client_address) @@ -397,7 +339,10 @@ class ProxyHandler(tcpserver.BaseHandler): sc.terminate() self.server_conn = None if not self.server_conn: - self.server_conn = ServerConnection(self.config, scheme, host, port) + try: + self.server_conn = ServerConnection(self.config, scheme, host, port) + except netlib.NetLibError, v: + raise ProxyError(502, v) def handle_request(self, cc): try: @@ -473,15 +418,6 @@ class ProxyHandler(tcpserver.BaseHandler): raise ProxyError(502, "mitmproxy: Unable to generate dummy cert.") return ret - def convert_to_ssl(self, cert): - ctx = SSL.Context(SSL.SSLv23_METHOD) - ctx.use_privatekey_file(self.config.certfile or self.config.cacert) - ctx.use_certificate_file(cert) - self.connection = SSL.Connection(ctx, self.connection) - self.connection.set_accept_state() - self.rfile = FileLike(self.connection) - self.wfile = FileLike(self.connection) - def read_request(self, client_conn): line = self.rfile.readline() if line == "\r\n" or line == "\n": # Possible leftover from previous message @@ -494,7 +430,7 @@ class ProxyHandler(tcpserver.BaseHandler): if port in self.config.transparent_proxy["sslports"]: scheme = "https" certfile = self.find_cert(host, port) - self.convert_to_ssl(certfile) + self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert) else: scheme = "http" method, path, httpversion = parse_init_http(line) @@ -527,7 +463,7 @@ class ProxyHandler(tcpserver.BaseHandler): ) self.wfile.flush() certfile = self.find_cert(host, port) - self.convert_to_ssl(certfile) + self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert) self.proxy_connect_state = (host, port, httpversion) line = self.rfile.readline(line) if self.proxy_connect_state: @@ -572,7 +508,7 @@ class ProxyHandler(tcpserver.BaseHandler): class ProxyServerError(Exception): pass -class ProxyServer(tcpserver.TCPServer): +class ProxyServer(netlib.TCPServer): allow_reuse_address = True bound = True def __init__(self, config, port, address=''): @@ -581,7 +517,7 @@ class ProxyServer(tcpserver.TCPServer): """ self.config, self.port, self.address = config, port, address try: - tcpserver.TCPServer.__init__(self, (address, port)) + netlib.TCPServer.__init__(self, (address, port)) except socket.error, v: raise ProxyServerError('Error starting proxy server: ' + v.strerror) self.masterq = None @@ -600,7 +536,7 @@ class ProxyServer(tcpserver.TCPServer): ProxyHandler(self.config, request, client_address, self, self.masterq) def shutdown(self): - tcpserver.TCPServer.shutdown(self) + netlib.TCPServer.shutdown(self) try: shutil.rmtree(self.certdir) except OSError: diff --git a/libmproxy/tcpserver.py b/libmproxy/tcpserver.py deleted file mode 100644 index bf7ed0b4..00000000 --- a/libmproxy/tcpserver.py +++ /dev/null @@ -1,88 +0,0 @@ -import select, socket, threading - -class BaseHandler: - rbufsize = -1 - wbufsize = 0 - def __init__(self, connection, client_address, server): - self.connection = connection - self.rfile = self.connection.makefile('rb', self.rbufsize) - self.wfile = self.connection.makefile('wb', self.wbufsize) - - self.client_address = client_address - self.server = server - self.handle() - self.finish() - - def finish(self): - try: - if not getattr(self.wfile, "closed", False): - self.wfile.flush() - self.connection.close() - self.wfile.close() - self.rfile.close() - except IOError: - pass - - def handle(self): - raise NotImplementedError - - -class TCPServer: - request_queue_size = 20 - def __init__(self, server_address): - self.server_address = server_address - self.__is_shut_down = threading.Event() - self.__shutdown_request = False - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.bind(self.server_address) - self.server_address = self.socket.getsockname() - self.socket.listen(self.request_queue_size) - - def fileno(self): - return self.socket.fileno() - - def request_thread(self, request, client_address): - try: - self.handle_connection(request, client_address) - request.close() - except: - self.handle_error(request, client_address) - request.close() - - def serve_forever(self, poll_interval=0.5): - self.__is_shut_down.clear() - try: - while not self.__shutdown_request: - r, w, e = select.select([self], [], [], poll_interval) - if self in r: - try: - request, client_address = self.socket.accept() - except socket.error: - return - try: - t = threading.Thread(target = self.request_thread, - args = (request, client_address)) - t.setDaemon (1) - t.start() - except: - self.handle_error(request, client_address) - request.close() - finally: - self.__shutdown_request = False - self.__is_shut_down.set() - - def shutdown(self): - self.__shutdown_request = True - self.__is_shut_down.wait() - - def handle_error(self, request, client_address): - print '-'*40 - print 'Exception happened during processing of request from', - print client_address - import traceback - traceback.print_exc() # XXX But this goes to stderr! - print '-'*40 - - def handle_connection(self, request, client_address): - raise NotImplementedError diff --git a/test/test_netlib.py b/test/test_netlib.py new file mode 100644 index 00000000..2b76c9cf --- /dev/null +++ b/test/test_netlib.py @@ -0,0 +1,15 @@ +import cStringIO +from libmproxy import netlib + + +class TestFileLike: + def test_wrap(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = netlib.FileLike(s) + s.flush() + assert s.readline() == "foobar\n" + assert s.readline() == "foobar" + # Test __getattr__ + assert s.isatty + + diff --git a/test/test_proxy.py b/test/test_proxy.py index 9d7239dd..5fab282c 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -60,17 +60,6 @@ def test_read_http_body(): assert len(proxy.read_http_body(s, h, True, 100)) == 7 -class TestFileLike: - def test_wrap(self): - s = cStringIO.StringIO("foobar\nfoobar") - s = proxy.FileLike(s) - s.flush() - assert s.readline() == "foobar\n" - assert s.readline() == "foobar" - # Test __getattr__ - assert s.isatty - - class TestProxyError: def test_simple(self): p = proxy.ProxyError(111, "msg") -- cgit v1.2.3