aboutsummaryrefslogtreecommitdiffstats
path: root/libmproxy/proxy.py
diff options
context:
space:
mode:
Diffstat (limited to 'libmproxy/proxy.py')
-rw-r--r--libmproxy/proxy.py168
1 files changed, 105 insertions, 63 deletions
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py
index 5ac40e92..3297ab90 100644
--- a/libmproxy/proxy.py
+++ b/libmproxy/proxy.py
@@ -1,8 +1,6 @@
-import sys, os, string, socket, time
-import shutil, tempfile, threading
-import SocketServer
+import os, socket, time, threading
from OpenSSL import SSL
-from netlib import odict, tcp, http, certutils, http_status, http_auth
+from netlib import tcp, http, certutils, http_auth
import utils, flow, version, platform, controller, protocol
@@ -19,6 +17,11 @@ class ProxyError(Exception):
return "ProxyError(%s, %s)"%(self.code, self.msg)
+class Log:
+ def __init__(self, msg):
+ self.msg = msg
+
+
class ProxyConfig:
def __init__(self, certfile = None, cacert = None, clientcerts = None, no_upstream_cert=False, body_size_limit = None, reverse_proxy=None, forward_proxy=None, transparent_proxy=None, authenticator=None):
self.certfile = certfile
@@ -33,27 +36,55 @@ class ProxyConfig:
self.certstore = certutils.CertStore()
+class ClientConnection(tcp.BaseHandler):
+ def __init__(self, client_connection, host, port):
+ tcp.BaseHandler.__init__(self, client_connection)
+ self.host, self.port = host, port
+
+ self.timestamp_start = utils.timestamp()
+ self.timestamp_end = None
+ self.timestamp_ssl_setup = None
+
+ @property
+ def address(self):
+ return self.host, self.port
+
+ def convert_to_ssl(self, *args, **kwargs):
+ tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs)
+ self.timestamp_ssl_setup = utils.timestamp()
+
+ def finish(self):
+ tcp.BaseHandler.finish(self)
+ self.timestamp_end = utils.timestamp()
+
+
class ServerConnection(tcp.TCPClient):
- def __init__(self, config, host, port, sni):
+ def __init__(self, host, port):
tcp.TCPClient.__init__(self, host, port)
- self.config = config
- self.sni = sni
- self.tcp_setup_timestamp = None
- self.ssl_setup_timestamp = None
+
+ self.timestamp_start = None
+ self.timestamp_end = None
+ self.timestamp_tcp_setup = None
+ self.timestamp_ssl_setup = None
+
+ @property
+ def address(self):
+ return self.host, self.port
def connect(self):
+ self.timestamp_start = utils.timestamp()
tcp.TCPClient.connect(self)
- self.tcp_setup_timestamp = time.time()
+ self.timestamp_tcp_setup = utils.timestamp()
- def establish_ssl(self):
+ def establish_ssl(self, clientcerts, sni):
clientcert = None
- if self.config.clientcerts:
- path = os.path.join(self.config.clientcerts, self.host.encode("idna")) + ".pem"
+ if clientcerts:
+ path = os.path.join(clientcerts, self.host.encode("idna")) + ".pem"
if os.path.exists(path):
clientcert = path
try:
- self.convert_to_ssl(cert=clientcert, sni=self.sni)
- self.ssl_setup_timestamp = time.time()
+ self.convert_to_ssl(cert=clientcert, sni=sni)
+ self.timestamp_ssl_setup = utils.timestamp()
except tcp.NetLibError, v:
raise ProxyError(400, str(v))
@@ -65,16 +96,12 @@ class ServerConnection(tcp.TCPClient):
self.wfile.write(d)
self.wfile.flush()
- def terminate(self):
- if self.connection:
- try:
- self.wfile.flush()
- except tcp.NetLibDisconnect: # pragma: no cover
- pass
- self.connection.close()
-
+ def finish(self):
+ tcp.TCPClient.finish(self)
+ self.timestamp_end = utils.timestamp()
+"""
class RequestReplayThread(threading.Thread):
def __init__(self, config, flow, masterq):
self.config, self.flow, self.channel = config, flow, controller.Channel(masterq)
@@ -98,14 +125,17 @@ class RequestReplayThread(threading.Thread):
except (ProxyError, http.HttpError, tcp.NetLibError), v:
err = flow.Error(self.flow.request, str(v))
self.channel.ask("error", err)
+"""
+
class ConnectionHandler:
def __init__(self, config, client_connection, client_address, server, channel, server_version):
self.config = config
- self.client_address, self.client_conn = client_address, tcp.BaseHandler(client_connection)
- self.server_address, self.server_conn = None, None
+ self.client_conn = ClientConnection(client_connection, *client_address)
+ self.server_conn = None
self.channel, self.server_version = channel, server_version
+ self.close = False
self.conntype = None
self.sni = None
@@ -117,7 +147,7 @@ class ConnectionHandler:
def del_server_connection(self):
if self.server_conn:
- self.server_conn.terminate()
+ self.server_conn.finish()
self.channel.tell("serverdisconnect", self)
self.server_conn = None
self.sni = None
@@ -126,81 +156,93 @@ class ConnectionHandler:
self.log("connect")
self.channel.ask("clientconnect", self)
- # Can we already identify the target server and connect to it?
- if self.config.forward_proxy:
- self.server_address = self.config.forward_proxy[1:]
- else:
- if self.config.reverse_proxy:
- self.server_address = self.config.reverse_proxy[1:]
- elif self.config.transparent_proxy:
- self.server_address = self.config.transparent_proxy["resolver"].original_addr(self.connection)
- if not self.server_address:
- raise ProxyError(502, "Transparent mode failure: could not resolve original destination.")
- self.log("transparent to %s:%s"%self.server_address)
-
- if self.server_address:
- self.establish_server_connection()
- self.handle_ssl()
-
- self.determine_conntype(self.mode, *self.server_address)
-
- while not self.close:
- try:
- protocol.handle_messages(self.conntype, self)
- except protocol.ConnectionTypeChange:
- continue
-
- self.del_server_connection()
+ try:
+ # Can we already identify the target server and connect to it?
+ server_address = None
+ if self.config.forward_proxy:
+ server_address = self.config.forward_proxy[1:]
+ else:
+ if self.config.reverse_proxy:
+ server_address = self.config.reverse_proxy[1:]
+ elif self.config.transparent_proxy:
+ server_address = self.config.transparent_proxy["resolver"].original_addr(self.client_conn.connection)
+ if not server_address:
+ raise ProxyError(502, "Transparent mode failure: could not resolve original destination.")
+ self.log("transparent to %s:%s"%server_address)
+
+ if server_address:
+ self.establish_server_connection(*server_address)
+ self.handle_ssl()
+
+ self.determine_conntype()
+
+ while not self.close:
+ try:
+ protocol.handle_messages(self.conntype, self)
+ except protocol.ConnectionTypeChange:
+ continue
+
+ self.del_server_connection()
+ except (ProxyError, protocol.ProtocolError), e:
+ self.log(str(e))
+ protocol.handle_error(self.conntype, self, e)
+ # FIXME: We need to persist errors
self.log("disconnect")
self.channel.tell("clientdisconnect", self)
- def determine_conntype(self, mode, host, port):
+ def finish(self):
+ self.client_conn.finish()
+
+ def determine_conntype(self):
#TODO: Add ruleset to select correct protocol depending on mode/target port etc.
self.conntype = "http"
- def establish_server_connection(self):
+ def establish_server_connection(self, host, port):
"""
- Establishes a new server connection to self.server_address.
+ Establishes a new server connection to the given server
If there is already an existing server connection, it will be killed.
"""
self.del_server_connection()
- self.server_conn = ServerConnection(self.config, *self.server_address, self.sni)
+ self.server_conn = ServerConnection(host, port)
+ self.server_conn.connect()
+ self.log("serverconnect", ["%s:%s"%(host, port)])
self.channel.tell("serverconnect", self)
def handle_ssl(self):
if self.config.transparent_proxy:
- client_ssl, server_ssl = (self.server_address[1] in self.config.transparent_proxy["sslports"])
+ client_ssl = server_ssl = (self.server_conn.port in self.config.transparent_proxy["sslports"])
elif self.config.reverse_proxy:
- client_ssl, server_ssl = (self.config.reverse_proxy[0] == "https")
+ client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https")
# TODO: Make protocol generic (as with transparent proxies)
# TODO: Add SSL-terminating capatbility (SSL -> mitmproxy -> plain and vice versa)
else:
- client_ssl, server_ssl = True # In regular mode, this function will only be called on HTTP CONNECT
+ client_ssl = server_ssl = True # In regular mode, this function will only be called on HTTP CONNECT
# TODO: Implement SSL pass-through handling and change conntype
if server_ssl and not self.server_conn.ssl_established:
- self.server_conn.establish_ssl()
+ self.server_conn.establish_ssl(self.config.clientcerts, self.sni)
if client_ssl and not self.client_conn.ssl_established:
dummycert = self.find_cert()
- self.client_conn.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=self.handle_sni)
+ self.client_conn.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert,
+ handle_sni=self.handle_sni)
def log(self, msg, subs=()):
msg = [
- "%s:%s: "%(self.client_address, msg)
+ "%s:%s: %s" % (self.client_conn.host, self.client_conn.port, msg)
]
for i in subs:
msg.append(" -> "+i)
msg = "\n".join(msg)
- self.channel.tell("log", msg)
+ self.channel.tell("log", Log(msg))
def find_cert(self):
if self.config.certfile:
with open(self.config.certfile, "rb") as f:
return certutils.SSLCert.from_pem(f.read())
else:
- host = self.server_address[0]
+ host = self.server_conn.host
sans = []
if not self.config.no_upstream_cert or not self.server_conn.ssl_established:
upstream_cert = self.server_conn.cert