aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2012-06-16 13:38:10 +1200
committerAldo Cortesi <aldo@nullcube.com>2012-06-16 13:38:10 +1200
commit4e53f1ee908949c0dcafd822bf05f9523e00d189 (patch)
tree856844675103ae220209be8911dd8d1cb7358b53
parent8ae64337ed64b0dc85aeba92ed23d038466ff6f7 (diff)
downloadmitmproxy-4e53f1ee908949c0dcafd822bf05f9523e00d189.tar.gz
mitmproxy-4e53f1ee908949c0dcafd822bf05f9523e00d189.tar.bz2
mitmproxy-4e53f1ee908949c0dcafd822bf05f9523e00d189.zip
Rename our tcpserver to netlib, expand to include client network functions.
-rw-r--r--libmproxy/netlib.py (renamed from libmproxy/tcpserver.py)85
-rw-r--r--libmproxy/proxy.py120
-rw-r--r--test/test_netlib.py15
-rw-r--r--test/test_proxy.py11
4 files changed, 128 insertions, 103 deletions
diff --git a/libmproxy/tcpserver.py b/libmproxy/netlib.py
index bf7ed0b4..65dbee63 100644
--- a/libmproxy/tcpserver.py
+++ b/libmproxy/netlib.py
@@ -1,4 +1,80 @@
import select, socket, threading
+from OpenSSL import SSL
+
+
+class NetLibError(Exception): pass
+
+
+class FileLike:
+ def __init__(self, o):
+ self.o = o
+
+ def __getattr__(self, attr):
+ return getattr(self.o, attr)
+
+ def flush(self):
+ pass
+
+ def read(self, length):
+ result = ''
+ while len(result) < length:
+ try:
+ data = self.o.read(length)
+ except AttributeError:
+ break
+ except SSL.ZeroReturnError:
+ break
+ if not data:
+ break
+ result += data
+ return result
+
+ def write(self, v):
+ self.o.sendall(v)
+
+ def readline(self, size = None):
+ result = ''
+ bytes_read = 0
+ while True:
+ if size is not None and bytes_read >= size:
+ break
+ ch = self.read(1)
+ bytes_read += 1
+ if not ch:
+ break
+ else:
+ result += ch
+ if ch == '\n':
+ break
+ return result
+
+
+class TCPClient:
+ def __init__(self, ssl, host, port, clientcert):
+ self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert
+ self.sock, self.rfile, self.wfile = None, None, None
+ self.cert = None
+ self.connect()
+
+ def connect(self):
+ try:
+ addr = socket.gethostbyname(self.host)
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ if self.ssl:
+ context = SSL.Context(SSL.SSLv23_METHOD)
+ if self.clientcert:
+ context.use_certificate_file(self.clientcert)
+ server = SSL.Connection(context, server)
+ server.connect((addr, self.port))
+ if self.ssl:
+ self.cert = server.get_peer_certificate()
+ self.rfile, self.wfile = FileLike(server), FileLike(server)
+ else:
+ self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
+ except socket.error, err:
+ raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
+ self.sock = server
+
class BaseHandler:
rbufsize = -1
@@ -13,6 +89,15 @@ class BaseHandler:
self.handle()
self.finish()
+ def convert_to_ssl(self, cert, key):
+ ctx = SSL.Context(SSL.SSLv23_METHOD)
+ ctx.use_privatekey_file(key)
+ ctx.use_certificate_file(cert)
+ self.connection = SSL.Connection(ctx, self.connection)
+ self.connection.set_accept_state()
+ self.rfile = FileLike(self.connection)
+ self.wfile = FileLike(self.connection)
+
def finish(self):
try:
if not getattr(self.wfile, "closed", False):
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py
index 89493e79..9febba72 100644
--- a/libmproxy/proxy.py
+++ b/libmproxy/proxy.py
@@ -21,7 +21,7 @@
import sys, os, string, socket, time
import shutil, tempfile, threading
import optparse, SocketServer
-import utils, flow, certutils, version, wsgi, tcpserver
+import utils, flow, certutils, version, wsgi, netlib
from OpenSSL import SSL
@@ -232,50 +232,6 @@ def read_http_body_request(rfile, wfile, headers, httpversion, limit):
return read_http_body(rfile, headers, False, limit)
-class FileLike:
- def __init__(self, o):
- self.o = o
-
- def __getattr__(self, attr):
- return getattr(self.o, attr)
-
- def flush(self):
- pass
-
- def read(self, length):
- result = ''
- while len(result) < length:
- try:
- data = self.o.read(length)
- except AttributeError:
- break
- except SSL.ZeroReturnError:
- break
- if not data:
- break
- result += data
- return result
-
- def write(self, v):
- self.o.sendall(v)
-
- def readline(self, size = None):
- result = ''
- bytes_read = 0
- while True:
- if size is not None and bytes_read >= size:
- break
- ch = self.read(1)
- bytes_read += 1
- if not ch:
- break
- else:
- result += ch
- if ch == '\n':
- break
- return result
-
-
class RequestReplayThread(threading.Thread):
def __init__(self, config, flow, masterq):
self.config, self.flow, self.masterq = config, flow, masterq
@@ -291,41 +247,27 @@ class RequestReplayThread(threading.Thread):
except ProxyError, v:
err = flow.Error(self.flow.request, v.msg)
err._send(self.masterq)
+ except netlib.NetLibError, v:
+ raise ProxyError(502, v)
-class ServerConnection:
+class ServerConnection(netlib.TCPClient):
def __init__(self, config, scheme, host, port):
- self.config, self.scheme, self.host, self.port = config, scheme, host, port
- self.cert = None
- self.sock, self.rfile, self.wfile = None, None, None
- self.connect()
+ clientcert = None
+ if config.clientcerts:
+ path = os.path.join(config.clientcerts, self.host) + ".pem"
+ if os.path.exists(clientcert):
+ clientcert = path
+ netlib.TCPClient.__init__(
+ self,
+ True if scheme == "https" else False,
+ host,
+ port,
+ clientcert
+ )
+ self.config, self.scheme = config, scheme
self.requestcount = 0
- def connect(self):
- try:
- addr = socket.gethostbyname(self.host)
- server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- if self.scheme == "https":
- if self.config.clientcerts:
- clientcert = os.path.join(self.config.clientcerts, self.host) + ".pem"
- if not os.path.exists(clientcert):
- clientcert = None
- else:
- clientcert = None
- context = SSL.Context(SSL.SSLv23_METHOD)
- if clientcert:
- context.use_certificate_file(clientcert)
- server = SSL.Connection(context, server)
- server.connect((addr, self.port))
- if self.scheme == "https":
- self.cert = server.get_peer_certificate()
- self.rfile, self.wfile = FileLike(server), FileLike(server)
- else:
- self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
- except socket.error, err:
- raise ProxyError(502, 'Error connecting to "%s": %s' % (self.host, err))
- self.sock = server
-
def send(self, request):
self.requestcount += 1
try:
@@ -374,13 +316,13 @@ class ServerConnection:
pass
-class ProxyHandler(tcpserver.BaseHandler):
+class ProxyHandler(netlib.BaseHandler):
def __init__(self, config, connection, client_address, server, q):
self.mqueue = q
self.config = config
self.server_conn = None
self.proxy_connect_state = None
- tcpserver.BaseHandler.__init__(self, connection, client_address, server)
+ netlib.BaseHandler.__init__(self, connection, client_address, server)
def handle(self):
cc = flow.ClientConnect(self.client_address)
@@ -397,7 +339,10 @@ class ProxyHandler(tcpserver.BaseHandler):
sc.terminate()
self.server_conn = None
if not self.server_conn:
- self.server_conn = ServerConnection(self.config, scheme, host, port)
+ try:
+ self.server_conn = ServerConnection(self.config, scheme, host, port)
+ except netlib.NetLibError, v:
+ raise ProxyError(502, v)
def handle_request(self, cc):
try:
@@ -473,15 +418,6 @@ class ProxyHandler(tcpserver.BaseHandler):
raise ProxyError(502, "mitmproxy: Unable to generate dummy cert.")
return ret
- def convert_to_ssl(self, cert):
- ctx = SSL.Context(SSL.SSLv23_METHOD)
- ctx.use_privatekey_file(self.config.certfile or self.config.cacert)
- ctx.use_certificate_file(cert)
- self.connection = SSL.Connection(ctx, self.connection)
- self.connection.set_accept_state()
- self.rfile = FileLike(self.connection)
- self.wfile = FileLike(self.connection)
-
def read_request(self, client_conn):
line = self.rfile.readline()
if line == "\r\n" or line == "\n": # Possible leftover from previous message
@@ -494,7 +430,7 @@ class ProxyHandler(tcpserver.BaseHandler):
if port in self.config.transparent_proxy["sslports"]:
scheme = "https"
certfile = self.find_cert(host, port)
- self.convert_to_ssl(certfile)
+ self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert)
else:
scheme = "http"
method, path, httpversion = parse_init_http(line)
@@ -527,7 +463,7 @@ class ProxyHandler(tcpserver.BaseHandler):
)
self.wfile.flush()
certfile = self.find_cert(host, port)
- self.convert_to_ssl(certfile)
+ self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert)
self.proxy_connect_state = (host, port, httpversion)
line = self.rfile.readline(line)
if self.proxy_connect_state:
@@ -572,7 +508,7 @@ class ProxyHandler(tcpserver.BaseHandler):
class ProxyServerError(Exception): pass
-class ProxyServer(tcpserver.TCPServer):
+class ProxyServer(netlib.TCPServer):
allow_reuse_address = True
bound = True
def __init__(self, config, port, address=''):
@@ -581,7 +517,7 @@ class ProxyServer(tcpserver.TCPServer):
"""
self.config, self.port, self.address = config, port, address
try:
- tcpserver.TCPServer.__init__(self, (address, port))
+ netlib.TCPServer.__init__(self, (address, port))
except socket.error, v:
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
self.masterq = None
@@ -600,7 +536,7 @@ class ProxyServer(tcpserver.TCPServer):
ProxyHandler(self.config, request, client_address, self, self.masterq)
def shutdown(self):
- tcpserver.TCPServer.shutdown(self)
+ netlib.TCPServer.shutdown(self)
try:
shutil.rmtree(self.certdir)
except OSError:
diff --git a/test/test_netlib.py b/test/test_netlib.py
new file mode 100644
index 00000000..2b76c9cf
--- /dev/null
+++ b/test/test_netlib.py
@@ -0,0 +1,15 @@
+import cStringIO
+from libmproxy import netlib
+
+
+class TestFileLike:
+ def test_wrap(self):
+ s = cStringIO.StringIO("foobar\nfoobar")
+ s = netlib.FileLike(s)
+ s.flush()
+ assert s.readline() == "foobar\n"
+ assert s.readline() == "foobar"
+ # Test __getattr__
+ assert s.isatty
+
+
diff --git a/test/test_proxy.py b/test/test_proxy.py
index 9d7239dd..5fab282c 100644
--- a/test/test_proxy.py
+++ b/test/test_proxy.py
@@ -60,17 +60,6 @@ def test_read_http_body():
assert len(proxy.read_http_body(s, h, True, 100)) == 7
-class TestFileLike:
- def test_wrap(self):
- s = cStringIO.StringIO("foobar\nfoobar")
- s = proxy.FileLike(s)
- s.flush()
- assert s.readline() == "foobar\n"
- assert s.readline() == "foobar"
- # Test __getattr__
- assert s.isatty
-
-
class TestProxyError:
def test_simple(self):
p = proxy.ProxyError(111, "msg")