aboutsummaryrefslogtreecommitdiffstats
path: root/libmproxy/proxy.py
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2012-06-10 16:02:48 +1200
committerAldo Cortesi <aldo@nullcube.com>2012-06-10 16:02:48 +1200
commit1f659948cddab77d9203d4d3b979b10d8fa12b98 (patch)
tree480b44aa6d7b1a2aa56d31aa553577f4063342bf /libmproxy/proxy.py
parent236447c65fef4f81ebdd311d225ada0d0d544bac (diff)
downloadmitmproxy-1f659948cddab77d9203d4d3b979b10d8fa12b98.tar.gz
mitmproxy-1f659948cddab77d9203d4d3b979b10d8fa12b98.tar.bz2
mitmproxy-1f659948cddab77d9203d4d3b979b10d8fa12b98.zip
Refactor request processing at mitmproxy's core.
Gradually cleaning up towards a state machine model.
Diffstat (limited to 'libmproxy/proxy.py')
-rw-r--r--libmproxy/proxy.py139
1 files changed, 75 insertions, 64 deletions
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py
index 7698a61f..dbe91e7e 100644
--- a/libmproxy/proxy.py
+++ b/libmproxy/proxy.py
@@ -104,7 +104,7 @@ def read_chunked(fp, limit):
return content
-def read_http_body(rfile, connection, headers, all, limit):
+def read_http_body(rfile, client_conn, headers, all, limit):
if 'transfer-encoding' in headers:
if not ",".join(headers["transfer-encoding"]).lower() == "chunked":
raise IOError('Invalid transfer-encoding')
@@ -121,7 +121,7 @@ def read_http_body(rfile, connection, headers, all, limit):
content = rfile.read(l)
elif all:
content = rfile.read(limit if limit else None)
- connection.close = True
+ client_conn.close = True
else:
content = ""
return content
@@ -203,6 +203,18 @@ def should_connection_close(httpversion, headers):
return True
+def read_http_body_request(rfile, wfile, client_conn, headers, httpversion, limit):
+ if "expect" in headers:
+ # FIXME: Should be forwarded upstream
+ expect = ",".join(headers['expect'])
+ if expect == "100-continue" and httpversion >= (1, 1):
+ wfile.write('HTTP/1.1 100 Continue\r\n')
+ wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
+ wfile.write('\r\n')
+ del headers['expect']
+ return read_http_body(rfile, client_conn, headers, False, limit)
+
+
class FileLike:
def __init__(self, o):
self.o = o
@@ -262,10 +274,10 @@ class RequestReplayThread(threading.Thread):
class ServerConnection:
def __init__(self, config, scheme, host, port):
self.config, self.scheme, self.host, self.port = config, scheme, host, port
- self.close = False
self.cert = None
self.sock, self.rfile, self.wfile = None, None, None
self.connect()
+ self.requestcount = 0
def connect(self):
try:
@@ -288,6 +300,7 @@ class ServerConnection:
self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
def send(self, request):
+ self.requestcount += 1
try:
d = request._assemble()
if not d:
@@ -336,29 +349,38 @@ class ServerConnection:
class ProxyHandler(SocketServer.StreamRequestHandler):
def __init__(self, config, request, client_address, server, q):
- self.config = config
self.mqueue = q
+ self.config = config
self.server_conn = None
+ self.proxy_connect_state = None
SocketServer.StreamRequestHandler.__init__(self, request, client_address, server)
def handle(self):
cc = flow.ClientConnect(self.client_address)
cc._send(self.mqueue)
- while not cc.close:
- self.handle_request(cc)
+ while self.handle_request(cc) and not cc.close:
+ pass
+ cc.close = True
cd = flow.ClientDisconnect(cc)
cd._send(self.mqueue)
self.finish()
+ def server_connect(self, scheme, host, port):
+ sc = self.server_conn
+ if sc and (scheme, host, port) != (sc.scheme, sc.host, sc.port):
+ sc.terminate()
+ self.server_conn = None
+ if not self.server_conn:
+ self.server_conn = ServerConnection(self.config, scheme, host, port)
+
def handle_request(self, cc):
- server_conn, request, err = None, None, None
try:
+ request, err = None, None
try:
request = self.read_request(cc)
except IOError, v:
raise IOError, "Reading request: %s"%v
if request is None:
- cc.close = True
return
cc.requestcount += 1
@@ -368,7 +390,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
else:
request = request._send(self.mqueue)
if request is None:
- cc.close = True
return
if isinstance(request, flow.Response):
@@ -380,31 +401,30 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
scheme, host, port = self.config.reverse_proxy
else:
scheme, host, port = request.scheme, request.host, request.port
- server_conn = ServerConnection(self.config, scheme, host, port)
- server_conn.send(request)
+ self.server_connect(scheme, host, port)
+ self.server_conn.send(request)
try:
- response = server_conn.read_response(request)
+ response = self.server_conn.read_response(request)
except IOError, v:
raise IOError, "Reading response: %s"%v
response = response._send(self.mqueue)
if response is None:
- server_conn.terminate()
+ self.server_conn.terminate()
if response is None:
- cc.close = True
return
self.send_response(response)
+ if should_connection_close(request.httpversion, request.headers):
+ return
except IOError, v:
cc.connection_error = v
- cc.close = True
except ProxyError, e:
- cc.close = True
cc.connection_error = "%s: %s"%(e.code, e.msg)
if request:
err = flow.Error(request, e.msg)
err._send(self.mqueue)
self.send_error(e.code, e.msg)
- if server_conn:
- server_conn.terminate()
+ else:
+ return True
def find_cert(self, host, port):
if self.config.certfile:
@@ -435,26 +455,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
self.rfile = FileLike(self.connection)
self.wfile = FileLike(self.connection)
- def read_contents(self, client_conn, headers, httpversion):
- if "expect" in headers:
- # FIXME: Should be forwarded upstream
- expect = ",".join(headers['expect'])
- if expect == "100-continue" and httpversion >= (1, 1):
- self.wfile.write('HTTP/1.1 100 Continue\r\n')
- self.wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
- self.wfile.write('\r\n')
- del headers['expect']
- if httpversion < (1, 1):
- client_conn.close = True
- if "connection" in headers:
- for value in ",".join(headers['connection']).split(","):
- value = value.strip()
- if value == "close":
- client_conn.close = True
- if value == "keep-alive":
- client_conn.close = False
- return read_http_body(self.rfile, client_conn, headers, False, self.config.body_size_limit)
-
def read_request(self, client_conn):
line = self.rfile.readline()
if line == "\r\n" or line == "\n": # Possible leftover from previous message
@@ -466,34 +466,45 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
scheme, host, port = self.config.reverse_proxy
method, path, httpversion = parse_init_http(line)
headers = read_headers(self.rfile)
- content = self.read_contents(client_conn, headers, httpversion)
+ content = read_http_body_request(
+ self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit
+ )
return flow.Request(client_conn, httpversion, host, port, "http", method, path, headers, content)
- elif line.startswith("CONNECT"):
- host, port, httpversion = parse_init_connect(line)
- # FIXME: Discard additional headers sent to the proxy. Should I expose
- # these to users?
- while 1:
- d = self.rfile.readline()
- if d == '\r\n' or d == '\n':
- break
- self.wfile.write(
- 'HTTP/1.1 200 Connection established\r\n' +
- ('Proxy-agent: %s\r\n'%version.NAMEVERSION) +
- '\r\n'
- )
- self.wfile.flush()
- certfile = self.find_cert(host, port)
- self.convert_to_ssl(certfile)
-
- method, path, httpversion = parse_init_http(self.rfile.readline(line))
- headers = read_headers(self.rfile)
- content = self.read_contents(client_conn, headers, httpversion)
- return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content)
else:
- method, scheme, host, port, path, httpversion = parse_init_proxy(line)
- headers = read_headers(self.rfile)
- content = self.read_contents(client_conn, headers, httpversion)
- return flow.Request(client_conn, httpversion, host, port, scheme, method, path, headers, content)
+ if line.startswith("CONNECT"):
+ host, port, httpversion = parse_init_connect(line)
+ # FIXME: Discard additional headers sent to the proxy. Should I expose
+ # these to users?
+ while 1:
+ d = self.rfile.readline()
+ if d == '\r\n' or d == '\n':
+ break
+ self.wfile.write(
+ 'HTTP/1.1 200 Connection established\r\n' +
+ ('Proxy-agent: %s\r\n'%version.NAMEVERSION) +
+ '\r\n'
+ )
+ self.wfile.flush()
+ certfile = self.find_cert(host, port)
+ self.convert_to_ssl(certfile)
+ self.proxy_connect_state = (host, port, httpversion)
+ line = self.rfile.readline(line)
+
+ if self.proxy_connect_state:
+ host, port, httpversion = self.proxy_connect_state
+ method, path, httpversion = parse_init_http(line)
+ headers = read_headers(self.rfile)
+ content = read_http_body_request(
+ self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit
+ )
+ return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content)
+ else:
+ method, scheme, host, port, path, httpversion = parse_init_proxy(line)
+ headers = read_headers(self.rfile)
+ content = read_http_body_request(
+ self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit
+ )
+ return flow.Request(client_conn, httpversion, host, port, scheme, method, path, headers, content)
def send_response(self, response):
d = response._assemble()