aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2014-02-07 10:50:23 +1300
committerAldo Cortesi <aldo@nullcube.com>2014-02-07 10:50:23 +1300
commit3d52d16e8d7b298363ef6d6f7279d75fb4b1a430 (patch)
tree5b793cc821b531a3fcedf511b9315346f6c7e014
parent404d4bbc69d9f2eb12664415ebca44a95ce96e56 (diff)
parent7fc544bc7ff8fd610ba9db92c0d3b59a0b040b5b (diff)
downloadmitmproxy-3d52d16e8d7b298363ef6d6f7279d75fb4b1a430.tar.gz
mitmproxy-3d52d16e8d7b298363ef6d6f7279d75fb4b1a430.tar.bz2
mitmproxy-3d52d16e8d7b298363ef6d6f7279d75fb4b1a430.zip
Merge branch 'tcp_proxy'
-rw-r--r--netlib/certutils.py2
-rw-r--r--netlib/odict.py4
-rw-r--r--netlib/tcp.py186
-rw-r--r--netlib/test.py11
-rw-r--r--netlib/wsgi.py15
-rw-r--r--test/test_http.py2
-rw-r--r--test/test_tcp.py37
7 files changed, 144 insertions, 113 deletions
diff --git a/netlib/certutils.py b/netlib/certutils.py
index 0349bec7..94294f6e 100644
--- a/netlib/certutils.py
+++ b/netlib/certutils.py
@@ -237,7 +237,7 @@ class SSLCert:
def get_remote_cert(host, port, sni):
- c = tcp.TCPClient(host, port)
+ c = tcp.TCPClient((host, port))
c.connect()
c.convert_to_ssl(sni=sni)
return c.cert
diff --git a/netlib/odict.py b/netlib/odict.py
index 0759a5bf..46b74e8e 100644
--- a/netlib/odict.py
+++ b/netlib/odict.py
@@ -1,5 +1,6 @@
import re, copy
+
def safe_subn(pattern, repl, target, *args, **kwargs):
"""
There are Unicode conversion problems with re.subn. We try to smooth
@@ -98,6 +99,9 @@ class ODict:
def _get_state(self):
return [tuple(i) for i in self.lst]
+ def _load_state(self, state):
+ self.list = [list(i) for i in state]
+
@classmethod
def _from_state(klass, state):
return klass([list(i) for i in state])
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 33f7ef3a..34e47999 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -173,16 +173,88 @@ class Reader(_FileLike):
return result
-class TCPClient:
+class Address(object):
+ """
+ This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information.
+ """
+ def __init__(self, address, use_ipv6=False):
+ self.address = tuple(address)
+ self.use_ipv6 = use_ipv6
+
+ @classmethod
+ def wrap(cls, t):
+ if isinstance(t, cls):
+ return t
+ else:
+ return cls(t)
+
+ def __call__(self):
+ return self.address
+
+ @property
+ def host(self):
+ return self.address[0]
+
+ @property
+ def port(self):
+ return self.address[1]
+
+ @property
+ def use_ipv6(self):
+ return self.family == socket.AF_INET6
+
+ @use_ipv6.setter
+ def use_ipv6(self, b):
+ self.family = socket.AF_INET6 if b else socket.AF_INET
+
+ def __eq__(self, other):
+ other = Address.wrap(other)
+ return (self.address, self.family) == (other.address, other.family)
+
+
+class SocketCloseMixin(object):
+ def finish(self):
+ self.finished = True
+ try:
+ if not getattr(self.wfile, "closed", False):
+ self.wfile.flush()
+ self.close()
+ self.wfile.close()
+ self.rfile.close()
+ except (socket.error, NetLibDisconnect):
+ # Remote has disconnected
+ pass
+
+ def close(self):
+ """
+ Does a hard close of the socket, i.e. a shutdown, followed by a close.
+ """
+ try:
+ if self.ssl_established:
+ self.connection.shutdown()
+ self.connection.sock_shutdown(socket.SHUT_WR)
+ else:
+ self.connection.shutdown(socket.SHUT_WR)
+ #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent.
+ #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
+ while self.connection.recv(4096):
+ pass
+ self.connection.close()
+ except (socket.error, SSL.Error, IOError):
+ # Socket probably already closed
+ pass
+
+
+class TCPClient(SocketCloseMixin):
rbufsize = -1
wbufsize = -1
- def __init__(self, host, port, source_address=None, use_ipv6=False):
- self.host, self.port = host, port
- self.source_address = source_address
- self.use_ipv6 = use_ipv6
+ def __init__(self, address, source_address=None):
+ self.address = Address.wrap(address)
+ self.source_address = Address.wrap(source_address) if source_address else None
self.connection, self.rfile, self.wfile = None, None, None
self.cert = None
self.ssl_established = False
+ self.sni = None
def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None):
"""
@@ -200,6 +272,7 @@ class TCPClient:
self.connection = SSL.Connection(context, self.connection)
self.ssl_established = True
if sni:
+ self.sni = sni
self.connection.set_tlsext_host_name(sni)
self.connection.set_connect_state()
try:
@@ -212,14 +285,14 @@ class TCPClient:
def connect(self):
try:
- connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
+ connection = socket.socket(self.address.family, socket.SOCK_STREAM)
if self.source_address:
- connection.bind(self.source_address)
- connection.connect((self.host, self.port))
+ connection.bind(self.source_address())
+ connection.connect(self.address())
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
except (socket.error, IOError), err:
- raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
+ raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err))
self.connection = connection
def settimeout(self, n):
@@ -228,43 +301,24 @@ class TCPClient:
def gettimeout(self):
return self.connection.gettimeout()
- def close(self):
- """
- Does a hard close of the socket, i.e. a shutdown, followed by a close.
- """
- try:
- if self.ssl_established:
- self.connection.shutdown()
- self.connection.sock_shutdown(socket.SHUT_WR)
- else:
- self.connection.shutdown(socket.SHUT_WR)
- #Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending readable data could lead to an immediate RST being sent.
- #http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
- while self.connection.recv(4096):
- pass
- self.connection.close()
- except (socket.error, SSL.Error, IOError):
- # Socket probably already closed
- pass
-
-class BaseHandler:
+class BaseHandler(SocketCloseMixin):
"""
The instantiator is expected to call the handle() and finish() methods.
"""
rbufsize = -1
wbufsize = -1
- def __init__(self, connection, client_address, server):
+
+ def __init__(self, connection, address, server):
self.connection = connection
+ self.address = Address.wrap(address)
+ self.server = server
self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
- self.client_address = client_address
- self.server = server
self.finished = False
self.ssl_established = False
-
self.clientcert = None
def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None):
@@ -318,66 +372,34 @@ class BaseHandler:
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)
- def finish(self):
- self.finished = True
- try:
- if not getattr(self.wfile, "closed", False):
- self.wfile.flush()
- self.close()
- self.wfile.close()
- self.rfile.close()
- except (socket.error, NetLibDisconnect):
- # Remote has disconnected
- pass
-
def handle(self): # pragma: no cover
raise NotImplementedError
def settimeout(self, n):
self.connection.settimeout(n)
- def close(self):
- """
- Does a hard close of the socket, i.e. a shutdown, followed by a close.
- """
- try:
- if self.ssl_established:
- self.connection.shutdown()
- self.connection.sock_shutdown(socket.SHUT_WR)
- else:
- self.connection.shutdown(socket.SHUT_WR)
- # Section 4.2.2.13 of RFC 1122 tells us that a close() with any
- # pending readable data could lead to an immediate RST being sent.
- # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
- while self.connection.recv(4096):
- pass
- except (socket.error, SSL.Error):
- # Socket probably already closed
- pass
- self.connection.close()
class TCPServer:
request_queue_size = 20
- def __init__(self, server_address, use_ipv6=False):
- self.server_address = server_address
- self.use_ipv6 = use_ipv6
+ def __init__(self, address):
+ self.address = Address.wrap(address)
self.__is_shut_down = threading.Event()
self.__shutdown_request = False
- self.socket = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
+ self.socket = socket.socket(self.address.family, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.socket.bind(self.server_address)
- self.server_address = self.socket.getsockname()
- self.port = self.server_address[1]
+ self.socket.bind(self.address())
+ self.address = Address.wrap(self.socket.getsockname())
self.socket.listen(self.request_queue_size)
- def request_thread(self, request, client_address):
+ def connection_thread(self, connection, client_address):
+ client_address = Address(client_address)
try:
- self.handle_connection(request, client_address)
- request.close()
+ self.handle_client_connection(connection, client_address)
except:
- self.handle_error(request, client_address)
- request.close()
+ self.handle_error(connection, client_address)
+ finally:
+ connection.close()
def serve_forever(self, poll_interval=0.1):
self.__is_shut_down.clear()
@@ -391,10 +413,10 @@ class TCPServer:
else:
raise
if self.socket in r:
- request, client_address = self.socket.accept()
+ connection, client_address = self.socket.accept()
t = threading.Thread(
- target = self.request_thread,
- args = (request, client_address)
+ target = self.connection_thread,
+ args = (connection, client_address)
)
t.setDaemon(1)
t.start()
@@ -410,18 +432,18 @@ class TCPServer:
def handle_error(self, request, client_address, fp=sys.stderr):
"""
- Called when handle_connection raises an exception.
+ Called when handle_client_connection raises an exception.
"""
# If a thread has persisted after interpreter exit, the module might be
# none.
if traceback:
exc = traceback.format_exc()
print >> fp, '-'*40
- print >> fp, "Error in processing of request from %s:%s"%client_address
+ print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port)
print >> fp, exc
print >> fp, '-'*40
- def handle_connection(self, request, client_address): # pragma: no cover
+ def handle_client_connection(self, conn, client_address): # pragma: no cover
"""
Called after client connection.
"""
diff --git a/netlib/test.py b/netlib/test.py
index 85a56739..2f6a7107 100644
--- a/netlib/test.py
+++ b/netlib/test.py
@@ -17,19 +17,18 @@ class ServerTestBase:
ssl = None
handler = None
addr = ("localhost", 0)
- use_ipv6 = False
@classmethod
def setupAll(cls):
cls.q = Queue.Queue()
s = cls.makeserver()
- cls.port = s.port
+ cls.port = s.address.port
cls.server = ServerThread(s)
cls.server.start()
@classmethod
def makeserver(cls):
- return TServer(cls.ssl, cls.q, cls.handler, cls.addr, cls.use_ipv6)
+ return TServer(cls.ssl, cls.q, cls.handler, cls.addr)
@classmethod
def teardownAll(cls):
@@ -41,16 +40,16 @@ class ServerTestBase:
class TServer(tcp.TCPServer):
- def __init__(self, ssl, q, handler_klass, addr, use_ipv6):
+ def __init__(self, ssl, q, handler_klass, addr):
"""
ssl: A {cert, key, v3_only} dict.
"""
- tcp.TCPServer.__init__(self, addr, use_ipv6=use_ipv6)
+ tcp.TCPServer.__init__(self, addr)
self.ssl, self.q = ssl, q
self.handler_klass = handler_klass
self.last_handler = None
- def handle_connection(self, request, client_address):
+ def handle_client_connection(self, request, client_address):
h = self.handler_klass(request, client_address, self)
self.last_handler = h
if self.ssl:
diff --git a/netlib/wsgi.py b/netlib/wsgi.py
index 647cb899..b576bdff 100644
--- a/netlib/wsgi.py
+++ b/netlib/wsgi.py
@@ -1,17 +1,22 @@
import cStringIO, urllib, time, traceback
-import odict
+import odict, tcp
class ClientConn:
def __init__(self, address):
- self.address = address
+ self.address = tcp.Address.wrap(address)
+
+
+class Flow:
+ def __init__(self, client_conn):
+ self.client_conn = client_conn
class Request:
def __init__(self, client_conn, scheme, method, path, headers, content):
self.scheme, self.method, self.path = scheme, method, path
self.headers, self.content = headers, content
- self.client_conn = client_conn
+ self.flow = Flow(client_conn)
def date_time_string():
@@ -60,8 +65,8 @@ class WSGIAdaptor:
'SERVER_PROTOCOL': "HTTP/1.1",
}
environ.update(extra)
- if request.client_conn.address:
- environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address
+ if request.flow.client_conn.address:
+ environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.flow.client_conn.address()
for key, value in request.headers.items():
key = 'HTTP_' + key.upper().replace('-', '_')
diff --git a/test/test_http.py b/test/test_http.py
index a0386115..e80e4b8f 100644
--- a/test/test_http.py
+++ b/test/test_http.py
@@ -223,7 +223,7 @@ class TestReadResponseNoContentLength(test.ServerTestBase):
handler = NoContentLengthHTTPHandler
def test_no_content_length(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None)
assert content == "bar\r\n\r\n"
diff --git a/test/test_tcp.py b/test/test_tcp.py
index 7f2c21c4..525961d5 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -73,7 +73,7 @@ class TestServer(test.ServerTestBase):
handler = EchoHandler
def test_echo(self):
testval = "echo!\n"
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.wfile.write(testval)
c.wfile.flush()
@@ -88,7 +88,7 @@ class TestServerBind(test.ServerTestBase):
for i in range(20):
random_port = random.randrange(1024, 65535)
try:
- c = tcp.TCPClient("127.0.0.1", self.port, source_address=("127.0.0.1", random_port))
+ c = tcp.TCPClient(("127.0.0.1", self.port), source_address=("127.0.0.1", random_port))
c.connect()
assert c.rfile.readline() == str(("127.0.0.1", random_port))
return
@@ -98,11 +98,11 @@ class TestServerBind(test.ServerTestBase):
class TestServerIPv6(test.ServerTestBase):
handler = EchoHandler
- use_ipv6 = True
+ addr = tcp.Address(("localhost", 0), use_ipv6=True)
def test_echo(self):
testval = "echo!\n"
- c = tcp.TCPClient("::1", self.port, use_ipv6=True)
+ c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True))
c.connect()
c.wfile.write(testval)
c.wfile.flush()
@@ -127,7 +127,7 @@ class TestFinishFail(test.ServerTestBase):
handler = FinishFailHandler
def test_disconnect_in_finish(self):
testval = "echo!\n"
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.wfile.write("foo\n")
c.wfile.flush()
@@ -137,7 +137,7 @@ class TestDisconnect(test.ServerTestBase):
handler = EchoHandler
def test_echo(self):
testval = "echo!\n"
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.wfile.write(testval)
c.wfile.flush()
@@ -153,7 +153,7 @@ class TestServerSSL(test.ServerTestBase):
v3_only = False
)
def test_echo(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl(sni="foo.com", options=tcp.OP_ALL)
testval = "echo!\n"
@@ -174,7 +174,7 @@ class TestSSLv3Only(test.ServerTestBase):
v3_only = True
)
def test_failure(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD)
@@ -188,13 +188,13 @@ class TestSSLClientCert(test.ServerTestBase):
v3_only = False
)
def test_clientcert(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl(cert=tutils.test_data.path("data/clientcert/client.pem"))
assert c.rfile.readline().strip() == "1"
def test_clientcert_err(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
tutils.raises(
tcp.NetLibError,
@@ -212,9 +212,10 @@ class TestSNI(test.ServerTestBase):
v3_only = False
)
def test_echo(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl(sni="foo.com")
+ assert c.sni == "foo.com"
assert c.rfile.readline() == "foo.com"
@@ -228,7 +229,7 @@ class TestClientCipherList(test.ServerTestBase):
cipher_list = 'RC4-SHA'
)
def test_echo(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl(sni="foo.com")
assert c.rfile.readline() == "['RC4-SHA']"
@@ -243,7 +244,7 @@ class TestSSLDisconnect(test.ServerTestBase):
v3_only = False
)
def test_echo(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
# Excercise SSL.ZeroReturnError
@@ -255,7 +256,7 @@ class TestSSLDisconnect(test.ServerTestBase):
class TestDisconnect(test.ServerTestBase):
def test_echo(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.rfile.read(10)
c.wfile.write("foo")
@@ -266,7 +267,7 @@ class TestDisconnect(test.ServerTestBase):
class TestServerTimeOut(test.ServerTestBase):
handler = TimeoutHandler
def test_timeout(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
time.sleep(0.3)
assert self.last_handler.timeout
@@ -275,7 +276,7 @@ class TestServerTimeOut(test.ServerTestBase):
class TestTimeOut(test.ServerTestBase):
handler = HangHandler
def test_timeout(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.settimeout(0.1)
assert c.gettimeout() == 0.1
@@ -291,7 +292,7 @@ class TestSSLTimeOut(test.ServerTestBase):
v3_only = False
)
def test_timeout_client(self):
- c = tcp.TCPClient("127.0.0.1", self.port)
+ c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
c.settimeout(0.1)
@@ -300,7 +301,7 @@ class TestSSLTimeOut(test.ServerTestBase):
class TestTCPClient:
def test_conerr(self):
- c = tcp.TCPClient("127.0.0.1", 0)
+ c = tcp.TCPClient(("127.0.0.1", 0))
tutils.raises(tcp.NetLibError, c.connect)