diff options
Diffstat (limited to 'libmproxy')
-rw-r--r-- | libmproxy/protocol.py | 284 | ||||
-rw-r--r-- | libmproxy/proxy.py | 168 |
2 files changed, 336 insertions, 116 deletions
diff --git a/libmproxy/protocol.py b/libmproxy/protocol.py index cd9b4ce5..866ac419 100644 --- a/libmproxy/protocol.py +++ b/libmproxy/protocol.py @@ -1,15 +1,37 @@ -from libmproxy.proxy import ProxyError, ConnectionHandler -from netlib import http +from libmproxy import flow +from libmproxy.utils import timestamp +from netlib import http, utils, tcp +from netlib.odict import ODictCaseless +KILL = 0 # FIXME: Remove duplication with proxy module +LEGACY = True -def handle_messages(conntype, connection_handler): +#FIXME: Combine with ProxyError? +class ProtocolError(Exception): + def __init__(self, code, msg, headers=None): + self.code, self.msg, self.headers = code, msg, headers + + def __str__(self): + return "ProtocolError(%s, %s)"%(self.code, self.msg) + + +def _handle(msg, conntype, connection_handler, *args, **kwargs): handler = None if conntype == "http": handler = HTTPHandler(connection_handler) else: raise NotImplementedError - return handler.handle_messages() + f = getattr(handler, "handle_" + msg) + return f(*args, **kwargs) + + +def handle_messages(conntype, connection_handler): + _handle("messages", conntype, connection_handler) + + +def handle_error(conntype, connection_handler, e): + _handle("error", conntype, connection_handler, e) class ConnectionTypeChange(Exception): @@ -21,6 +43,134 @@ class ProtocolHandler(object): self.c = c +class Flow(object): + def __init__(self, client_conn, server_conn, timestamp_start, timestamp_end): + self.client_conn, self.server_conn = client_conn, server_conn + self.timestamp_start, self.timestamp_end = timestamp_start, timestamp_end + + +class HTTPFlow(Flow): + def __init__(self, client_conn, server_conn, timestamp_start, timestamp_end, request, response): + Flow.__init__(self, client_conn, server_conn, + timestamp_start, timestamp_end) + self.request, self.response = request, response + + +class HTTPResponse(object): + def __init__(self, http_version, code, msg, headers, content, timestamp_start, timestamp_end): + self.http_version = http_version + self.code = code + self.msg = msg + self.headers = headers + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + assert isinstance(headers, ODictCaseless) + + #FIXME: Legacy + @property + def request(self): + return False + + def _assemble(self): + response_line = 'HTTP/%s.%s %s %s'%(self.http_version[0], self.http_version[1], self.code, self.msg) + return '%s\r\n%s\r\n%s' % (response_line, str(self.headers), self.content) + + @classmethod + def from_stream(cls, rfile, request_method, include_content=True, body_size_limit=None): + """ + Parse an HTTP response from a file stream + """ + if not include_content: + raise NotImplementedError + + timestamp_start = timestamp() + http_version, code, msg, headers, content = http.read_response( + rfile, + request_method, + body_size_limit) + timestamp_end = timestamp() + return HTTPResponse(http_version, code, msg, headers, content, timestamp_start, timestamp_end) + +class HTTPRequest(object): + def __init__(self, form_in, method, scheme, host, port, path, http_version, headers, content, + timestamp_start, timestamp_end, form_out=None, ip=None): + self.form_in = form_in + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.http_version = http_version + self.headers = headers + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + self.form_out = form_out or self.form_in + self.ip = ip # resolved ip address + assert isinstance(headers, ODictCaseless) + + #FIXME: Remove, legacy + def is_live(self): + return True + + def _assemble(self): + request_line = None + if self.form_out == "asterisk" or self.form_out == "origin": + request_line = '%s %s HTTP/%s.%s' % (self.method, self.path, self.http_version[0], self.http_version[1]) + else: + raise NotImplementedError + return '%s\r\n%s\r\n%s' % (request_line, str(self.headers), self.content) + + @classmethod + def from_stream(cls, rfile, include_content=True, body_size_limit=None): + """ + Parse an HTTP request from a file stream + """ + http_version, host, port, scheme, method, path, headers, content, timestamp_start, timestamp_end \ + = None, None, None, None, None, None, None, None, None, None + + timestamp_start = timestamp() + request_line = HTTPHandler.get_line(rfile) + + request_line_parts = http.parse_init(request_line) + if not request_line_parts: + raise ProtocolError(400, "Bad HTTP request line: %s"%repr(request_line)) + method, path, http_version = request_line_parts + + if path == '*': + form_in = "asterisk" + elif path.startswith("/"): + form_in = "origin" + if not utils.isascii(path): + raise ProtocolError(400, "Bad HTTP request line: %s"%repr(request_line)) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = http.parse_init_connect(request_line) + if not r: + raise ProtocolError(400, "Bad HTTP request line: %s"%repr(request_line)) + host, port, _ = r + else: + form_in = "absolute" + r = http.parse_init_proxy(request_line) + if not r: + raise ProtocolError(400, "Bad HTTP request line: %s"%repr(request_line)) + _, scheme, host, port, path, _ = r + + headers = http.read_headers(rfile) + if headers is None: + raise ProtocolError(400, "Invalid headers") + + if include_content: + content = http.read_http_body(rfile, headers, body_size_limit, True) + timestamp_end = timestamp() + + return HTTPRequest(form_in, method, scheme, host, port, path, http_version, headers, content, + timestamp_start, timestamp_end) + + class HTTPHandler(ProtocolHandler): def handle_messages(self): @@ -28,68 +178,96 @@ class HTTPHandler(ProtocolHandler): pass self.c.close = True + def handle_error(self, e): + raise e # FIXME: Proper error handling + def handle_request(self): - request = self.read_request() - if request is None: - return - raise NotImplementedError + try: + flow = HTTPFlow(self.c.client_conn, self.c.server_conn, timestamp(), None, None, None) + flow.request = self.read_request() + request_reply = self.c.channel.ask("request" if LEGACY else "httprequest", flow.request) - def read_request(self): - self.c.client_conn.rfile.reset_timestamps() + if request_reply is None or request_reply == KILL: + return False + if isinstance(request_reply, HTTPResponse): + flow.response = request_reply + else: + flow.request = request_reply + raw = flow.request._assemble() + self.c.server_conn.wfile.write(raw) + self.c.server_conn.wfile.flush() + flow.response = self.read_response(flow) + response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse", flow.response) + if response_reply is None or response_reply == KILL: + return False + else: + raw = flow.response._assemble() + self.c.client_conn.wfile.write(raw) + self.c.client_conn.wfile.flush() - request_line = self.get_line(self.c.client_conn.rfile) - method, path, httpversion = http.parse_init(request_line) - headers = self.read_headers(authenticate=True) + if (http.connection_close(flow.request.http_version, flow.request.headers) or + http.connection_close(flow.response.http_version, flow.response.headers)): + return False + + flow.timestamp_end = timestamp() + return flow + except tcp.NetLibDisconnect, e: + return False + + def read_request(self): + request = HTTPRequest.from_stream(self.c.client_conn.rfile, body_size_limit=self.c.config.body_size_limit) if self.c.mode == "regular": - if method == "CONNECT": - r = http.parse_init_connect(request_line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(request_line)) - host, port, _ = r - if self.c.config.forward_proxy: - #FIXME: Treat as request, no custom handling - self.c.server_conn.wfile.write(request_line) - for key, value in headers.items(): - self.c.server_conn.wfile.write("%s: %s\r\n"%(key, value)) - self.c.server_conn.wfile.write("\r\n") - else: - self.c.server_address = (host, port) - self.c.establish_server_connection() + self.authenticate(request) + if request.form_in == "authority": + if not self.c.config.forward_proxy: + self.c.establish_server_connection(request.host, request.port) + self.c.client_conn.wfile.write( + 'HTTP/1.1 200 Connection established\r\n' + + ('Proxy-agent: %s\r\n'%self.c.server_version) + + '\r\n' + ) + self.c.client_conn.wfile.flush() self.c.handle_ssl() - self.c.determine_conntype("transparent", host, port) + self.c.mode = "transparent" + self.c.determine_conntype() + # FIXME: We need to persist the CONNECT request raise ConnectionTypeChange - else: - r = http.parse_init_proxy(request_line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(request_line)) - method, scheme, host, port, path, httpversion = r + elif request.form_in == "absolute": if not self.c.config.forward_proxy: - if (not self.c.server_conn) or (self.c.server_address != (host, port)): - self.c.server_address = (host, port) - self.c.establish_server_connection() + request.form_out = "origin" + if ((not self.c.server_conn) or + (self.c.server_conn.address != (request.host, request.port))): + self.c.establish_server_connection(request.host, request.port) + else: + raise ProtocolError(400, "Invalid Request") + + return request - def get_line(self, fp): + def read_response(self, flow): + return HTTPResponse.from_stream(self.c.server_conn.rfile, flow.request.method, body_size_limit=self.c.config.body_size_limit) + + def authenticate(self, request): + if self.c.config.authenticator: + if self.c.config.authenticator.authenticate(request.headers): + self.c.config.authenticator.clean(request.headers) + else: + raise ProtocolError( + 407, + "Proxy Authentication Required", + self.c.config.authenticator.auth_challenge_headers() + ) + return request.headers + + @staticmethod + def get_line(fp): """ Get a line, possibly preceded by a blank. """ line = fp.readline() if line == "\r\n" or line == "\n": # Possible leftover from previous message line = fp.readline() - return line - - def read_headers(self, authenticate=False): - headers = http.read_headers(self.c.client_conn.rfile) - if headers is None: - raise ProxyError(400, "Invalid headers") - if authenticate and self.c.config.authenticator: - if self.c.config.authenticator.authenticate(headers): - self.c.config.authenticator.clean(headers) - else: - raise ProxyError( - 407, - "Proxy Authentication Required", - self.c.config.authenticator.auth_challenge_headers() - ) - return headers
\ No newline at end of file + if line == "": + raise tcp.NetLibDisconnect + return line
\ No newline at end of file diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 5ac40e92..3297ab90 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -1,8 +1,6 @@ -import sys, os, string, socket, time -import shutil, tempfile, threading -import SocketServer +import os, socket, time, threading from OpenSSL import SSL -from netlib import odict, tcp, http, certutils, http_status, http_auth +from netlib import tcp, http, certutils, http_auth import utils, flow, version, platform, controller, protocol @@ -19,6 +17,11 @@ class ProxyError(Exception): return "ProxyError(%s, %s)"%(self.code, self.msg) +class Log: + def __init__(self, msg): + self.msg = msg + + class ProxyConfig: def __init__(self, certfile = None, cacert = None, clientcerts = None, no_upstream_cert=False, body_size_limit = None, reverse_proxy=None, forward_proxy=None, transparent_proxy=None, authenticator=None): self.certfile = certfile @@ -33,27 +36,55 @@ class ProxyConfig: self.certstore = certutils.CertStore() +class ClientConnection(tcp.BaseHandler): + def __init__(self, client_connection, host, port): + tcp.BaseHandler.__init__(self, client_connection) + self.host, self.port = host, port + + self.timestamp_start = utils.timestamp() + self.timestamp_end = None + self.timestamp_ssl_setup = None + + @property + def address(self): + return self.host, self.port + + def convert_to_ssl(self, *args, **kwargs): + tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs) + self.timestamp_ssl_setup = utils.timestamp() + + def finish(self): + tcp.BaseHandler.finish(self) + self.timestamp_end = utils.timestamp() + + class ServerConnection(tcp.TCPClient): - def __init__(self, config, host, port, sni): + def __init__(self, host, port): tcp.TCPClient.__init__(self, host, port) - self.config = config - self.sni = sni - self.tcp_setup_timestamp = None - self.ssl_setup_timestamp = None + + self.timestamp_start = None + self.timestamp_end = None + self.timestamp_tcp_setup = None + self.timestamp_ssl_setup = None + + @property + def address(self): + return self.host, self.port def connect(self): + self.timestamp_start = utils.timestamp() tcp.TCPClient.connect(self) - self.tcp_setup_timestamp = time.time() + self.timestamp_tcp_setup = utils.timestamp() - def establish_ssl(self): + def establish_ssl(self, clientcerts, sni): clientcert = None - if self.config.clientcerts: - path = os.path.join(self.config.clientcerts, self.host.encode("idna")) + ".pem" + if clientcerts: + path = os.path.join(clientcerts, self.host.encode("idna")) + ".pem" if os.path.exists(path): clientcert = path try: - self.convert_to_ssl(cert=clientcert, sni=self.sni) - self.ssl_setup_timestamp = time.time() + self.convert_to_ssl(cert=clientcert, sni=sni) + self.timestamp_ssl_setup = utils.timestamp() except tcp.NetLibError, v: raise ProxyError(400, str(v)) @@ -65,16 +96,12 @@ class ServerConnection(tcp.TCPClient): self.wfile.write(d) self.wfile.flush() - def terminate(self): - if self.connection: - try: - self.wfile.flush() - except tcp.NetLibDisconnect: # pragma: no cover - pass - self.connection.close() - + def finish(self): + tcp.TCPClient.finish(self) + self.timestamp_end = utils.timestamp() +""" class RequestReplayThread(threading.Thread): def __init__(self, config, flow, masterq): self.config, self.flow, self.channel = config, flow, controller.Channel(masterq) @@ -98,14 +125,17 @@ class RequestReplayThread(threading.Thread): except (ProxyError, http.HttpError, tcp.NetLibError), v: err = flow.Error(self.flow.request, str(v)) self.channel.ask("error", err) +""" + class ConnectionHandler: def __init__(self, config, client_connection, client_address, server, channel, server_version): self.config = config - self.client_address, self.client_conn = client_address, tcp.BaseHandler(client_connection) - self.server_address, self.server_conn = None, None + self.client_conn = ClientConnection(client_connection, *client_address) + self.server_conn = None self.channel, self.server_version = channel, server_version + self.close = False self.conntype = None self.sni = None @@ -117,7 +147,7 @@ class ConnectionHandler: def del_server_connection(self): if self.server_conn: - self.server_conn.terminate() + self.server_conn.finish() self.channel.tell("serverdisconnect", self) self.server_conn = None self.sni = None @@ -126,81 +156,93 @@ class ConnectionHandler: self.log("connect") self.channel.ask("clientconnect", self) - # Can we already identify the target server and connect to it? - if self.config.forward_proxy: - self.server_address = self.config.forward_proxy[1:] - else: - if self.config.reverse_proxy: - self.server_address = self.config.reverse_proxy[1:] - elif self.config.transparent_proxy: - self.server_address = self.config.transparent_proxy["resolver"].original_addr(self.connection) - if not self.server_address: - raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") - self.log("transparent to %s:%s"%self.server_address) - - if self.server_address: - self.establish_server_connection() - self.handle_ssl() - - self.determine_conntype(self.mode, *self.server_address) - - while not self.close: - try: - protocol.handle_messages(self.conntype, self) - except protocol.ConnectionTypeChange: - continue - - self.del_server_connection() + try: + # Can we already identify the target server and connect to it? + server_address = None + if self.config.forward_proxy: + server_address = self.config.forward_proxy[1:] + else: + if self.config.reverse_proxy: + server_address = self.config.reverse_proxy[1:] + elif self.config.transparent_proxy: + server_address = self.config.transparent_proxy["resolver"].original_addr(self.client_conn.connection) + if not server_address: + raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") + self.log("transparent to %s:%s"%server_address) + + if server_address: + self.establish_server_connection(*server_address) + self.handle_ssl() + + self.determine_conntype() + + while not self.close: + try: + protocol.handle_messages(self.conntype, self) + except protocol.ConnectionTypeChange: + continue + + self.del_server_connection() + except (ProxyError, protocol.ProtocolError), e: + self.log(str(e)) + protocol.handle_error(self.conntype, self, e) + # FIXME: We need to persist errors self.log("disconnect") self.channel.tell("clientdisconnect", self) - def determine_conntype(self, mode, host, port): + def finish(self): + self.client_conn.finish() + + def determine_conntype(self): #TODO: Add ruleset to select correct protocol depending on mode/target port etc. self.conntype = "http" - def establish_server_connection(self): + def establish_server_connection(self, host, port): """ - Establishes a new server connection to self.server_address. + Establishes a new server connection to the given server If there is already an existing server connection, it will be killed. """ self.del_server_connection() - self.server_conn = ServerConnection(self.config, *self.server_address, self.sni) + self.server_conn = ServerConnection(host, port) + self.server_conn.connect() + self.log("serverconnect", ["%s:%s"%(host, port)]) self.channel.tell("serverconnect", self) def handle_ssl(self): if self.config.transparent_proxy: - client_ssl, server_ssl = (self.server_address[1] in self.config.transparent_proxy["sslports"]) + client_ssl = server_ssl = (self.server_conn.port in self.config.transparent_proxy["sslports"]) elif self.config.reverse_proxy: - client_ssl, server_ssl = (self.config.reverse_proxy[0] == "https") + client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https") # TODO: Make protocol generic (as with transparent proxies) # TODO: Add SSL-terminating capatbility (SSL -> mitmproxy -> plain and vice versa) else: - client_ssl, server_ssl = True # In regular mode, this function will only be called on HTTP CONNECT + client_ssl = server_ssl = True # In regular mode, this function will only be called on HTTP CONNECT # TODO: Implement SSL pass-through handling and change conntype if server_ssl and not self.server_conn.ssl_established: - self.server_conn.establish_ssl() + self.server_conn.establish_ssl(self.config.clientcerts, self.sni) if client_ssl and not self.client_conn.ssl_established: dummycert = self.find_cert() - self.client_conn.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=self.handle_sni) + self.client_conn.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, + handle_sni=self.handle_sni) def log(self, msg, subs=()): msg = [ - "%s:%s: "%(self.client_address, msg) + "%s:%s: %s" % (self.client_conn.host, self.client_conn.port, msg) ] for i in subs: msg.append(" -> "+i) msg = "\n".join(msg) - self.channel.tell("log", msg) + self.channel.tell("log", Log(msg)) def find_cert(self): if self.config.certfile: with open(self.config.certfile, "rb") as f: return certutils.SSLCert.from_pem(f.read()) else: - host = self.server_address[0] + host = self.server_conn.host sans = [] if not self.config.no_upstream_cert or not self.server_conn.ssl_established: upstream_cert = self.server_conn.cert |