diff options
35 files changed, 2530 insertions, 1804 deletions
diff --git a/doc-src/proxy-flowchart/proxy-flowchart.pdf b/doc-src/proxy-flowchart/proxy-flowchart.pdf Binary files differnew file mode 100644 index 00000000..ae98eea3 --- /dev/null +++ b/doc-src/proxy-flowchart/proxy-flowchart.pdf diff --git a/doc-src/proxy-flowchart/proxy-flowchart.vsdx b/doc-src/proxy-flowchart/proxy-flowchart.vsdx Binary files differnew file mode 100644 index 00000000..4d75f49f --- /dev/null +++ b/doc-src/proxy-flowchart/proxy-flowchart.vsdx diff --git a/libmproxy/app.py b/libmproxy/app.py index b0692cf2..b046f712 100644 --- a/libmproxy/app.py +++ b/libmproxy/app.py @@ -4,9 +4,11 @@ import os.path mapp = flask.Flask(__name__) mapp.debug = True + def master(): return flask.request.environ["mitmproxy.master"] + @mapp.route("/") def index(): return flask.render_template("index.html", section="home") @@ -16,12 +18,12 @@ def index(): def certs_pem(): capath = master().server.config.cacert p = os.path.splitext(capath)[0] + "-cert.pem" - return flask.Response(open(p).read(), mimetype='application/x-x509-ca-cert') + return flask.Response(open(p, "rb").read(), mimetype='application/x-x509-ca-cert') @mapp.route("/cert/p12") def certs_p12(): capath = master().server.config.cacert p = os.path.splitext(capath)[0] + "-cert.p12" - return flask.Response(open(p).read(), mimetype='application/x-pkcs12') + return flask.Response(open(p, "rb").read(), mimetype='application/x-pkcs12') diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py index a316602c..d92561f2 100644 --- a/libmproxy/console/__init__.py +++ b/libmproxy/console/__init__.py @@ -197,7 +197,7 @@ class StatusBar(common.WWrap): ] if self.master.server.bound: - boundaddr = "[%s:%s]"%(self.master.server.address or "*", self.master.server.port) + boundaddr = "[%s:%s]"%(self.master.server.address.host or "*", self.master.server.address.port) else: boundaddr = "" t.extend(self.get_status()) @@ -1008,7 +1008,7 @@ class ConsoleMaster(flow.FlowMaster): self.statusbar.refresh_flow(c) def process_flow(self, f, r): - if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay(): + if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay: f.intercept() else: r.reply() diff --git a/libmproxy/console/common.py b/libmproxy/console/common.py index 951d2c2a..715bed80 100644 --- a/libmproxy/console/common.py +++ b/libmproxy/console/common.py @@ -172,7 +172,7 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2): intercepting = f.intercepting, req_timestamp = f.request.timestamp_start, - req_is_replay = f.request.is_replay(), + req_is_replay = f.request.is_replay, req_method = f.request.method, req_acked = f.request.reply.acked, req_url = f.request.get_url(hostheader=hostheader), @@ -189,12 +189,12 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2): contentdesc = "[no content]" delta = f.response.timestamp_end - f.response.timestamp_start - size = len(f.response.content) + f.response.get_header_size() + size = f.response.size() rate = utils.pretty_size(size / ( delta if delta > 0 else 1 ) ) d.update(dict( resp_code = f.response.code, - resp_is_replay = f.response.is_replay(), + resp_is_replay = f.response.is_replay, resp_acked = f.response.reply.acked, resp_clen = contentdesc, resp_rate = "{0}/s".format(rate), diff --git a/libmproxy/console/flowdetailview.py b/libmproxy/console/flowdetailview.py index a26e5308..436d8f07 100644 --- a/libmproxy/console/flowdetailview.py +++ b/libmproxy/console/flowdetailview.py @@ -1,5 +1,6 @@ import urwid import common +from .. import utils footer = [ ('heading_key', "q"), ":back ", @@ -33,8 +34,17 @@ class FlowDetailsView(urwid.ListBox): title = urwid.AttrWrap(title, "heading") text.append(title) - if self.flow.response: - c = self.flow.response.cert + if self.flow.server_conn: + text.append(urwid.Text([("head", "Server Connection:")])) + sc = self.flow.server_conn + parts = [ + ["Address", "%s:%s" % sc.peername], + ["Start time", utils.format_timestamp(sc.timestamp_start)], + ["End time", utils.format_timestamp(sc.timestamp_end) if sc.timestamp_end else "active"], + ] + text.extend(common.format_keyvals(parts, key="key", val="text", indent=4)) + + c = self.flow.server_conn.cert if c: text.append(urwid.Text([("head", "Server Certificate:")])) parts = [ @@ -43,19 +53,13 @@ class FlowDetailsView(urwid.ListBox): ["Valid to", str(c.notafter)], ["Valid from", str(c.notbefore)], ["Serial", str(c.serial)], - ] - - parts.append( [ "Subject", urwid.BoxAdapter( urwid.ListBox(common.format_keyvals(c.subject, key="highlight", val="text")), len(c.subject) ) - ] - ) - - parts.append( + ], [ "Issuer", urwid.BoxAdapter( @@ -63,7 +67,7 @@ class FlowDetailsView(urwid.ListBox): len(c.issuer) ) ] - ) + ] if c.altnames: parts.append( @@ -74,13 +78,14 @@ class FlowDetailsView(urwid.ListBox): ) text.extend(common.format_keyvals(parts, key="key", val="text", indent=4)) - if self.flow.request.client_conn: + if self.flow.client_conn: text.append(urwid.Text([("head", "Client Connection:")])) - cc = self.flow.request.client_conn + cc = self.flow.client_conn parts = [ - ["Address", "%s:%s"%tuple(cc.address)], - ["Requests", "%s"%cc.requestcount], - ["Closed", "%s"%cc.close], + ["Address", "%s:%s" % cc.address()], + ["Start time", utils.format_timestamp(cc.timestamp_start)], + # ["Requests", "%s"%cc.requestcount], + ["End time", utils.format_timestamp(cc.timestamp_end) if cc.timestamp_end else "active"], ] text.extend(common.format_keyvals(parts, key="key", val="text", indent=4)) diff --git a/libmproxy/controller.py b/libmproxy/controller.py index b662b6d5..470d88fc 100644 --- a/libmproxy/controller.py +++ b/libmproxy/controller.py @@ -72,6 +72,7 @@ class Slave(threading.Thread): self.channel, self.server = channel, server self.server.set_channel(channel) threading.Thread.__init__(self) + self.name = "SlaveThread (%s:%s)" % (self.server.address.host, self.server.address.port) def run(self): self.server.serve_forever() diff --git a/libmproxy/dump.py b/libmproxy/dump.py index 8bd29ae5..6cf5e688 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -42,14 +42,14 @@ class Options(object): def str_response(resp): r = "%s %s"%(resp.code, resp.msg) - if resp.is_replay(): + if resp.is_replay: r = "[replay] " + r return r def str_request(req, showhost): - if req.client_conn: - c = req.client_conn.address[0] + if req.flow.client_conn: + c = req.flow.client_conn.address.host else: c = "[replay]" r = "%s %s %s"%(c, req.method, req.get_url(showhost)) diff --git a/libmproxy/filt.py b/libmproxy/filt.py index 6a0c3075..95076eed 100644 --- a/libmproxy/filt.py +++ b/libmproxy/filt.py @@ -198,7 +198,7 @@ class FDomain(_Rex): code = "d" help = "Domain" def __call__(self, f): - return bool(re.search(self.expr, f.request.host, re.IGNORECASE)) + return bool(re.search(self.expr, f.request.get_host(), re.IGNORECASE)) class FUrl(_Rex): diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 76ca4f47..40786631 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -2,16 +2,19 @@ This module provides more sophisticated flow tracking. These match requests with their responses, and provide filtering and interception facilities. """ -import hashlib, Cookie, cookielib, copy, re, urlparse, threading -import time, urllib -import tnetstring, filt, script, utils, encoding, proxy -from email.utils import parsedate_tz, formatdate, mktime_tz -from netlib import odict, http, certutils, wsgi -import controller, version +import base64 +import hashlib, Cookie, cookielib, re, threading +import os +from flask import request +import requests +import tnetstring, filt, script +from netlib import odict, wsgi +from .proxy import ClientConnection, ServerConnection # FIXME: remove circular dependency +import controller, version, protocol import app - -HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" -CONTENT_MISSING = 0 +from .protocol import KILL +from .protocol.http import HTTPResponse, CONTENT_MISSING +from .proxy import RequestReplayThread ODict = odict.ODict ODictCaseless = odict.ODictCaseless @@ -32,11 +35,11 @@ class AppRegistry: """ Returns an WSGIAdaptor instance if request matches an app, or None. """ - if (request.host, request.port) in self.apps: - return self.apps[(request.host, request.port)] + if (request.get_host(), request.get_port()) in self.apps: + return self.apps[(request.get_host(), request.get_port())] if "host" in request.headers: host = request.headers["host"][0] - return self.apps.get((host, request.port), None) + return self.apps.get((host, request.get_port()), None) class ReplaceHooks: @@ -143,769 +146,6 @@ class SetHeaders: f.request.headers.add(header, value) -class decoded(object): - """ - - A context manager that decodes a request, response or error, and then - re-encodes it with the same encoding after execution of the block. - - Example: - - with decoded(request): - request.content = request.content.replace("foo", "bar") - """ - def __init__(self, o): - self.o = o - ce = o.headers.get_first("content-encoding") - if ce in encoding.ENCODINGS: - self.ce = ce - else: - self.ce = None - - def __enter__(self): - if self.ce: - self.o.decode() - - def __exit__(self, type, value, tb): - if self.ce: - self.o.encode(self.ce) - - -class StateObject: - def __eq__(self, other): - try: - return self._get_state() == other._get_state() - except AttributeError: - return False - - -class HTTPMsg(StateObject): - def get_decoded_content(self): - """ - Returns the decoded content based on the current Content-Encoding header. - Doesn't change the message iteself or its headers. - """ - ce = self.headers.get_first("content-encoding") - if not self.content or ce not in encoding.ENCODINGS: - return self.content - return encoding.decode(ce, self.content) - - def decode(self): - """ - Decodes content based on the current Content-Encoding header, then - removes the header. If there is no Content-Encoding header, no - action is taken. - - Returns True if decoding succeeded, False otherwise. - """ - ce = self.headers.get_first("content-encoding") - if not self.content or ce not in encoding.ENCODINGS: - return False - data = encoding.decode( - ce, - self.content - ) - if data is None: - return False - self.content = data - del self.headers["content-encoding"] - return True - - def encode(self, e): - """ - Encodes content with the encoding e, where e is "gzip", "deflate" - or "identity". - """ - # FIXME: Error if there's an existing encoding header? - self.content = encoding.encode(e, self.content) - self.headers["content-encoding"] = [e] - - def size(self, **kwargs): - """ - Size in bytes of a fully rendered message, including headers and - HTTP lead-in. - """ - hl = len(self._assemble_head(**kwargs)) - if self.content: - return hl + len(self.content) - else: - return hl - - def get_content_type(self): - return self.headers.get_first("content-type") - - def get_transmitted_size(self): - # FIXME: this is inprecise in case chunking is used - # (we should count the chunking headers) - if not self.content: - return 0 - return len(self.content) - - -class Request(HTTPMsg): - """ - An HTTP request. - - Exposes the following attributes: - - client_conn: ClientConnect object, or None if this is a replay. - - headers: ODictCaseless object - - content: Content of the request, None, or CONTENT_MISSING if there - is content associated, but not present. CONTENT_MISSING evaluates - to False to make checking for the presence of content natural. - - scheme: URL scheme (http/https) - - host: Host portion of the URL - - port: Destination port - - path: Path portion of the URL - - timestamp_start: Seconds since the epoch signifying request transmission started - - method: HTTP method - - timestamp_end: Seconds since the epoch signifying request transmission ended - - tcp_setup_timestamp: Seconds since the epoch signifying remote TCP connection setup completion time - (or None, if request didn't results TCP setup) - - ssl_setup_timestamp: Seconds since the epoch signifying remote SSL encryption setup completion time - (or None, if request didn't results SSL setup) - - """ - def __init__( - self, client_conn, httpversion, host, port, - scheme, method, path, headers, content, timestamp_start=None, - timestamp_end=None, tcp_setup_timestamp=None, - ssl_setup_timestamp=None, ip=None): - assert isinstance(headers, ODictCaseless) - self.client_conn = client_conn - self.httpversion = httpversion - self.host, self.port, self.scheme = host, port, scheme - self.method, self.path, self.headers, self.content = method, path, headers, content - self.timestamp_start = timestamp_start or utils.timestamp() - self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start) - self.close = False - self.tcp_setup_timestamp = tcp_setup_timestamp - self.ssl_setup_timestamp = ssl_setup_timestamp - self.ip = ip - - # Have this request's cookies been modified by sticky cookies or auth? - self.stickycookie = False - self.stickyauth = False - - # Live attributes - not serialized - self.wfile, self.rfile = None, None - - def set_live(self, rfile, wfile): - self.wfile, self.rfile = wfile, rfile - - def is_live(self): - return bool(self.wfile) - - def anticache(self): - """ - Modifies this request to remove headers that might produce a cached - response. That is, we remove ETags and If-Modified-Since headers. - """ - delheaders = [ - "if-modified-since", - "if-none-match", - ] - for i in delheaders: - del self.headers[i] - - def anticomp(self): - """ - Modifies this request to remove headers that will compress the - resource's data. - """ - self.headers["accept-encoding"] = ["identity"] - - def constrain_encoding(self): - """ - Limits the permissible Accept-Encoding values, based on what we can - decode appropriately. - """ - if self.headers["accept-encoding"]: - self.headers["accept-encoding"] = [', '.join( - e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0] - )] - - def _set_replay(self): - self.client_conn = None - - def is_replay(self): - """ - Is this request a replay? - """ - if self.client_conn: - return False - else: - return True - - def _load_state(self, state): - if state["client_conn"]: - if self.client_conn: - self.client_conn._load_state(state["client_conn"]) - else: - self.client_conn = ClientConnect._from_state(state["client_conn"]) - else: - self.client_conn = None - self.host = state["host"] - self.port = state["port"] - self.scheme = state["scheme"] - self.method = state["method"] - self.path = state["path"] - self.headers = ODictCaseless._from_state(state["headers"]) - self.content = state["content"] - self.timestamp_start = state["timestamp_start"] - self.timestamp_end = state["timestamp_end"] - self.tcp_setup_timestamp = state["tcp_setup_timestamp"] - self.ssl_setup_timestamp = state["ssl_setup_timestamp"] - self.ip = state["ip"] - - def _get_state(self): - return dict( - client_conn = self.client_conn._get_state() if self.client_conn else None, - httpversion = self.httpversion, - host = self.host, - port = self.port, - scheme = self.scheme, - method = self.method, - path = self.path, - headers = self.headers._get_state(), - content = self.content, - timestamp_start = self.timestamp_start, - timestamp_end = self.timestamp_end, - tcp_setup_timestamp = self.tcp_setup_timestamp, - ssl_setup_timestamp = self.ssl_setup_timestamp, - ip = self.ip - ) - - @classmethod - def _from_state(klass, state): - return klass( - ClientConnect._from_state(state["client_conn"]), - tuple(state["httpversion"]), - str(state["host"]), - state["port"], - str(state["scheme"]), - str(state["method"]), - str(state["path"]), - ODictCaseless._from_state(state["headers"]), - state["content"], - state["timestamp_start"], - state["timestamp_end"], - state["tcp_setup_timestamp"], - state["ssl_setup_timestamp"], - state["ip"] - ) - - def __hash__(self): - return id(self) - - def copy(self): - c = copy.copy(self) - c.headers = self.headers.copy() - return c - - def get_form_urlencoded(self): - """ - Retrieves the URL-encoded form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.content and self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): - return ODict(utils.urldecode(self.content)) - return ODict([]) - - def set_form_urlencoded(self, odict): - """ - Sets the body to the URL-encoded form data, and adds the - appropriate content-type header. Note that this will destory the - existing body if there is one. - """ - # FIXME: If there's an existing content-type header indicating a - # url-encoded form, leave it alone. - self.headers["Content-Type"] = [HDR_FORM_URLENCODED] - self.content = utils.urlencode(odict.lst) - - def get_path_components(self): - """ - Returns the path components of the URL as a list of strings. - - Components are unquoted. - """ - _, _, path, _, _, _ = urlparse.urlparse(self.get_url()) - return [urllib.unquote(i) for i in path.split("/") if i] - - def set_path_components(self, lst): - """ - Takes a list of strings, and sets the path component of the URL. - - Components are quoted. - """ - lst = [urllib.quote(i, safe="") for i in lst] - path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url()) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) - - def get_query(self): - """ - Gets the request query string. Returns an ODict object. - """ - _, _, _, _, query, _ = urlparse.urlparse(self.get_url()) - if query: - return ODict(utils.urldecode(query)) - return ODict([]) - - def set_query(self, odict): - """ - Takes an ODict object, and sets the request query string. - """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url()) - query = utils.urlencode(odict.lst) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) - - def get_url(self, hostheader=False): - """ - Returns a URL string, constructed from the Request's URL compnents. - - If hostheader is True, we use the value specified in the request - Host header to construct the URL. - """ - if hostheader: - host = self.headers.get_first("host") or self.host - else: - host = self.host - host = host.encode("idna") - return utils.unparse_url(self.scheme, host, self.port, self.path).encode('ascii') - - def set_url(self, url): - """ - Parses a URL specification, and updates the Request's information - accordingly. - - Returns False if the URL was invalid, True if the request succeeded. - """ - parts = http.parse_url(url) - if not parts: - return False - self.scheme, self.host, self.port, self.path = parts - return True - - def get_cookies(self): - cookie_headers = self.headers.get("cookie") - if not cookie_headers: - return None - - cookies = [] - for header in cookie_headers: - pairs = [pair.partition("=") for pair in header.split(';')] - cookies.extend((pair[0],(pair[2],{})) for pair in pairs) - return dict(cookies) - - def get_header_size(self): - FMT = '%s %s HTTP/%s.%s\r\n%s\r\n' - assembled_header = FMT % ( - self.method, - self.path, - self.httpversion[0], - self.httpversion[1], - str(self.headers) - ) - return len(assembled_header) - - def _assemble_head(self, proxy=False): - FMT = '%s %s HTTP/%s.%s\r\n%s\r\n' - FMT_PROXY = '%s %s://%s:%s%s HTTP/%s.%s\r\n%s\r\n' - - headers = self.headers.copy() - utils.del_all( - headers, - [ - 'proxy-connection', - 'keep-alive', - 'connection', - 'transfer-encoding' - ] - ) - if not 'host' in headers: - headers["host"] = [utils.hostport(self.scheme, self.host, self.port)] - content = self.content - if content: - headers["Content-Length"] = [str(len(content))] - else: - content = "" - if self.close: - headers["connection"] = ["close"] - if not proxy: - return FMT % ( - self.method, - self.path, - self.httpversion[0], - self.httpversion[1], - str(headers) - ) - else: - return FMT_PROXY % ( - self.method, - self.scheme, - self.host, - self.port, - self.path, - self.httpversion[0], - self.httpversion[1], - str(headers) - ) - - def _assemble(self, _proxy = False): - """ - Assembles the request for transmission to the server. We make some - modifications to make sure interception works properly. - - Returns None if the request cannot be assembled. - """ - if self.content == CONTENT_MISSING: - return None - head = self._assemble_head(_proxy) - if self.content: - return head + self.content - else: - return head - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in both the headers - and the body of the request. Encoded content will be decoded before - replacement, and re-encoded afterwards. - - Returns the number of replacements made. - """ - with decoded(self): - self.content, c = utils.safe_subn(pattern, repl, self.content, *args, **kwargs) - self.path, pc = utils.safe_subn(pattern, repl, self.path, *args, **kwargs) - c += pc - c += self.headers.replace(pattern, repl, *args, **kwargs) - return c - - -class Response(HTTPMsg): - """ - An HTTP response. - - Exposes the following attributes: - - request: Request object. - - code: HTTP response code - - msg: HTTP response message - - headers: ODict object - - content: Content of the request, None, or CONTENT_MISSING if there - is content associated, but not present. CONTENT_MISSING evaluates - to False to make checking for the presence of content natural. - - timestamp_start: Seconds since the epoch signifying response transmission started - - timestamp_end: Seconds since the epoch signifying response transmission ended - """ - def __init__(self, request, httpversion, code, msg, headers, content, cert, timestamp_start=None, timestamp_end=None): - assert isinstance(headers, ODictCaseless) - self.request = request - self.httpversion, self.code, self.msg = httpversion, code, msg - self.headers, self.content = headers, content - self.cert = cert - self.timestamp_start = timestamp_start or utils.timestamp() - self.timestamp_end = timestamp_end or utils.timestamp() - self.replay = False - - def _refresh_cookie(self, c, delta): - """ - Takes a cookie string c and a time delta in seconds, and returns - a refreshed cookie string. - """ - c = Cookie.SimpleCookie(str(c)) - for i in c.values(): - if "expires" in i: - d = parsedate_tz(i["expires"]) - if d: - d = mktime_tz(d) + delta - i["expires"] = formatdate(d) - else: - # This can happen when the expires tag is invalid. - # reddit.com sends a an expires tag like this: "Thu, 31 Dec - # 2037 23:59:59 GMT", which is valid RFC 1123, but not - # strictly correct according tot he cookie spec. Browsers - # appear to parse this tolerantly - maybe we should too. - # For now, we just ignore this. - del i["expires"] - return c.output(header="").strip() - - def refresh(self, now=None): - """ - This fairly complex and heuristic function refreshes a server - response for replay. - - - It adjusts date, expires and last-modified headers. - - It adjusts cookie expiration. - """ - if not now: - now = time.time() - delta = now - self.timestamp_start - refresh_headers = [ - "date", - "expires", - "last-modified", - ] - for i in refresh_headers: - if i in self.headers: - d = parsedate_tz(self.headers[i][0]) - if d: - new = mktime_tz(d) + delta - self.headers[i] = [formatdate(new)] - c = [] - for i in self.headers["set-cookie"]: - c.append(self._refresh_cookie(i, delta)) - if c: - self.headers["set-cookie"] = c - - def _set_replay(self): - self.replay = True - - def is_replay(self): - """ - Is this response a replay? - """ - return self.replay - - def _load_state(self, state): - self.code = state["code"] - self.msg = state["msg"] - self.headers = ODictCaseless._from_state(state["headers"]) - self.content = state["content"] - self.timestamp_start = state["timestamp_start"] - self.timestamp_end = state["timestamp_end"] - self.cert = certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None - - def _get_state(self): - return dict( - httpversion = self.httpversion, - code = self.code, - msg = self.msg, - headers = self.headers._get_state(), - timestamp_start = self.timestamp_start, - timestamp_end = self.timestamp_end, - cert = self.cert.to_pem() if self.cert else None, - content = self.content, - ) - - @classmethod - def _from_state(klass, request, state): - return klass( - request, - state["httpversion"], - state["code"], - str(state["msg"]), - ODictCaseless._from_state(state["headers"]), - state["content"], - certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None, - state["timestamp_start"], - state["timestamp_end"], - ) - - def copy(self): - c = copy.copy(self) - c.headers = self.headers.copy() - return c - - def _assemble_head(self): - FMT = '%s\r\n%s\r\n' - headers = self.headers.copy() - utils.del_all( - headers, - ['proxy-connection', 'transfer-encoding'] - ) - if self.content: - headers["Content-Length"] = [str(len(self.content))] - elif 'Transfer-Encoding' in self.headers: - headers["Content-Length"] = ["0"] - proto = "HTTP/%s.%s %s %s"%(self.httpversion[0], self.httpversion[1], self.code, str(self.msg)) - data = (proto, str(headers)) - return FMT%data - - def _assemble(self): - """ - Assembles the response for transmission to the client. We make some - modifications to make sure interception works properly. - - Returns None if the request cannot be assembled. - """ - if self.content == CONTENT_MISSING: - return None - head = self._assemble_head() - if self.content: - return head + self.content - else: - return head - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in both the headers - and the body of the response. Encoded content will be decoded - before replacement, and re-encoded afterwards. - - Returns the number of replacements made. - """ - with decoded(self): - self.content, c = utils.safe_subn(pattern, repl, self.content, *args, **kwargs) - c += self.headers.replace(pattern, repl, *args, **kwargs) - return c - - def get_header_size(self): - FMT = '%s\r\n%s\r\n' - proto = "HTTP/%s.%s %s %s"%(self.httpversion[0], self.httpversion[1], self.code, str(self.msg)) - assembled_header = FMT % (proto, str(self.headers)) - return len(assembled_header) - - def get_cookies(self): - cookie_headers = self.headers.get("set-cookie") - if not cookie_headers: - return None - - cookies = [] - for header in cookie_headers: - pairs = [pair.partition("=") for pair in header.split(';')] - cookie_name = pairs[0][0] # the key of the first key/value pairs - cookie_value = pairs[0][2] # the value of the first key/value pairs - cookie_parameters = {key.strip().lower():value.strip() for key,sep,value in pairs[1:]} - cookies.append((cookie_name, (cookie_value, cookie_parameters))) - return dict(cookies) - - -class ClientDisconnect: - """ - A client disconnection event. - - Exposes the following attributes: - - client_conn: ClientConnect object. - """ - def __init__(self, client_conn): - self.client_conn = client_conn - - -class ClientConnect(StateObject): - """ - A single client connection. Each connection can result in multiple HTTP - Requests. - - Exposes the following attributes: - - address: (address, port) tuple, or None if the connection is replayed. - requestcount: Number of requests created by this client connection. - close: Is the client connection closed? - error: Error string or None. - """ - def __init__(self, address): - """ - address is an (address, port) tuple, or None if this connection has - been replayed from within mitmproxy. - """ - self.address = address - self.close = False - self.requestcount = 0 - self.error = None - - def __str__(self): - if self.address: - return "%s:%d"%(self.address[0],self.address[1]) - - def _load_state(self, state): - self.close = True - self.error = state["error"] - self.requestcount = state["requestcount"] - - def _get_state(self): - return dict( - address = list(self.address), - requestcount = self.requestcount, - error = self.error, - ) - - @classmethod - def _from_state(klass, state): - if state: - k = klass(state["address"]) - k._load_state(state) - return k - else: - return None - - def copy(self): - return copy.copy(self) - - -class Error(StateObject): - """ - An Error. - - This is distinct from an HTTP error response (say, a code 500), which - is represented by a normal Response object. This class is responsible - for indicating errors that fall outside of normal HTTP communications, - like interrupted connections, timeouts, protocol errors. - - Exposes the following attributes: - - request: Request object - msg: Message describing the error - timestamp: Seconds since the epoch - """ - def __init__(self, request, msg, timestamp=None): - self.request, self.msg = request, msg - self.timestamp = timestamp or utils.timestamp() - - def _load_state(self, state): - self.msg = state["msg"] - self.timestamp = state["timestamp"] - - def copy(self): - c = copy.copy(self) - return c - - def _get_state(self): - return dict( - msg = self.msg, - timestamp = self.timestamp, - ) - - @classmethod - def _from_state(klass, request, state): - return klass( - request, - state["msg"], - state["timestamp"], - ) - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in both the headers - and the body of the request. Returns the number of replacements - made. - - FIXME: Is replace useful on an Error object?? - """ - self.msg, c = utils.safe_subn(pattern, repl, self.msg, *args, **kwargs) - return c - - class ClientPlaybackState: def __init__(self, flows, exit): self.flows, self.exit = flows, exit @@ -934,7 +174,7 @@ class ClientPlaybackState: if self.flows and not self.current: n = self.flows.pop(0) n.request.reply = controller.DummyReply() - n.request.client_conn = None + n.client_conn = None self.current = master.handle_request(n.request) if not testing and not self.current.response: master.replay_request(self.current) # pragma: no cover @@ -997,7 +237,6 @@ class ServerPlaybackState: return l.pop(0) - class StickyCookieState: def __init__(self, flt): """ @@ -1011,8 +250,8 @@ class StickyCookieState: Returns a (domain, port, path) tuple. """ return ( - m["domain"] or f.request.host, - f.request.port, + m["domain"] or f.request.get_host(), + f.request.get_port(), m["path"] or "/" ) @@ -1030,7 +269,7 @@ class StickyCookieState: c = Cookie.SimpleCookie(str(i)) m = c.values()[0] k = self.ckey(m, f) - if self.domain_match(f.request.host, k[0]): + if self.domain_match(f.request.get_host(), k[0]): self.jar[self.ckey(m, f)] = m def handle_request(self, f): @@ -1038,8 +277,8 @@ class StickyCookieState: if f.match(self.flt): for i in self.jar.keys(): match = [ - self.domain_match(f.request.host, i[0]), - f.request.port == i[1], + self.domain_match(f.request.get_host(), i[0]), + f.request.get_port() == i[1], f.request.path.startswith(i[2]) ] if all(match): @@ -1058,177 +297,16 @@ class StickyAuthState: self.hosts = {} def handle_request(self, f): + host = f.request.get_host() if "authorization" in f.request.headers: - self.hosts[f.request.host] = f.request.headers["authorization"] + self.hosts[host] = f.request.headers["authorization"] elif f.match(self.flt): - if f.request.host in self.hosts: - f.request.headers["authorization"] = self.hosts[f.request.host] - - -class Flow: - """ - A Flow is a collection of objects representing a single HTTP - transaction. The main attributes are: - - request: Request object - response: Response object - error: Error object - - Note that it's possible for a Flow to have both a response and an error - object. This might happen, for instance, when a response was received - from the server, but there was an error sending it back to the client. - - The following additional attributes are exposed: - - intercepting: Is this flow currently being intercepted? - """ - def __init__(self, request): - self.request = request - self.response, self.error = None, None - self.intercepting = False - self._backup = None - - def copy(self): - rc = self.request.copy() - f = Flow(rc) - if self.response: - f.response = self.response.copy() - f.response.request = rc - if self.error: - f.error = self.error.copy() - f.error.request = rc - return f - - @classmethod - def _from_state(klass, state): - f = klass(None) - f._load_state(state) - return f - - def _get_state(self): - d = dict( - request = self.request._get_state() if self.request else None, - response = self.response._get_state() if self.response else None, - error = self.error._get_state() if self.error else None, - version = version.IVERSION - ) - return d - - def _load_state(self, state): - if self.request: - self.request._load_state(state["request"]) - else: - self.request = Request._from_state(state["request"]) - - if state["response"]: - if self.response: - self.response._load_state(state["response"]) - else: - self.response = Response._from_state(self.request, state["response"]) - else: - self.response = None - - if state["error"]: - if self.error: - self.error._load_state(state["error"]) - else: - self.error = Error._from_state(self.request, state["error"]) - else: - self.error = None - - def modified(self): - """ - Has this Flow been modified? - """ - # FIXME: Save a serialization in backup, compare current with - # backup to detect if flow has _really_ been modified. - if self._backup: - return True - else: - return False - - def backup(self, force=False): - """ - Save a backup of this Flow, which can be reverted to using a - call to .revert(). - """ - if not self._backup: - self._backup = self._get_state() - - def revert(self): - """ - Revert to the last backed up state. - """ - if self._backup: - self._load_state(self._backup) - self._backup = None - - def match(self, f): - """ - Match this flow against a compiled filter expression. Returns True - if matched, False if not. - - If f is a string, it will be compiled as a filter expression. If - the expression is invalid, ValueError is raised. - """ - if isinstance(f, basestring): - f = filt.parse(f) - if not f: - raise ValueError("Invalid filter expression.") - if f: - return f(self) - return True - - def kill(self, master): - """ - Kill this request. - """ - self.error = Error(self.request, "Connection killed") - self.error.reply = controller.DummyReply() - if self.request and not self.request.reply.acked: - self.request.reply(proxy.KILL) - elif self.response and not self.response.reply.acked: - self.response.reply(proxy.KILL) - master.handle_error(self.error) - self.intercepting = False - - def intercept(self): - """ - Intercept this Flow. Processing will stop until accept_intercept is - called. - """ - self.intercepting = True - - def accept_intercept(self): - """ - Continue with the flow - called after an intercept(). - """ - if self.request: - if not self.request.reply.acked: - self.request.reply() - elif self.response and not self.response.reply.acked: - self.response.reply() - self.intercepting = False - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in all parts of the - flow. Encoded content will be decoded before replacement, and - re-encoded afterwards. - - Returns the number of replacements made. - """ - c = self.request.replace(pattern, repl, *args, **kwargs) - if self.response: - c += self.response.replace(pattern, repl, *args, **kwargs) - if self.error: - c += self.error.replace(pattern, repl, *args, **kwargs) - return c + if host in self.hosts: + f.request.headers["authorization"] = self.hosts[host] class State(object): def __init__(self): - self._flow_map = {} self._flow_list = [] self.view = [] @@ -1242,7 +320,7 @@ class State(object): return self._limit_txt def flow_count(self): - return len(self._flow_map) + return len(self._flow_list) def index(self, f): return self._flow_list.index(f) @@ -1258,10 +336,8 @@ class State(object): """ Add a request to the state. Returns the matching flow. """ - f = Flow(req) + f = req.flow self._flow_list.append(f) - self._flow_map[req] = f - assert len(self._flow_list) == len(self._flow_map) if f.match(self._limit): self.view.append(f) return f @@ -1270,10 +346,9 @@ class State(object): """ Add a response to the state. Returns the matching flow. """ - f = self._flow_map.get(resp.request) + f = resp.flow if not f: return False - f.response = resp if f.match(self._limit) and not f in self.view: self.view.append(f) return f @@ -1283,18 +358,15 @@ class State(object): Add an error response to the state. Returns the matching flow, or None if there isn't one. """ - f = self._flow_map.get(err.request) + f = err.flow if not f: return None - f.error = err if f.match(self._limit) and not f in self.view: self.view.append(f) return f def load_flows(self, flows): self._flow_list.extend(flows) - for i in flows: - self._flow_map[i.request] = i self.recalculate_view() def set_limit(self, txt): @@ -1327,8 +399,6 @@ class State(object): self.view = self._flow_list[:] def delete_flow(self, f): - if f.request in self._flow_map: - del self._flow_map[f.request] self._flow_list.remove(f) if f in self.view: self.view.remove(f) @@ -1383,7 +453,28 @@ class FlowMaster(controller.Master): port ) else: - threading.Thread(target=app.mapp.run,kwargs={ + @app.mapp.before_request + def patch_environ(*args, **kwargs): + request.environ["mitmproxy.master"] = self + + # the only absurd way to shut down a flask/werkzeug server. + # http://flask.pocoo.org/snippets/67/ + shutdown_secret = base64.b32encode(os.urandom(30)) + + @app.mapp.route('/shutdown/<secret>') + def shutdown(secret): + if secret == shutdown_secret: + request.environ.get('werkzeug.server.shutdown')() + + # Workaround: Monkey-patch shutdown function to stop the app. + # Improve this when we switch flask werkzeug for something useful. + _shutdown = self.shutdown + def _shutdownwrap(): + _shutdown() + requests.get("http://%s:%s/shutdown/%s" % (host, port, shutdown_secret)) + self.shutdown = _shutdownwrap + + threading.Thread(target=app.mapp.run, kwargs={ "use_reloader": False, "host": host, "port": port}).start() @@ -1474,9 +565,8 @@ class FlowMaster(controller.Master): rflow = self.server_playback.next_flow(flow) if not rflow: return None - response = Response._from_state(flow.request, rflow.response._get_state()) - response._set_replay() - flow.response = response + response = HTTPResponse._from_state(rflow.response._get_state()) + response.is_replay = True if self.refresh_server_playback: response.refresh() flow.request.reply(response) @@ -1555,13 +645,13 @@ class FlowMaster(controller.Master): if f.request.content == CONTENT_MISSING: return "Can't replay request with missing content..." if f.request: - f.request._set_replay() + f.request.is_replay = True if f.request.content: f.request.headers["Content-Length"] = [str(len(f.request.content))] f.response = None f.error = None self.process_new_request(f) - rt = proxy.RequestReplayThread( + rt = RequestReplayThread( self.server.config, f, self.masterq, @@ -1597,13 +687,13 @@ class FlowMaster(controller.Master): return f def handle_request(self, r): - if r.is_live(): + if r.flow.client_conn and r.flow.client_conn.wfile: app = self.apps.get(r) if app: - err = app.serve(r, r.wfile, **{"mitmproxy.master": self}) + err = app.serve(r, r.flow.client_conn.wfile, **{"mitmproxy.master": self}) if err: self.add_event("Error in wsgi app. %s"%err, "error") - r.reply(proxy.KILL) + r.reply(KILL) return f = self.state.add_request(r) self.replacehooks.run(f) @@ -1676,7 +766,7 @@ class FlowReader: v = ".".join(str(i) for i in data["version"]) raise FlowReadError("Incompatible serialized data version: %s"%v) off = self.fo.tell() - yield Flow._from_state(data) + yield protocol.protocols[data["conntype"]]["flow"]._from_state(data) except ValueError, v: # Error is due to EOF if self.fo.tell() == off and self.fo.read() == '': diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py new file mode 100644 index 00000000..580d693c --- /dev/null +++ b/libmproxy/protocol/__init__.py @@ -0,0 +1,101 @@ +from ..proxy import ServerConnection, AddressPriority + +KILL = 0 # const for killed requests + +class ConnectionTypeChange(Exception): + """ + Gets raised if the connetion type has been changed (e.g. after HTTP/1.1 101 Switching Protocols). + It's up to the raising ProtocolHandler to specify the new conntype before raising the exception. + """ + pass + + +class ProtocolHandler(object): + def __init__(self, c): + self.c = c + """@type: libmproxy.proxy.ConnectionHandler""" + + def handle_messages(self): + """ + This method gets called if a client connection has been made. Depending on the proxy settings, + a server connection might already exist as well. + """ + raise NotImplementedError # pragma: nocover + + def handle_error(self, error): + """ + This method gets called should there be an uncaught exception during the connection. + This might happen outside of handle_messages, e.g. if the initial SSL handshake fails in transparent mode. + """ + raise error # pragma: nocover + + +class TemporaryServerChangeMixin(object): + """ + This mixin allows safe modification of the target server, + without any need to expose the ConnectionHandler to the Flow. + """ + + def change_server(self, address, ssl): + if address == self.c.server_conn.address(): + return + priority = AddressPriority.MANUALLY_CHANGED + + if self.c.server_conn.priority > priority: + self.log("Attempt to change server address, " + "but priority is too low (is: %s, got: %s)" % (self.server_conn.priority, priority)) + return + + self.log("Temporarily change server connection: %s:%s -> %s:%s" % ( + self.c.server_conn.address.host, + self.c.server_conn.address.port, + address.host, + address.port + )) + + if not hasattr(self, "_backup_server_conn"): + self._backup_server_conn = self.c.server_conn + self.c.server_conn = None + else: # This is at least the second temporary change. We can kill the current connection. + self.c.del_server_connection() + + self.c.set_server_address(address, priority) + if ssl: + self.establish_ssl(server=True) + + def restore_server(self): + if not hasattr(self, "_backup_server_conn"): + return + + self.log("Restore original server connection: %s:%s -> %s:%s" % ( + self.c.server_conn.address.host, + self.c.server_conn.address.port, + self._backup_server_conn.host, + self._backup_server_conn.port + )) + + self.c.del_server_connection() + self.c.server_conn = self._backup_server_conn + del self._backup_server_conn + +from . import http, tcp + +protocols = { + 'http': dict(handler=http.HTTPHandler, flow=http.HTTPFlow), + 'tcp': dict(handler=tcp.TCPHandler) +} # PyCharm type hinting behaves bad if this is a dict constructor... + + +def _handler(conntype, connection_handler): + if conntype in protocols: + return protocols[conntype]["handler"](connection_handler) + + raise NotImplementedError # pragma: nocover + + +def handle_messages(conntype, connection_handler): + return _handler(conntype, connection_handler).handle_messages() + + +def handle_error(conntype, connection_handler, error): + return _handler(conntype, connection_handler).handle_error(error)
\ No newline at end of file diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py new file mode 100644 index 00000000..95de6606 --- /dev/null +++ b/libmproxy/protocol/http.py @@ -0,0 +1,1046 @@ +import Cookie, urllib, urlparse, time, copy +from email.utils import parsedate_tz, formatdate, mktime_tz +import netlib.utils +from netlib import http, tcp, http_status, odict +from netlib.odict import ODict, ODictCaseless +from . import ProtocolHandler, ConnectionTypeChange, KILL, TemporaryServerChangeMixin +from .. import encoding, utils, version, filt, controller, stateobject +from ..proxy import ProxyError, AddressPriority, ServerConnection +from .primitives import Flow, Error + + +HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" +CONTENT_MISSING = 0 + + +def get_line(fp): + """ + Get a line, possibly preceded by a blank. + """ + line = fp.readline() + if line == "\r\n" or line == "\n": # Possible leftover from previous message + line = fp.readline() + if line == "": + raise tcp.NetLibDisconnect + return line + + +class decoded(object): + """ + A context manager that decodes a request or response, and then + re-encodes it with the same encoding after execution of the block. + + Example: + with decoded(request): + request.content = request.content.replace("foo", "bar") + """ + + def __init__(self, o): + self.o = o + ce = o.headers.get_first("content-encoding") + if ce in encoding.ENCODINGS: + self.ce = ce + else: + self.ce = None + + def __enter__(self): + if self.ce: + self.o.decode() + + def __exit__(self, type, value, tb): + if self.ce: + self.o.encode(self.ce) + + +class HTTPMessage(stateobject.SimpleStateObject): + def __init__(self, httpversion, headers, content, timestamp_start=None, timestamp_end=None): + self.httpversion = httpversion + self.headers = headers + """@type: ODictCaseless""" + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + self.flow = None # will usually be set by the flow backref mixin + """@type: HTTPFlow""" + + _stateobject_attributes = dict( + httpversion=tuple, + headers=ODictCaseless, + content=str, + timestamp_start=float, + timestamp_end=float + ) + + def get_decoded_content(self): + """ + Returns the decoded content based on the current Content-Encoding header. + Doesn't change the message iteself or its headers. + """ + ce = self.headers.get_first("content-encoding") + if not self.content or ce not in encoding.ENCODINGS: + return self.content + return encoding.decode(ce, self.content) + + def decode(self): + """ + Decodes content based on the current Content-Encoding header, then + removes the header. If there is no Content-Encoding header, no + action is taken. + + Returns True if decoding succeeded, False otherwise. + """ + ce = self.headers.get_first("content-encoding") + if not self.content or ce not in encoding.ENCODINGS: + return False + data = encoding.decode(ce, self.content) + if data is None: + return False + self.content = data + del self.headers["content-encoding"] + return True + + def encode(self, e): + """ + Encodes content with the encoding e, where e is "gzip", "deflate" + or "identity". + """ + # FIXME: Error if there's an existing encoding header? + self.content = encoding.encode(e, self.content) + self.headers["content-encoding"] = [e] + + def size(self, **kwargs): + """ + Size in bytes of a fully rendered message, including headers and + HTTP lead-in. + """ + hl = len(self._assemble_head(**kwargs)) + if self.content: + return hl + len(self.content) + else: + return hl + + def copy(self): + c = copy.copy(self) + c.headers = self.headers.copy() + return c + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both the headers + and the body of the message. Encoded content will be decoded + before replacement, and re-encoded afterwards. + + Returns the number of replacements made. + """ + with decoded(self): + self.content, c = utils.safe_subn(pattern, repl, self.content, *args, **kwargs) + c += self.headers.replace(pattern, repl, *args, **kwargs) + return c + + @classmethod + def from_stream(cls, rfile, include_content=True, body_size_limit=None): + """ + Parse an HTTP message from a file stream + """ + raise NotImplementedError # pragma: nocover + + def _assemble_first_line(self): + """ + Returns the assembled request/response line + """ + raise NotImplementedError # pragma: nocover + + def _assemble_headers(self): + """ + Returns the assembled headers + """ + raise NotImplementedError # pragma: nocover + + def _assemble_head(self): + """ + Returns the assembled request/response line plus headers + """ + raise NotImplementedError # pragma: nocover + + def _assemble(self): + """ + Returns the assembled request/response + """ + raise NotImplementedError # pragma: nocover + + +class HTTPRequest(HTTPMessage): + """ + An HTTP request. + + Exposes the following attributes: + + flow: Flow object the request belongs to + + headers: ODictCaseless object + + content: Content of the request, None, or CONTENT_MISSING if there + is content associated, but not present. CONTENT_MISSING evaluates + to False to make checking for the presence of content natural. + + form_in: The request form which mitmproxy has received. The following values are possible: + - origin (GET /index.html) + - absolute (GET http://example.com:80/index.html) + - authority-form (CONNECT example.com:443) + - asterisk-form (OPTIONS *) + Details: http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-25#section-5.3 + + form_out: The request form which mitmproxy has send out to the destination + + method: HTTP method + + scheme: URL scheme (http/https) (absolute-form only) + + host: Host portion of the URL (absolute-form and authority-form only) + + port: Destination port (absolute-form and authority-form only) + + path: Path portion of the URL (not present in authority-form) + + httpversion: HTTP version tuple + + timestamp_start: Timestamp indicating when request transmission started + + timestamp_end: Timestamp indicating when request transmission ended + """ + def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content, + timestamp_start=None, timestamp_end=None, form_out=None): + assert isinstance(headers, ODictCaseless) or not headers + HTTPMessage.__init__(self, httpversion, headers, content, timestamp_start, timestamp_end) + + self.form_in = form_in + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.httpversion = httpversion + self.form_out = form_out or form_in + + # Have this request's cookies been modified by sticky cookies or auth? + self.stickycookie = False + self.stickyauth = False + # Is this request replayed? + self.is_replay = False + + _stateobject_attributes = HTTPMessage._stateobject_attributes.copy() + _stateobject_attributes.update( + form_in=str, + method=str, + scheme=str, + host=str, + port=int, + path=str, + form_out=str + ) + + @classmethod + def _from_state(cls, state): + f = cls(None, None, None, None, None, None, None, None, None, None, None) + f._load_state(state) + return f + + @classmethod + def from_stream(cls, rfile, include_content=True, body_size_limit=None): + """ + Parse an HTTP request from a file stream + """ + httpversion, host, port, scheme, method, path, headers, content, timestamp_start, timestamp_end \ + = None, None, None, None, None, None, None, None, None, None + + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + request_line = get_line(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + timestamp_start = rfile.first_byte_timestamp + else: + timestamp_start = utils.timestamp() + + request_line_parts = http.parse_init(request_line) + if not request_line_parts: + raise http.HttpError(400, "Bad HTTP request line: %s" % repr(request_line)) + method, path, httpversion = request_line_parts + + if path == '*': + form_in = "asterisk" + elif path.startswith("/"): + form_in = "origin" + if not netlib.utils.isascii(path): + raise http.HttpError(400, "Bad HTTP request line: %s" % repr(request_line)) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = http.parse_init_connect(request_line) + if not r: + raise http.HttpError(400, "Bad HTTP request line: %s" % repr(request_line)) + host, port, _ = r + path = None + else: + form_in = "absolute" + r = http.parse_init_proxy(request_line) + if not r: + raise http.HttpError(400, "Bad HTTP request line: %s" % repr(request_line)) + _, scheme, host, port, path, _ = r + + headers = http.read_headers(rfile) + if headers is None: + raise http.HttpError(400, "Invalid headers") + + if include_content: + content = http.read_http_body(rfile, headers, body_size_limit, True) + timestamp_end = utils.timestamp() + + return HTTPRequest(form_in, method, scheme, host, port, path, httpversion, headers, content, + timestamp_start, timestamp_end) + + def _assemble_first_line(self, form=None): + form = form or self.form_out + + if form == "asterisk" or \ + form == "origin": + request_line = '%s %s HTTP/%s.%s' % (self.method, self.path, self.httpversion[0], self.httpversion[1]) + elif form == "authority": + request_line = '%s %s:%s HTTP/%s.%s' % (self.method, self.host, self.port, + self.httpversion[0], self.httpversion[1]) + elif form == "absolute": + request_line = '%s %s://%s:%s%s HTTP/%s.%s' % \ + (self.method, self.scheme, self.host, self.port, self.path, + self.httpversion[0], self.httpversion[1]) + else: + raise http.HttpError(400, "Invalid request form") + return request_line + + def _assemble_headers(self): + headers = self.headers.copy() + utils.del_all( + headers, + [ + 'Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding' + ] + ) + if not 'host' in headers: + headers["Host"] = [utils.hostport(self.scheme, + self.host or self.flow.server_conn.address.host, + self.port or self.flow.server_conn.address.port)] + + if self.content: + headers["Content-Length"] = [str(len(self.content))] + elif 'Transfer-Encoding' in self.headers: # content-length for e.g. chuncked transfer-encoding with no content + headers["Content-Length"] = ["0"] + + return str(headers) + + def _assemble_head(self, form=None): + return "%s\r\n%s\r\n" % (self._assemble_first_line(form), self._assemble_headers()) + + def _assemble(self, form=None): + """ + Assembles the request for transmission to the server. We make some + modifications to make sure interception works properly. + + Raises an Exception if the request cannot be assembled. + """ + if self.content == CONTENT_MISSING: + raise ProxyError(502, "Cannot assemble flow with CONTENT_MISSING") + head = self._assemble_head(form) + if self.content: + return head + self.content + else: + return head + + def __hash__(self): + return id(self) + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + del self.headers[i] + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = ["identity"] + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + if self.headers["accept-encoding"]: + self.headers["accept-encoding"] = [', '.join( + e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0] + )] + + def get_form_urlencoded(self): + """ + Retrieves the URL-encoded form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.content and self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): + return ODict(utils.urldecode(self.content)) + return ODict([]) + + def set_form_urlencoded(self, odict): + """ + Sets the body to the URL-encoded form data, and adds the + appropriate content-type header. Note that this will destory the + existing body if there is one. + """ + # FIXME: If there's an existing content-type header indicating a + # url-encoded form, leave it alone. + self.headers["Content-Type"] = [HDR_FORM_URLENCODED] + self.content = utils.urlencode(odict.lst) + + def get_path_components(self): + """ + Returns the path components of the URL as a list of strings. + + Components are unquoted. + """ + _, _, path, _, _, _ = urlparse.urlparse(self.get_url()) + return [urllib.unquote(i) for i in path.split("/") if i] + + def set_path_components(self, lst): + """ + Takes a list of strings, and sets the path component of the URL. + + Components are quoted. + """ + lst = [urllib.quote(i, safe="") for i in lst] + path = "/" + "/".join(lst) + scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url()) + self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) + + def get_query(self): + """ + Gets the request query string. Returns an ODict object. + """ + _, _, _, _, query, _ = urlparse.urlparse(self.get_url()) + if query: + return ODict(utils.urldecode(query)) + return ODict([]) + + def set_query(self, odict): + """ + Takes an ODict object, and sets the request query string. + """ + scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url()) + query = utils.urlencode(odict.lst) + self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) + + def get_host(self, hostheader=False): + """ + Heuristic to get the host of the request. + The host is not necessarily equal to the TCP destination of the request, + for example on a transparently proxified absolute-form request to an upstream HTTP proxy. + If hostheader is set to True, the Host: header will be used as additional (and preferred) data source. + """ + host = None + if hostheader: + host = self.headers.get_first("host") + if not host: + if self.host: + host = self.host + else: + host = self.flow.server_conn.address.host + host = host.encode("idna") + return host + + def get_scheme(self): + """ + Returns the request port, either from the request itself or from the flow's server connection + """ + if self.scheme: + return self.scheme + return "https" if self.flow.server_conn.ssl_established else "http" + + def get_port(self): + """ + Returns the request port, either from the request itself or from the flow's server connection + """ + if self.port: + return self.port + return self.flow.server_conn.address.port + + def get_url(self, hostheader=False): + """ + Returns a URL string, constructed from the Request's URL components. + + If hostheader is True, we use the value specified in the request + Host header to construct the URL. + """ + return utils.unparse_url(self.get_scheme(), + self.get_host(hostheader), + self.get_port(), + self.path).encode('ascii') + + def set_url(self, url): + """ + Parses a URL specification, and updates the Request's information + accordingly. + + Returns False if the URL was invalid, True if the request succeeded. + """ + parts = http.parse_url(url) + if not parts: + return False + scheme, host, port, path = parts + is_ssl = (True if scheme == "https" else False) + + self.path = path + + if host != self.get_host() or port != self.get_port(): + if self.flow.change_server: + self.flow.change_server((host, port), ssl=is_ssl) + else: + # There's not live server connection, we're just changing the attributes here. + self.flow.server_conn = ServerConnection((host, port), AddressPriority.MANUALLY_CHANGED) + self.flow.server_conn.ssl_established = is_ssl + + # If this is an absolute request, replace the attributes on the request object as well. + if self.host: + self.host = host + if self.port: + self.port = port + if self.scheme: + self.scheme = scheme + + return True + + def get_cookies(self): + cookie_headers = self.headers.get("cookie") + if not cookie_headers: + return None + + cookies = [] + for header in cookie_headers: + pairs = [pair.partition("=") for pair in header.split(';')] + cookies.extend((pair[0], (pair[2], {})) for pair in pairs) + return dict(cookies) + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in the headers, the request path + and the body of the request. Encoded content will be decoded before + replacement, and re-encoded afterwards. + + Returns the number of replacements made. + """ + c = HTTPMessage.replace(self, pattern, repl, *args, **kwargs) + self.path, pc = utils.safe_subn(pattern, repl, self.path, *args, **kwargs) + c += pc + return c + + +class HTTPResponse(HTTPMessage): + """ + An HTTP response. + + Exposes the following attributes: + + flow: Flow object the request belongs to + + code: HTTP response code + + msg: HTTP response message + + headers: ODict object + + content: Content of the request, None, or CONTENT_MISSING if there + is content associated, but not present. CONTENT_MISSING evaluates + to False to make checking for the presence of content natural. + + httpversion: HTTP version tuple + + timestamp_start: Timestamp indicating when request transmission started + + timestamp_end: Timestamp indicating when request transmission ended + """ + def __init__(self, httpversion, code, msg, headers, content, timestamp_start=None, timestamp_end=None): + assert isinstance(headers, ODictCaseless) or headers is None + HTTPMessage.__init__(self, httpversion, headers, content, timestamp_start, timestamp_end) + + self.code = code + self.msg = msg + + # Is this request replayed? + self.is_replay = False + + _stateobject_attributes = HTTPMessage._stateobject_attributes.copy() + _stateobject_attributes.update( + code=int, + msg=str + ) + + @classmethod + def _from_state(cls, state): + f = cls(None, None, None, None, None) + f._load_state(state) + return f + + @classmethod + def from_stream(cls, rfile, request_method, include_content=True, body_size_limit=None): + """ + Parse an HTTP response from a file stream + """ + if not include_content: + raise NotImplementedError # pragma: nocover + + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + httpversion, code, msg, headers, content = http.read_response( + rfile, + request_method, + body_size_limit) + + if hasattr(rfile, "first_byte_timestamp"): + timestamp_start = rfile.first_byte_timestamp + else: + timestamp_start = utils.timestamp() + + timestamp_end = utils.timestamp() + return HTTPResponse(httpversion, code, msg, headers, content, timestamp_start, timestamp_end) + + def _assemble_first_line(self): + return 'HTTP/%s.%s %s %s' % (self.httpversion[0], self.httpversion[1], self.code, self.msg) + + def _assemble_headers(self): + headers = self.headers.copy() + utils.del_all( + headers, + [ + 'Proxy-Connection', + 'Transfer-Encoding' + ] + ) + if self.content: + headers["Content-Length"] = [str(len(self.content))] + elif 'Transfer-Encoding' in self.headers: # add content-length for chuncked transfer-encoding with no content + headers["Content-Length"] = ["0"] + + return str(headers) + + def _assemble_head(self): + return '%s\r\n%s\r\n' % (self._assemble_first_line(), self._assemble_headers()) + + def _assemble(self): + """ + Assembles the response for transmission to the client. We make some + modifications to make sure interception works properly. + + Raises an Exception if the request cannot be assembled. + """ + if self.content == CONTENT_MISSING: + raise ProxyError(502, "Cannot assemble flow with CONTENT_MISSING") + head = self._assemble_head() + if self.content: + return head + self.content + else: + return head + + def _refresh_cookie(self, c, delta): + """ + Takes a cookie string c and a time delta in seconds, and returns + a refreshed cookie string. + """ + c = Cookie.SimpleCookie(str(c)) + for i in c.values(): + if "expires" in i: + d = parsedate_tz(i["expires"]) + if d: + d = mktime_tz(d) + delta + i["expires"] = formatdate(d) + else: + # This can happen when the expires tag is invalid. + # reddit.com sends a an expires tag like this: "Thu, 31 Dec + # 2037 23:59:59 GMT", which is valid RFC 1123, but not + # strictly correct according to the cookie spec. Browsers + # appear to parse this tolerantly - maybe we should too. + # For now, we just ignore this. + del i["expires"] + return c.output(header="").strip() + + def refresh(self, now=None): + """ + This fairly complex and heuristic function refreshes a server + response for replay. + + - It adjusts date, expires and last-modified headers. + - It adjusts cookie expiration. + """ + if not now: + now = time.time() + delta = now - self.timestamp_start + refresh_headers = [ + "date", + "expires", + "last-modified", + ] + for i in refresh_headers: + if i in self.headers: + d = parsedate_tz(self.headers[i][0]) + if d: + new = mktime_tz(d) + delta + self.headers[i] = [formatdate(new)] + c = [] + for i in self.headers["set-cookie"]: + c.append(self._refresh_cookie(i, delta)) + if c: + self.headers["set-cookie"] = c + + def get_cookies(self): + cookie_headers = self.headers.get("set-cookie") + if not cookie_headers: + return None + + cookies = [] + for header in cookie_headers: + pairs = [pair.partition("=") for pair in header.split(';')] + cookie_name = pairs[0][0] # the key of the first key/value pairs + cookie_value = pairs[0][2] # the value of the first key/value pairs + cookie_parameters = {key.strip().lower(): value.strip() for key, sep, value in pairs[1:]} + cookies.append((cookie_name, (cookie_value, cookie_parameters))) + return dict(cookies) + + +class HTTPFlow(Flow): + """ + A Flow is a collection of objects representing a single HTTP + transaction. The main attributes are: + + request: HTTPRequest object + response: HTTPResponse object + error: Error object + + Note that it's possible for a Flow to have both a response and an error + object. This might happen, for instance, when a response was received + from the server, but there was an error sending it back to the client. + + The following additional attributes are exposed: + + intercepting: Is this flow currently being intercepted? + """ + def __init__(self, client_conn, server_conn, change_server=None): + Flow.__init__(self, "http", client_conn, server_conn) + self.request = None + """@type: HTTPRequest""" + self.response = None + """@type: HTTPResponse""" + self.change_server = change_server # Used by flow.request.set_url to change the server address + + self.intercepting = False # FIXME: Should that rather be an attribute of Flow? + + _backrefattr = Flow._backrefattr + ("request", "response") + + _stateobject_attributes = Flow._stateobject_attributes.copy() + _stateobject_attributes.update( + request=HTTPRequest, + response=HTTPResponse + ) + + @classmethod + def _from_state(cls, state): + f = cls(None, None) + f._load_state(state) + return f + + def copy(self): + f = super(HTTPFlow, self).copy() + if self.request: + f.request = self.request.copy() + if self.response: + f.response = self.response.copy() + return f + + def match(self, f): + """ + Match this flow against a compiled filter expression. Returns True + if matched, False if not. + + If f is a string, it will be compiled as a filter expression. If + the expression is invalid, ValueError is raised. + """ + if isinstance(f, basestring): + f = filt.parse(f) + if not f: + raise ValueError("Invalid filter expression.") + if f: + return f(self) + return True + + def kill(self, master): + """ + Kill this request. + """ + self.error = Error("Connection killed") + self.error.reply = controller.DummyReply() + if self.request and not self.request.reply.acked: + self.request.reply(KILL) + elif self.response and not self.response.reply.acked: + self.response.reply(KILL) + master.handle_error(self.error) + self.intercepting = False + + def intercept(self): + """ + Intercept this Flow. Processing will stop until accept_intercept is + called. + """ + self.intercepting = True + + def accept_intercept(self): + """ + Continue with the flow - called after an intercept(). + """ + if self.request: + if not self.request.reply.acked: + self.request.reply() + elif self.response and not self.response.reply.acked: + self.response.reply() + self.intercepting = False + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both request and response of the + flow. Encoded content will be decoded before replacement, and + re-encoded afterwards. + + Returns the number of replacements made. + """ + c = self.request.replace(pattern, repl, *args, **kwargs) + if self.response: + c += self.response.replace(pattern, repl, *args, **kwargs) + return c + + +class HttpAuthenticationError(Exception): + def __init__(self, auth_headers=None): + self.auth_headers = auth_headers + + def __str__(self): + return "HttpAuthenticationError" + + +class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin): + + def handle_messages(self): + while self.handle_flow(): + pass + self.c.close = True + + def get_response_from_server(self, request): + self.c.establish_server_connection() + request_raw = request._assemble() + + for i in range(2): + try: + self.c.server_conn.send(request_raw) + return HTTPResponse.from_stream(self.c.server_conn.rfile, request.method, + body_size_limit=self.c.config.body_size_limit) + except (tcp.NetLibDisconnect, http.HttpErrorConnClosed), v: + self.c.log("error in server communication: %s" % str(v)) + if i < 1: + # In any case, we try to reconnect at least once. + # This is necessary because it might be possible that we already initiated an upstream connection + # after clientconnect that has already been expired, e.g consider the following event log: + # > clientconnect (transparent mode destination known) + # > serverconnect + # > read n% of large request + # > server detects timeout, disconnects + # > read (100-n)% of large request + # > send large request upstream + self.c.server_reconnect() + else: + raise v + + def handle_flow(self): + flow = HTTPFlow(self.c.client_conn, self.c.server_conn, self.change_server) + try: + flow.request = HTTPRequest.from_stream(self.c.client_conn.rfile, + body_size_limit=self.c.config.body_size_limit) + self.c.log("request", [flow.request._assemble_first_line(flow.request.form_in)]) + self.process_request(flow.request) + + request_reply = self.c.channel.ask("request", flow.request) + flow.server_conn = self.c.server_conn + + if request_reply is None or request_reply == KILL: + return False + + if isinstance(request_reply, HTTPResponse): + flow.response = request_reply + else: + flow.response = self.get_response_from_server(flow.request) + + flow.server_conn = self.c.server_conn # no further manipulation of self.c.server_conn beyond this point + # we can safely set it as the final attribute value here. + + self.c.log("response", [flow.response._assemble_first_line()]) + response_reply = self.c.channel.ask("response", flow.response) + if response_reply is None or response_reply == KILL: + return False + + self.c.client_conn.send(flow.response._assemble()) + flow.timestamp_end = utils.timestamp() + + if (http.connection_close(flow.request.httpversion, flow.request.headers) or + http.connection_close(flow.response.httpversion, flow.response.headers)): + return False + + if flow.request.form_in == "authority": + self.ssl_upgrade() + + self.restore_server() # If the user has changed the target server on this connection, + # restore the original target server + return True + except (HttpAuthenticationError, http.HttpError, ProxyError, tcp.NetLibError), e: + self.handle_error(e, flow) + return False + + def handle_error(self, error, flow=None): + code, message, headers = None, None, None + if isinstance(error, HttpAuthenticationError): + code = 407 + message = "Proxy Authentication Required" + headers = error.auth_headers + elif isinstance(error, (http.HttpError, ProxyError)): + code = error.code + message = error.msg + elif isinstance(error, tcp.NetLibError): + code = 502 + message = error.message or error.__class__ + + if code: + err = "%s: %s" % (code, message) + else: + err = error.__class__ + + self.c.log("error: %s" % err) + + if flow: + flow.error = Error(err) + if flow.request and not flow.response: + # FIXME: no flows without request or with both request and response at the moement. + self.c.channel.ask("error", flow.error) + else: + pass # FIXME: Is there any use case for persisting errors that occur outside of flows? + + if code: + try: + self.send_error(code, message, headers) + except: + pass + + def send_error(self, code, message, headers): + response = http_status.RESPONSES.get(code, "Unknown") + html_content = '<html><head>\n<title>%d %s</title>\n</head>\n<body>\n%s\n</body>\n</html>' % \ + (code, response, message) + self.c.client_conn.wfile.write("HTTP/1.1 %s %s\r\n" % (code, response)) + self.c.client_conn.wfile.write("Server: %s\r\n" % self.c.server_version) + self.c.client_conn.wfile.write("Content-type: text/html\r\n") + self.c.client_conn.wfile.write("Content-Length: %d\r\n" % len(html_content)) + if headers: + for key, value in headers.items(): + self.c.client_conn.wfile.write("%s: %s\r\n" % (key, value)) + self.c.client_conn.wfile.write("Connection: close\r\n") + self.c.client_conn.wfile.write("\r\n") + self.c.client_conn.wfile.write(html_content) + self.c.client_conn.wfile.flush() + + def hook_reconnect(self, upstream_request): + self.c.log("Hook reconnect function") + original_reconnect_func = self.c.server_reconnect + + def reconnect_http_proxy(): + self.c.log("Hooked reconnect function") + self.c.log("Hook: Run original reconnect") + original_reconnect_func(no_ssl=True) + self.c.log("Hook: Write CONNECT request to upstream proxy", [upstream_request._assemble_first_line()]) + self.c.server_conn.send(upstream_request._assemble()) + self.c.log("Hook: Read answer to CONNECT request from proxy") + resp = HTTPResponse.from_stream(self.c.server_conn.rfile, upstream_request.method) + if resp.code != 200: + raise ProxyError(resp.code, + "Cannot reestablish SSL connection with upstream proxy: \r\n" + str(resp.headers)) + self.c.log("Hook: Establish SSL with upstream proxy") + self.c.establish_ssl(server=True) + + self.c.server_reconnect = reconnect_http_proxy + + def ssl_upgrade(self): + """ + Upgrade the connection to SSL after an authority (CONNECT) request has been made. + If the authority request has been forwarded upstream (because we have another proxy server there), + money-patch the ConnectionHandler.server_reconnect function to resend the request on reconnect. + + This isn't particular beautiful code, but it isolates this rare edge-case from the + protocol-agnostic ConnectionHandler + """ + self.c.log("Received CONNECT request. Upgrading to SSL...") + self.c.mode = "transparent" + self.c.determine_conntype() + self.c.establish_ssl(server=True, client=True) + self.c.log("Upgrade to SSL completed.") + raise ConnectionTypeChange + + def process_request(self, request): + if self.c.mode == "regular": + self.authenticate(request) + if request.form_in == "authority" and self.c.client_conn.ssl_established: + raise http.HttpError(502, "Must not CONNECT on already encrypted connection") + + # If we have a CONNECT request, we might need to intercept + if request.form_in == "authority": + directly_addressed_at_mitmproxy = (self.c.mode == "regular" and not self.c.config.forward_proxy) + if directly_addressed_at_mitmproxy: + self.c.set_server_address((request.host, request.port), AddressPriority.FROM_PROTOCOL) + request.flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow + self.c.client_conn.wfile.write( + 'HTTP/1.1 200 Connection established\r\n' + + ('Proxy-agent: %s\r\n' % self.c.server_version) + + '\r\n' + ) + self.c.client_conn.wfile.flush() + self.ssl_upgrade() # raises ConnectionTypeChange exception + + if self.c.mode == "regular": + if request.form_in == "authority": # forward mode + self.hook_reconnect(request) + elif request.form_in == "absolute": + if request.scheme != "http": + raise http.HttpError(400, "Invalid Request") + if not self.c.config.forward_proxy: + request.form_out = "origin" + self.c.set_server_address((request.host, request.port), AddressPriority.FROM_PROTOCOL) + request.flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow + else: + raise http.HttpError(400, "Invalid request form (absolute-form or authority-form required)") + + def authenticate(self, request): + if self.c.config.authenticator: + if self.c.config.authenticator.authenticate(request.headers): + self.c.config.authenticator.clean(request.headers) + else: + raise HttpAuthenticationError(self.c.config.authenticator.auth_challenge_headers()) + return request.headers
\ No newline at end of file diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py new file mode 100644 index 00000000..90191eeb --- /dev/null +++ b/libmproxy/protocol/primitives.py @@ -0,0 +1,130 @@ +from .. import stateobject, utils, version +from ..proxy import ServerConnection, ClientConnection +import copy + + +class BackreferenceMixin(object): + """ + If an attribute from the _backrefattr tuple is set, + this mixin sets a reference back on the attribute object. + Example: + e = Error() + f = Flow() + f.error = e + assert f is e.flow + """ + _backrefattr = tuple() + + def __setattr__(self, key, value): + super(BackreferenceMixin, self).__setattr__(key, value) + if key in self._backrefattr and value is not None: + setattr(value, self._backrefname, self) + + +class Error(stateobject.SimpleStateObject): + """ + An Error. + + This is distinct from an HTTP error response (say, a code 500), which + is represented by a normal Response object. This class is responsible + for indicating errors that fall outside of normal HTTP communications, + like interrupted connections, timeouts, protocol errors. + + Exposes the following attributes: + + flow: Flow object + msg: Message describing the error + timestamp: Seconds since the epoch + """ + def __init__(self, msg, timestamp=None): + """ + @type msg: str + @type timestamp: float + """ + self.flow = None # will usually be set by the flow backref mixin + self.msg = msg + self.timestamp = timestamp or utils.timestamp() + + _stateobject_attributes = dict( + msg=str, + timestamp=float + ) + + def __str__(self): + return self.msg + + @classmethod + def _from_state(cls, state): + f = cls(None) # the default implementation assumes an empty constructor. Override accordingly. + f._load_state(state) + return f + + def copy(self): + c = copy.copy(self) + return c + + +class Flow(stateobject.SimpleStateObject, BackreferenceMixin): + def __init__(self, conntype, client_conn, server_conn): + self.conntype = conntype + self.client_conn = client_conn + """@type: ClientConnection""" + self.server_conn = server_conn + """@type: ServerConnection""" + + self.error = None + """@type: Error""" + self._backup = None + + _backrefattr = ("error",) + _backrefname = "flow" + + _stateobject_attributes = dict( + error=Error, + client_conn=ClientConnection, + server_conn=ServerConnection, + conntype=str + ) + + def _get_state(self): + d = super(Flow, self)._get_state() + d.update(version=version.IVERSION) + return d + + def __eq__(self, other): + return self is other + + def copy(self): + f = copy.copy(self) + + f.client_conn = self.client_conn.copy() + f.server_conn = self.server_conn.copy() + + if self.error: + f.error = self.error.copy() + return f + + def modified(self): + """ + Has this Flow been modified? + """ + if self._backup: + return self._backup != self._get_state() + else: + return False + + def backup(self, force=False): + """ + Save a backup of this Flow, which can be reverted to using a + call to .revert(). + """ + if not self._backup: + self._backup = self._get_state() + + def revert(self): + """ + Revert to the last backed up state. + """ + if self._backup: + self._load_state(self._backup) + self._backup = None
\ No newline at end of file diff --git a/libmproxy/protocol/tcp.py b/libmproxy/protocol/tcp.py new file mode 100644 index 00000000..406a6f7b --- /dev/null +++ b/libmproxy/protocol/tcp.py @@ -0,0 +1,59 @@ +from . import ProtocolHandler +import select, socket +from cStringIO import StringIO + + +class TCPHandler(ProtocolHandler): + """ + TCPHandler acts as a generic TCP forwarder. + Data will be .log()ed, but not stored any further. + """ + def handle_messages(self): + conns = [self.c.client_conn.rfile, self.c.server_conn.rfile] + while not self.c.close: + r, _, _ = select.select(conns, [], [], 10) + for rfile in r: + if self.c.client_conn.rfile == rfile: + src, dst = self.c.client_conn, self.c.server_conn + direction = "-> tcp ->" + dst_str = "%s:%s" % self.c.server_conn.address()[:2] + else: + dst, src = self.c.client_conn, self.c.server_conn + direction = "<- tcp <-" + dst_str = "client" + + data = StringIO() + while range(4096): + # Do non-blocking select() to see if there is further data on in the buffer. + r, _, _ = select.select([rfile], [], [], 0) + if len(r): + d = rfile.read(1) + if d == "": # connection closed + break + data.write(d) + + """ + OpenSSL Connections have an internal buffer that might contain data altough everything is read + from the socket. Thankfully, connection.pending() returns the amount of bytes in this buffer, + so we can read it completely at once. + """ + if src.ssl_established: + data.write(rfile.read(src.connection.pending())) + else: # no data left, but not closed yet + break + data = data.getvalue() + + if data == "": # no data received, rfile is closed + self.c.log("Close writing connection to %s" % dst_str) + conns.remove(rfile) + if dst.ssl_established: + dst.connection.shutdown() + else: + dst.connection.shutdown(socket.SHUT_WR) + if len(conns) == 0: + self.c.close = True + break + + self.c.log("%s %s\r\n%s" % (direction, dst_str,data)) + dst.wfile.write(data) + dst.wfile.flush()
\ No newline at end of file diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 0d53aef8..b6480822 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -1,13 +1,26 @@ -import os, socket, time -import threading +import os, socket, time, threading, copy from OpenSSL import SSL -from netlib import tcp, http, certutils, http_status, http_auth -import utils, flow, version, platform, controller - +from netlib import tcp, http, certutils, http_auth +import utils, version, platform, controller, stateobject TRANSPARENT_SSL_PORTS = [443, 8443] -KILL = 0 + +class AddressPriority(object): + """ + Enum that signifies the priority of the given address when choosing the destination host. + Higher is better (None < i) + """ + FORCE = 5 + """forward mode""" + MANUALLY_CHANGED = 4 + """user changed the target address in the ui""" + FROM_SETTINGS = 3 + """reverse proxy mode""" + FROM_CONNECTION = 2 + """derived from transparent resolver""" + FROM_PROTOCOL = 1 + """derived from protocol (e.g. absolute-form http requests)""" class ProxyError(Exception): @@ -15,7 +28,7 @@ class ProxyError(Exception): self.code, self.msg, self.headers = code, msg, headers def __str__(self): - return "ProxyError(%s, %s)"%(self.code, self.msg) + return "ProxyError(%s, %s)" % (self.code, self.msg) class Log: @@ -24,7 +37,8 @@ class Log: class ProxyConfig: - def __init__(self, certfile = None, cacert = None, clientcerts = None, no_upstream_cert=False, body_size_limit = None, reverse_proxy=None, forward_proxy=None, transparent_proxy=None, authenticator=None): + def __init__(self, certfile=None, cacert=None, clientcerts=None, no_upstream_cert=False, body_size_limit=None, + reverse_proxy=None, forward_proxy=None, transparent_proxy=None, authenticator=None): self.certfile = certfile self.cacert = cacert self.clientcerts = clientcerts @@ -37,49 +51,146 @@ class ProxyConfig: self.certstore = certutils.CertStore() -class ServerConnection(tcp.TCPClient): - def __init__(self, config, scheme, host, port, sni): - tcp.TCPClient.__init__(self, host, port) - self.config = config - self.scheme, self.sni = scheme, sni - self.requestcount = 0 - self.tcp_setup_timestamp = None - self.ssl_setup_timestamp = None +class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): + def __init__(self, client_connection, address, server): + if client_connection: # Eventually, this object is restored from state. We don't have a connection then. + tcp.BaseHandler.__init__(self, client_connection, address, server) + else: + self.connection = None + self.server = None + self.wfile = None + self.rfile = None + self.address = None + self.clientcert = None + + self.timestamp_start = utils.timestamp() + self.timestamp_end = None + self.timestamp_ssl_setup = None + + _stateobject_attributes = dict( + timestamp_start=float, + timestamp_end=float, + timestamp_ssl_setup=float + ) + + def _get_state(self): + d = super(ClientConnection, self)._get_state() + d.update( + address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, + clientcert=self.cert.to_pem() if self.clientcert else None + ) + return d + + def _load_state(self, state): + super(ClientConnection, self)._load_state(state) + self.address = tcp.Address(**state["address"]) if state["address"] else None + self.clientcert = certutils.SSLCert.from_pem(state["clientcert"]) if state["clientcert"] else None + + def copy(self): + return copy.copy(self) + + def send(self, message): + self.wfile.write(message) + self.wfile.flush() + + @classmethod + def _from_state(cls, state): + f = cls(None, tuple(), None) + f._load_state(state) + return f + + def convert_to_ssl(self, *args, **kwargs): + tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs) + self.timestamp_ssl_setup = utils.timestamp() + + def finish(self): + tcp.BaseHandler.finish(self) + self.timestamp_end = utils.timestamp() + + +class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): + def __init__(self, address, priority): + tcp.TCPClient.__init__(self, address) + self.priority = priority + + self.peername = None + self.timestamp_start = None + self.timestamp_end = None + self.timestamp_tcp_setup = None + self.timestamp_ssl_setup = None + + _stateobject_attributes = dict( + peername=tuple, + timestamp_start=float, + timestamp_end=float, + timestamp_tcp_setup=float, + timestamp_ssl_setup=float, + address=tcp.Address, + source_address=tcp.Address, + cert=certutils.SSLCert, + ssl_established=bool, + sni=str + ) + + def _get_state(self): + d = super(ServerConnection, self)._get_state() + d.update( + address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, + source_address= {"address": self.source_address(), + "use_ipv6": self.source_address.use_ipv6} if self.source_address else None, + cert=self.cert.to_pem() if self.cert else None + ) + return d + + def _load_state(self, state): + super(ServerConnection, self)._load_state(state) + + self.address = tcp.Address(**state["address"]) if state["address"] else None + self.source_address = tcp.Address(**state["source_address"]) if state["source_address"] else None + self.cert = certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None + + @classmethod + def _from_state(cls, state): + f = cls(tuple(), None) + f._load_state(state) + return f + + def copy(self): + return copy.copy(self) def connect(self): + self.timestamp_start = utils.timestamp() tcp.TCPClient.connect(self) - self.tcp_setup_timestamp = time.time() - if self.scheme == "https": - clientcert = None - if self.config.clientcerts: - path = os.path.join(self.config.clientcerts, self.host.encode("idna")) + ".pem" - if os.path.exists(path): - clientcert = path - try: - self.convert_to_ssl(cert=clientcert, sni=self.sni) - self.ssl_setup_timestamp = time.time() - except tcp.NetLibError, v: - raise ProxyError(400, str(v)) - - def send(self, request): - self.requestcount += 1 - d = request._assemble() - if not d: - raise ProxyError(502, "Cannot transmit an incomplete request.") - self.wfile.write(d) + self.peername = self.connection.getpeername() + self.timestamp_tcp_setup = utils.timestamp() + + def send(self, message): + self.wfile.write(message) self.wfile.flush() - def terminate(self): - if self.connection: - try: - self.wfile.flush() - except tcp.NetLibDisconnect: # pragma: no cover - pass - self.connection.close() + def establish_ssl(self, clientcerts, sni): + clientcert = None + if clientcerts: + path = os.path.join(clientcerts, self.address.host.encode("idna")) + ".pem" + if os.path.exists(path): + clientcert = path + try: + self.convert_to_ssl(cert=clientcert, sni=sni) + self.timestamp_ssl_setup = utils.timestamp() + except tcp.NetLibError, v: + raise ProxyError(400, str(v)) + + def finish(self): + tcp.TCPClient.finish(self) + self.timestamp_end = utils.timestamp() +from . import protocol +from .protocol.http import HTTPResponse class RequestReplayThread(threading.Thread): + name="RequestReplayThread" + def __init__(self, config, flow, masterq): self.config, self.flow, self.channel = config, flow, controller.Channel(masterq) threading.Thread.__init__(self) @@ -87,448 +198,275 @@ class RequestReplayThread(threading.Thread): def run(self): try: r = self.flow.request - server = ServerConnection(self.config, r.scheme, r.host, r.port, r.host) + server = ServerConnection(self.flow.server_conn.address(), None) server.connect() - server.send(r) - httpversion, code, msg, headers, content = http.read_response( - server.rfile, r.method, self.config.body_size_limit - ) - response = flow.Response( - self.flow.request, httpversion, code, msg, headers, content, server.cert, - server.rfile.first_byte_timestamp - ) - self.channel.ask("response", response) + if self.flow.server_conn.ssl_established: + server.establish_ssl(self.config.clientcerts, + self.flow.server_conn.sni) + server.send(r._assemble()) + self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, body_size_limit=self.config.body_size_limit) + self.channel.ask("response", self.flow.response) except (ProxyError, http.HttpError, tcp.NetLibError), v: - err = flow.Error(self.flow.request, str(v)) - self.channel.ask("error", err) + self.flow.error = protocol.primitives.Error(str(v)) + self.channel.ask("error", self.flow.error) + + +class ConnectionHandler: + def __init__(self, config, client_connection, client_address, server, channel, server_version): + self.config = config + self.client_conn = ClientConnection(client_connection, client_address, server) + self.server_conn = None + self.channel, self.server_version = channel, server_version + + self.close = False + self.conntype = None + self.sni = None + + self.mode = "regular" + if self.config.reverse_proxy: + self.mode = "reverse" + if self.config.transparent_proxy: + self.mode = "transparent" + def handle(self): + self.log("clientconnect") + self.channel.ask("clientconnect", self) -class HandleSNI: - def __init__(self, handler, client_conn, host, port, key): - self.handler, self.client_conn, self.host, self.port = handler, client_conn, host, port - self.key = key + self.determine_conntype() - def __call__(self, client_connection): try: - sn = client_connection.get_servername() - if sn: - self.handler.get_server_connection(self.client_conn, "https", self.host, self.port, sn) - dummycert = self.handler.find_cert(self.client_conn, self.host, self.port, sn) - new_context = SSL.Context(SSL.TLSv1_METHOD) - new_context.use_privatekey_file(self.key) - new_context.use_certificate(dummycert.x509) - client_connection.set_context(new_context) - self.handler.sni = sn.decode("utf8").encode("idna") - # An unhandled exception in this method will core dump PyOpenSSL, so - # make dang sure it doesn't happen. - except Exception: # pragma: no cover - pass + try: + # Can we already identify the target server and connect to it? + server_address = None + address_priority = None + if self.config.forward_proxy: + server_address = self.config.forward_proxy[1:] + address_priority = AddressPriority.FORCE + elif self.config.reverse_proxy: + server_address = self.config.reverse_proxy[1:] + address_priority = AddressPriority.FROM_SETTINGS + elif self.config.transparent_proxy: + server_address = self.config.transparent_proxy["resolver"].original_addr( + self.client_conn.connection) + if not server_address: + raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") + address_priority = AddressPriority.FROM_CONNECTION + self.log("transparent to %s:%s" % server_address) + + if server_address: + self.set_server_address(server_address, address_priority) + self._handle_ssl() + + while not self.close: + try: + protocol.handle_messages(self.conntype, self) + except protocol.ConnectionTypeChange: + self.log("Connection Type Changed: %s" % self.conntype) + continue + + # FIXME: Do we want to persist errors? + except (ProxyError, tcp.NetLibError), e: + protocol.handle_error(self.conntype, self, e) + except Exception, e: + self.log(e.__class__) + import traceback + self.log(traceback.format_exc()) + self.log(str(e)) + self.del_server_connection() + self.log("clientdisconnect") + self.channel.tell("clientdisconnect", self) -class ProxyHandler(tcp.BaseHandler): - def __init__(self, config, connection, client_address, server, channel, server_version): - self.channel, self.server_version = channel, server_version - self.config = config - self.proxy_connect_state = None - self.sni = None - self.server_conn = None - tcp.BaseHandler.__init__(self, connection, client_address, server) + def _handle_ssl(self): + """ + Helper function of .handle() + Check if we can already identify SSL connections. + If so, connect to the server and establish an SSL connection + """ + client_ssl = False + server_ssl = False - def get_server_connection(self, cc, scheme, host, port, sni, request=None): + if self.config.transparent_proxy: + client_ssl = server_ssl = (self.server_conn.address.port in self.config.transparent_proxy["sslports"]) + elif self.config.reverse_proxy: + client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https") + # TODO: Make protocol generic (as with transparent proxies) + # TODO: Add SSL-terminating capatbility (SSL -> mitmproxy -> plain and vice versa) + if client_ssl or server_ssl: + self.establish_server_connection() + self.establish_ssl(client=client_ssl, server=server_ssl) + + def del_server_connection(self): """ - When SNI is in play, this means we have an SSL-encrypted - connection, which means that the entire handler is dedicated to a - single server connection - no multiplexing. If this assumption ever - breaks, we'll have to do something different with the SNI host - variable on the handler object. - - `conn_info` holds the initial connection's parameters, as the - hook might change them. Also, the hook might require an initial - request to figure out connection settings; in this case it can - set require_request, which will cause the connection to be - re-opened after the client's request arrives. + Deletes an existing server connection. """ - sc = self.server_conn - if not sni: - sni = host - conn_info = (scheme, host, port, sni) - if sc and (conn_info != sc.conn_info or (request and sc.require_request)): - sc.terminate() - self.server_conn = None - self.log( - cc, - "switching connection", [ - "%s://%s:%s (sni=%s) -> %s://%s:%s (sni=%s)"%( - scheme, host, port, sni, - sc.scheme, sc.host, sc.port, sc.sni - ) - ] - ) - if not self.server_conn: - try: - self.server_conn = ServerConnection(self.config, scheme, host, port, sni) + if self.server_conn and self.server_conn.connection: + self.server_conn.finish() + self.log("serverdisconnect", ["%s:%s" % (self.server_conn.address.host, self.server_conn.address.port)]) + self.channel.tell("serverdisconnect", self) + self.server_conn = None + self.sni = None - # Additional attributes, used if the server_connect hook - # needs to change parameters - self.server_conn.request = request - self.server_conn.require_request = False + def determine_conntype(self): + #TODO: Add ruleset to select correct protocol depending on mode/target port etc. + self.conntype = "http" - self.server_conn.conn_info = conn_info - self.channel.ask("serverconnect", self.server_conn) - self.server_conn.connect() - except tcp.NetLibError, v: - raise ProxyError(502, v) - return self.server_conn + def set_server_address(self, address, priority): + """ + Sets a new server address with the given priority. + Does not re-establish either connection or SSL handshake. + @type priority: AddressPriority + """ + address = tcp.Address.wrap(address) - def del_server_connection(self): if self.server_conn: - self.server_conn.terminate() - self.server_conn = None + if self.server_conn.priority > priority: + self.log("Attempt to change server address, " + "but priority is too low (is: %s, got: %s)" % (self.server_conn.priority, priority)) + return + if self.server_conn.address == address: + self.server_conn.priority = priority # Possibly increase priority + return - def handle(self): - cc = flow.ClientConnect(self.client_address) - self.log(cc, "connect") - self.channel.ask("clientconnect", cc) - while self.handle_request(cc) and not cc.close: - pass - cc.close = True - self.del_server_connection() + self.del_server_connection() - cd = flow.ClientDisconnect(cc) - self.log( - cc, "disconnect", - [ - "handled %s requests"%cc.requestcount] - ) - self.channel.tell("clientdisconnect", cd) + self.log("Set new server address: %s:%s" % (address.host, address.port)) + self.server_conn = ServerConnection(address, priority) - def handle_request(self, cc): + def establish_server_connection(self): + """ + Establishes a new server connection. + If there is already an existing server connection, the function returns immediately. + """ + if self.server_conn.connection: + return + self.log("serverconnect", ["%s:%s" % self.server_conn.address()[:2]]) + self.channel.tell("serverconnect", self) try: - request, err = None, None - request = self.read_request(cc) - if request is None: - return - cc.requestcount += 1 + self.server_conn.connect() + except tcp.NetLibError, v: + raise ProxyError(502, v) - request_reply = self.channel.ask("request", request) - if request_reply is None or request_reply == KILL: - return - elif isinstance(request_reply, flow.Response): - request = False - response = request_reply - response_reply = self.channel.ask("response", response) - else: - request = request_reply - if self.config.reverse_proxy: - scheme, host, port = self.config.reverse_proxy - elif self.config.forward_proxy: - scheme, host, port = self.config.forward_proxy - else: - scheme, host, port = request.scheme, request.host, request.port - - # If we've already pumped a request over this connection, - # it's possible that the server has timed out. If this is - # the case, we want to reconnect without sending an error - # to the client. - while 1: - sc = self.get_server_connection(cc, scheme, host, port, self.sni, request=request) - sc.send(request) - if sc.requestcount == 1: # add timestamps only for first request (others are not directly affected) - request.tcp_setup_timestamp = sc.tcp_setup_timestamp - request.ssl_setup_timestamp = sc.ssl_setup_timestamp - sc.rfile.reset_timestamps() - try: - peername = sc.connection.getpeername() - if peername: - request.ip = peername[0] - httpversion, code, msg, headers, content = http.read_response( - sc.rfile, - request.method, - self.config.body_size_limit - ) - except http.HttpErrorConnClosed: - self.del_server_connection() - if sc.requestcount > 1: - continue - else: - raise - except http.HttpError: - raise ProxyError(502, "Invalid server response.") - else: - break - - response = flow.Response( - request, httpversion, code, msg, headers, content, sc.cert, - sc.rfile.first_byte_timestamp - ) - response_reply = self.channel.ask("response", response) - # Not replying to the server invalidates the server - # connection, so we terminate. - if response_reply == KILL: - sc.terminate() - - if response_reply == KILL: - return - else: - response = response_reply - self.send_response(response) - if request and http.connection_close(request.httpversion, request.headers): - return - # We could keep the client connection when the server - # connection needs to go away. However, we want to mimic - # behaviour as closely as possible to the client, so we - # disconnect. - if http.connection_close(response.httpversion, response.headers): - return - except (IOError, ProxyError, http.HttpError, tcp.NetLibError), e: - if hasattr(e, "code"): - cc.error = "%s: %s"%(e.code, e.msg) - else: - cc.error = str(e) - - if request: - err = flow.Error(request, cc.error) - self.channel.ask("error", err) - self.log( - cc, cc.error, - ["url: %s"%request.get_url()] - ) - else: - self.log(cc, cc.error) - if isinstance(e, ProxyError): - self.send_error(e.code, e.msg, e.headers) - else: - return True + def establish_ssl(self, client=False, server=False): + """ + Establishes SSL on the existing connection(s) to the server or the client, + as specified by the parameters. If the target server is on the pass-through list, + the conntype attribute will be changed and the SSL connection won't be wrapped. + A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening + """ + # TODO: Implement SSL pass-through handling and change conntype + passthrough = [ + # "echo.websocket.org", + # "174.129.224.73" # echo.websocket.org, transparent mode + ] + if self.server_conn.address.host in passthrough or self.sni in passthrough: + self.conntype = "tcp" + return + + # Logging + if client or server: + subs = [] + if client: + subs.append("with client") + if server: + subs.append("with server (sni: %s)" % self.sni) + self.log("Establish SSL", subs) + + if server: + if self.server_conn.ssl_established: + raise ProxyError(502, "SSL to Server already established.") + self.establish_server_connection() # make sure there is a server connection. + self.server_conn.establish_ssl(self.config.clientcerts, self.sni) + if client: + if self.client_conn.ssl_established: + raise ProxyError(502, "SSL to Client already established.") + dummycert = self.find_cert() + self.client_conn.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, + handle_sni=self.handle_sni) + + def server_reconnect(self, no_ssl=False): + address = self.server_conn.address + had_ssl = self.server_conn.ssl_established + priority = self.server_conn.priority + sni = self.sni + self.log("(server reconnect follows)") + self.del_server_connection() + self.set_server_address(address, priority) + self.establish_server_connection() + if had_ssl and not no_ssl: + self.sni = sni + self.establish_ssl(server=True) + + def finish(self): + self.client_conn.finish() - def log(self, cc, msg, subs=()): + def log(self, msg, subs=()): msg = [ - "%s:%s: "%cc.address + msg + "%s:%s: %s" % (self.client_conn.address.host, self.client_conn.address.port, msg) ] for i in subs: - msg.append(" -> "+i) + msg.append(" -> " + i) msg = "\n".join(msg) - l = Log(msg) - self.channel.tell("log", l) + self.channel.tell("log", Log(msg)) - def find_cert(self, cc, host, port, sni): + def find_cert(self): if self.config.certfile: with open(self.config.certfile, "rb") as f: return certutils.SSLCert.from_pem(f.read()) else: + host = self.server_conn.address.host sans = [] - if not self.config.no_upstream_cert: - conn = self.get_server_connection(cc, "https", host, port, sni) - sans = conn.cert.altnames - if conn.cert.cn: - host = conn.cert.cn.decode("utf8").encode("idna") + if not self.config.no_upstream_cert or not self.server_conn.ssl_established: + upstream_cert = self.server_conn.cert + if upstream_cert.cn: + host = upstream_cert.cn.decode("utf8").encode("idna") + sans = upstream_cert.altnames + ret = self.config.certstore.get_cert(host, sans, self.config.cacert) if not ret: raise ProxyError(502, "Unable to generate dummy cert.") return ret - def establish_ssl(self, client_conn, host, port): - dummycert = self.find_cert(client_conn, host, port, host) - sni = HandleSNI( - self, client_conn, host, port, self.config.certfile or self.config.cacert - ) - try: - self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni) - except tcp.NetLibError, v: - raise ProxyError(400, str(v)) - - def get_line(self, fp): + def handle_sni(self, connection): """ - Get a line, possibly preceded by a blank. + This callback gets called during the SSL handshake with the client. + The client has just sent the Sever Name Indication (SNI). We now connect upstream to + figure out which certificate needs to be served. """ - line = fp.readline() - if line == "\r\n" or line == "\n": # Possible leftover from previous message - line = fp.readline() - return line - - def read_request(self, client_conn): - self.rfile.reset_timestamps() - if self.config.transparent_proxy: - return self.read_request_transparent(client_conn) - elif self.config.reverse_proxy: - return self.read_request_reverse(client_conn) - else: - return self.read_request_proxy(client_conn) - - def read_request_transparent(self, client_conn): - orig = self.config.transparent_proxy["resolver"].original_addr(self.connection) - if not orig: - raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") - self.log(client_conn, "transparent to %s:%s"%orig) - - host, port = orig - if port in self.config.transparent_proxy["sslports"]: - scheme = "https" - else: - scheme = "http" - - return self._read_request_origin_form(client_conn, scheme, host, port) - - def read_request_reverse(self, client_conn): - scheme, host, port = self.config.reverse_proxy - return self._read_request_origin_form(client_conn, scheme, host, port) - - def read_request_proxy(self, client_conn): - # Check for a CONNECT command. - if not self.proxy_connect_state: - line = self.get_line(self.rfile) - if line == "": - return None - self.proxy_connect_state = self._read_request_authority_form(line) - - # Check for an actual request - if self.proxy_connect_state: - host, port, _ = self.proxy_connect_state - return self._read_request_origin_form(client_conn, "https", host, port) - else: - # noinspection PyUnboundLocalVariable - return self._read_request_absolute_form(client_conn, line) - - def _read_request_authority_form(self, line): - """ - The authority-form of request-target is only used for CONNECT requests. - The CONNECT method is used to request a tunnel to the destination server. - This function sends a "200 Connection established" response to the client - and returns the host information that can be used to process further requests in origin-form. - An example authority-form request line would be: - CONNECT www.example.com:80 HTTP/1.1 - """ - connparts = http.parse_init_connect(line) - if connparts: - self.read_headers(authenticate=True) - # respond according to http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.2 - self.wfile.write( - 'HTTP/1.1 200 Connection established\r\n' + - ('Proxy-agent: %s\r\n'%self.server_version) + - '\r\n' - ) - self.wfile.flush() - return connparts - - def _read_request_absolute_form(self, client_conn, line): - """ - When making a request to a proxy (other than CONNECT or OPTIONS), - a client must send the target uri in absolute-form. - An example absolute-form request line would be: - GET http://www.example.com/foo.html HTTP/1.1 - """ - r = http.parse_init_proxy(line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) - method, scheme, host, port, path, httpversion = r - headers = self.read_headers(authenticate=True) - self.handle_expect_header(headers, httpversion) - content = http.read_http_body( - self.rfile, headers, self.config.body_size_limit, True - ) - r = flow.Request( - client_conn, httpversion, host, port, scheme, method, path, headers, content, - self.rfile.first_byte_timestamp, utils.timestamp() - ) - r.set_live(self.rfile, self.wfile) - return r - - def _read_request_origin_form(self, client_conn, scheme, host, port): - """ - Read a HTTP request with regular (origin-form) request line. - An example origin-form request line would be: - GET /foo.html HTTP/1.1 - - The request destination is already known from one of the following sources: - 1) transparent proxy: destination provided by platform resolver - 2) reverse proxy: fixed destination - 3) regular proxy: known from CONNECT command. - """ - if scheme.lower() == "https" and not self.ssl_established: - self.establish_ssl(client_conn, host, port) - - line = self.get_line(self.rfile) - if line == "": - return None - - r = http.parse_init_http(line) - if not r: - raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) - method, path, httpversion = r - headers = self.read_headers(authenticate=False) - self.handle_expect_header(headers, httpversion) - content = http.read_http_body( - self.rfile, headers, self.config.body_size_limit, True - ) - r = flow.Request( - client_conn, httpversion, host, port, scheme, method, path, headers, content, - self.rfile.first_byte_timestamp, utils.timestamp() - ) - r.set_live(self.rfile, self.wfile) - return r - - def handle_expect_header(self, headers, httpversion): - if "expect" in headers: - if "100-continue" in headers['expect'] and httpversion >= (1, 1): - #FIXME: Check if content-length is over limit - self.wfile.write('HTTP/1.1 100 Continue\r\n' - '\r\n') - del headers['expect'] - - def read_headers(self, authenticate=False): - headers = http.read_headers(self.rfile) - if headers is None: - raise ProxyError(400, "Invalid headers") - if authenticate and self.config.authenticator: - if self.config.authenticator.authenticate(headers): - self.config.authenticator.clean(headers) - else: - raise ProxyError( - 407, - "Proxy Authentication Required", - self.config.authenticator.auth_challenge_headers() - ) - return headers - - def send_response(self, response): - d = response._assemble() - if not d: - raise ProxyError(502, "Cannot transmit an incomplete response.") - self.wfile.write(d) - self.wfile.flush() - - def send_error(self, code, body, headers): try: - response = http_status.RESPONSES.get(code, "Unknown") - html_content = '<html><head>\n<title>%d %s</title>\n</head>\n<body>\n%s\n</body>\n</html>'%(code, response, body) - self.wfile.write("HTTP/1.1 %s %s\r\n" % (code, response)) - self.wfile.write("Server: %s\r\n"%self.server_version) - self.wfile.write("Content-type: text/html\r\n") - self.wfile.write("Content-Length: %d\r\n"%len(html_content)) - if headers: - for key, value in headers.items(): - self.wfile.write("%s: %s\r\n"%(key, value)) - self.wfile.write("Connection: close\r\n") - self.wfile.write("\r\n") - self.wfile.write(html_content) - self.wfile.flush() - except: + sn = connection.get_servername() + if sn and sn != self.sni: + self.sni = sn.decode("utf8").encode("idna") + self.log("SNI received: %s" % self.sni) + self.server_reconnect() # reconnect to upstream server with SNI + # Now, change client context to reflect changed certificate: + new_context = SSL.Context(SSL.TLSv1_METHOD) + new_context.use_privatekey_file(self.config.certfile or self.config.cacert) + dummycert = self.find_cert() + new_context.use_certificate(dummycert.x509) + connection.set_context(new_context) + # An unhandled exception in this method will core dump PyOpenSSL, so + # make dang sure it doesn't happen. + except Exception, e: # pragma: no cover pass -class ProxyServerError(Exception): pass +class ProxyServerError(Exception): + pass class ProxyServer(tcp.TCPServer): allow_reuse_address = True bound = True - def __init__(self, config, port, address='', server_version=version.NAMEVERSION): + + def __init__(self, config, port, host='', server_version=version.NAMEVERSION): """ Raises ProxyServerError if there's a startup problem. """ - self.config, self.port, self.address = config, port, address + self.config = config self.server_version = server_version try: - tcp.TCPServer.__init__(self, (address, port)) + tcp.TCPServer.__init__(self, (host, port)) except socket.error, v: raise ProxyServerError('Error starting proxy server: ' + v.strerror) self.channel = None @@ -540,14 +478,15 @@ class ProxyServer(tcp.TCPServer): def set_channel(self, channel): self.channel = channel - def handle_connection(self, request, client_address): - h = ProxyHandler(self.config, request, client_address, self, self.channel, self.server_version) + def handle_client_connection(self, conn, client_address): + h = ConnectionHandler(self.config, conn, client_address, self, self.channel, self.server_version) h.handle() h.finish() class DummyServer: bound = False + def __init__(self, config): self.config = config @@ -563,22 +502,21 @@ def certificate_option_group(parser): group = parser.add_argument_group("SSL") group.add_argument( "--cert", action="store", - type = str, dest="cert", default=None, - help = "User-created SSL certificate file." + type=str, dest="cert", default=None, + help="User-created SSL certificate file." ) group.add_argument( "--client-certs", action="store", - type = str, dest = "clientcerts", default=None, - help = "Client certificate directory." + type=str, dest="clientcerts", default=None, + help="Client certificate directory." ) - def process_proxy_options(parser, options): if options.cert: options.cert = os.path.expanduser(options.cert) if not os.path.exists(options.cert): - return parser.error("Manually created certificate does not exist: %s"%options.cert) + return parser.error("Manually created certificate does not exist: %s" % options.cert) cacert = os.path.join(options.confdir, "mitmproxy-ca.pem") cacert = os.path.expanduser(cacert) @@ -592,8 +530,8 @@ def process_proxy_options(parser, options): if not platform.resolver: return parser.error("Transparent mode not supported on this platform.") trans = dict( - resolver = platform.resolver(), - sslports = TRANSPARENT_SSL_PORTS + resolver=platform.resolver(), + sslports=TRANSPARENT_SSL_PORTS ) else: trans = None @@ -601,14 +539,14 @@ def process_proxy_options(parser, options): if options.reverse_proxy: rp = utils.parse_proxy_spec(options.reverse_proxy) if not rp: - return parser.error("Invalid reverse proxy specification: %s"%options.reverse_proxy) + return parser.error("Invalid reverse proxy specification: %s" % options.reverse_proxy) else: rp = None if options.forward_proxy: fp = utils.parse_proxy_spec(options.forward_proxy) if not fp: - return parser.error("Invalid forward proxy specification: %s"%options.forward_proxy) + return parser.error("Invalid forward proxy specification: %s" % options.forward_proxy) else: fp = None @@ -616,8 +554,8 @@ def process_proxy_options(parser, options): options.clientcerts = os.path.expanduser(options.clientcerts) if not os.path.exists(options.clientcerts) or not os.path.isdir(options.clientcerts): return parser.error( - "Client certificate directory does not exist or is not a directory: %s"%options.clientcerts - ) + "Client certificate directory does not exist or is not a directory: %s" % options.clientcerts + ) if (options.auth_nonanonymous or options.auth_singleuser or options.auth_htpasswd): if options.auth_singleuser: @@ -637,13 +575,13 @@ def process_proxy_options(parser, options): authenticator = http_auth.NullProxyAuth(None) return ProxyConfig( - certfile = options.cert, - cacert = cacert, - clientcerts = options.clientcerts, - body_size_limit = body_size_limit, - no_upstream_cert = options.no_upstream_cert, - reverse_proxy = rp, - forward_proxy = fp, - transparent_proxy = trans, - authenticator = authenticator + certfile=options.cert, + cacert=cacert, + clientcerts=options.clientcerts, + body_size_limit=body_size_limit, + no_upstream_cert=options.no_upstream_cert, + reverse_proxy=rp, + forward_proxy=fp, + transparent_proxy=trans, + authenticator=authenticator ) diff --git a/libmproxy/script.py b/libmproxy/script.py index 0912c9ae..d34d3383 100644 --- a/libmproxy/script.py +++ b/libmproxy/script.py @@ -108,7 +108,7 @@ def _handle_concurrent_reply(fn, o, args=[], kwargs={}): def run(): fn(*args, **kwargs) reply(o) - threading.Thread(target=run).start() + threading.Thread(target=run, name="ScriptThread").start() def concurrent(fn): diff --git a/libmproxy/stateobject.py b/libmproxy/stateobject.py new file mode 100644 index 00000000..a752999d --- /dev/null +++ b/libmproxy/stateobject.py @@ -0,0 +1,73 @@ +class StateObject(object): + def _get_state(self): + raise NotImplementedError # pragma: nocover + + def _load_state(self, state): + raise NotImplementedError # pragma: nocover + + @classmethod + def _from_state(cls, state): + raise NotImplementedError # pragma: nocover + # Usually, this function roughly equals to the following code: + # f = cls() + # f._load_state(state) + # return f + + def __eq__(self, other): + try: + return self._get_state() == other._get_state() + except AttributeError: # we may compare with something that's not a StateObject + return False + + +class SimpleStateObject(StateObject): + """ + A StateObject with opionated conventions that tries to keep everything DRY. + + Simply put, you agree on a list of attributes and their type. + Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves. + SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods. + Overriding _get_state or _load_state to add custom adjustments is always possible. + """ + + _stateobject_attributes = None # none by default to raise an exception if definition was forgotten + """ + An attribute-name -> class-or-type dict containing all attributes that should be serialized + If the attribute is a class, this class must be a subclass of StateObject. + """ + + def _get_state(self): + return {attr: self._get_state_attr(attr, cls) + for attr, cls in self._stateobject_attributes.iteritems()} + + def _get_state_attr(self, attr, cls): + """ + helper for _get_state. + returns the value of the given attribute + """ + val = getattr(self, attr) + if hasattr(val, "_get_state"): + return val._get_state() + else: + return val + + def _load_state(self, state): + for attr, cls in self._stateobject_attributes.iteritems(): + self._load_state_attr(attr, cls, state) + + def _load_state_attr(self, attr, cls, state): + """ + helper for _load_state. + loads the given attribute from the state. + """ + if state.get(attr, None) is None: + setattr(self, attr, None) + return + + curr = getattr(self, attr) + if hasattr(curr, "_load_state"): + curr._load_state(state[attr]) + elif hasattr(cls, "_from_state"): + setattr(self, attr, cls._from_state(state[attr])) + else: + setattr(self, attr, cls(state[attr]))
\ No newline at end of file diff --git a/test/data/confdir/mitmproxy-ca-cert.cer b/test/data/confdir/mitmproxy-ca-cert.cer new file mode 100644 index 00000000..cc7f8f19 --- /dev/null +++ b/test/data/confdir/mitmproxy-ca-cert.cer @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE----- +MIICnzCCAgigAwIBAgIGDKiSwuJOMA0GCSqGSIb3DQEBBQUAMCgxEjAQBgNVBAMT +CW1pdG1wcm94eTESMBAGA1UEChMJbWl0bXByb3h5MB4XDTE0MDIwNzIzMjcwOFoX +DTE2MDEyODIzMjcwOFowKDESMBAGA1UEAxMJbWl0bXByb3h5MRIwEAYDVQQKEwlt +aXRtcHJveHkwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKsZ+XnBvjCjAJ00 +9M+v41AT91h7v7cF1UG0BpS3y4MOysN88btHM/IWRCllnmY+zx5LTMAEtbnqyOIk +nkgJ0sU3CFWHRIfwkinssEtMM2mOAFXm0wqffECxwe1p5z84M7nOolzuuw4FtkaK +G9/UqANdRVs6uOwz+CuyOSY7illTAgMBAAGjgdMwgdAwDwYDVR0TAQH/BAUwAwEB +/zAUBglghkgBhvhCAQEBAf8EBAMCAgQwewYDVR0lAQH/BHEwbwYIKwYBBQUHAwEG +CCsGAQUFBwMCBggrBgEFBQcDBAYIKwYBBQUHAwgGCisGAQQBgjcCARUGCisGAQQB +gjcCARYGCisGAQQBgjcKAwEGCisGAQQBgjcKAwMGCisGAQQBgjcKAwQGCWCGSAGG ++EIEATALBgNVHQ8EBAMCAQYwHQYDVR0OBBYEFFKVDIF+w2Ns4KsJx6tJZpILqWwG +MA0GCSqGSIb3DQEBBQUAA4GBABWYxoYFLgZh/ujz/0jrNsx0pvSNVTU1T669374z +PhO+ScvzuxVbgI2NQv86aqih35pzakK/DyKaTck85QduDiSiLNw2Yb5UfJvO4C0d +dPzQMIKNTInFFiLBjbvxx9cuDwAPyYOF247Xj9M6C2x6e/gq1L+GR75wT5288x9h +rFTJ +-----END CERTIFICATE----- diff --git a/test/data/confdir/mitmproxy-ca-cert.p12 b/test/data/confdir/mitmproxy-ca-cert.p12 Binary files differnew file mode 100644 index 00000000..d4cec0d4 --- /dev/null +++ b/test/data/confdir/mitmproxy-ca-cert.p12 diff --git a/test/data/confdir/mitmproxy-ca-cert.pem b/test/data/confdir/mitmproxy-ca-cert.pem new file mode 100644 index 00000000..cc7f8f19 --- /dev/null +++ b/test/data/confdir/mitmproxy-ca-cert.pem @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE----- +MIICnzCCAgigAwIBAgIGDKiSwuJOMA0GCSqGSIb3DQEBBQUAMCgxEjAQBgNVBAMT +CW1pdG1wcm94eTESMBAGA1UEChMJbWl0bXByb3h5MB4XDTE0MDIwNzIzMjcwOFoX +DTE2MDEyODIzMjcwOFowKDESMBAGA1UEAxMJbWl0bXByb3h5MRIwEAYDVQQKEwlt +aXRtcHJveHkwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKsZ+XnBvjCjAJ00 +9M+v41AT91h7v7cF1UG0BpS3y4MOysN88btHM/IWRCllnmY+zx5LTMAEtbnqyOIk +nkgJ0sU3CFWHRIfwkinssEtMM2mOAFXm0wqffECxwe1p5z84M7nOolzuuw4FtkaK +G9/UqANdRVs6uOwz+CuyOSY7illTAgMBAAGjgdMwgdAwDwYDVR0TAQH/BAUwAwEB +/zAUBglghkgBhvhCAQEBAf8EBAMCAgQwewYDVR0lAQH/BHEwbwYIKwYBBQUHAwEG +CCsGAQUFBwMCBggrBgEFBQcDBAYIKwYBBQUHAwgGCisGAQQBgjcCARUGCisGAQQB +gjcCARYGCisGAQQBgjcKAwEGCisGAQQBgjcKAwMGCisGAQQBgjcKAwQGCWCGSAGG ++EIEATALBgNVHQ8EBAMCAQYwHQYDVR0OBBYEFFKVDIF+w2Ns4KsJx6tJZpILqWwG +MA0GCSqGSIb3DQEBBQUAA4GBABWYxoYFLgZh/ujz/0jrNsx0pvSNVTU1T669374z +PhO+ScvzuxVbgI2NQv86aqih35pzakK/DyKaTck85QduDiSiLNw2Yb5UfJvO4C0d +dPzQMIKNTInFFiLBjbvxx9cuDwAPyYOF247Xj9M6C2x6e/gq1L+GR75wT5288x9h +rFTJ +-----END CERTIFICATE----- diff --git a/test/data/confdir/mitmproxy-ca.pem b/test/data/confdir/mitmproxy-ca.pem new file mode 100644 index 00000000..2a2343a6 --- /dev/null +++ b/test/data/confdir/mitmproxy-ca.pem @@ -0,0 +1,32 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCrGfl5wb4wowCdNPTPr+NQE/dYe7+3BdVBtAaUt8uDDsrDfPG7 +RzPyFkQpZZ5mPs8eS0zABLW56sjiJJ5ICdLFNwhVh0SH8JIp7LBLTDNpjgBV5tMK +n3xAscHtaec/ODO5zqJc7rsOBbZGihvf1KgDXUVbOrjsM/grsjkmO4pZUwIDAQAB +AoGAUtjn4Fm8cqZqpLRAmdOruFmCmbiJ0uAjK4Y07Yu1IgdmjJOSJMFMWLsJVBYd +RZrCBQQm7I8bQyN5E27xqSYAhKz7ymjgHGWlTXENtvfx/XlIIn9DYENKpN1N8Y/5 +BCt0O/F9h2/Z+zGNdV3R2tX3WuSjYlqzzD2RDBIDPe6Fr8kCQQDSLcyqGRXamt0X +MjPtltJHIjIXHp+++qQDT3n8eaP0maWtAm+75PzWGqOvfg4F2VoWMTGdDEbHbCmH +Qa6EW0B/AkEA0Gc90xLD+qLqVEbzdveca+yO1lAastqoYzRuM1StZ1Y4pW7F5D23 +MNhV0zV6z7ejZYnnsGvuQLTx51X8Ff59LQJAF1mxQECTNfs4jugr7rxv1ilNaVYk +p0IPULLWuZ8GARnE10jLAxP4pwzEnK2jfzDbmlWSzoDbqDIzFuzMJ7Y/nwJBAL+s +dNxRAhbfCA6DQyFEE4XfiG/sNOIS4ZR8gG6Njv7f+jGNdEy7xmUSU71yDoZFK+8T +qxhD7FlvEp3mI3hHG/ECQQC0x7z/lr5KRsFGqVZOErkc3nOZO+4rjApHSlbuhDLU +mnUwIi06KyjbN+0XL+6bJl+L5nfL3TIlnyHMJAta2uta +-----END RSA PRIVATE KEY----- +-----BEGIN CERTIFICATE----- +MIICnzCCAgigAwIBAgIGDKiSwuJOMA0GCSqGSIb3DQEBBQUAMCgxEjAQBgNVBAMT +CW1pdG1wcm94eTESMBAGA1UEChMJbWl0bXByb3h5MB4XDTE0MDIwNzIzMjcwOFoX +DTE2MDEyODIzMjcwOFowKDESMBAGA1UEAxMJbWl0bXByb3h5MRIwEAYDVQQKEwlt +aXRtcHJveHkwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKsZ+XnBvjCjAJ00 +9M+v41AT91h7v7cF1UG0BpS3y4MOysN88btHM/IWRCllnmY+zx5LTMAEtbnqyOIk +nkgJ0sU3CFWHRIfwkinssEtMM2mOAFXm0wqffECxwe1p5z84M7nOolzuuw4FtkaK +G9/UqANdRVs6uOwz+CuyOSY7illTAgMBAAGjgdMwgdAwDwYDVR0TAQH/BAUwAwEB +/zAUBglghkgBhvhCAQEBAf8EBAMCAgQwewYDVR0lAQH/BHEwbwYIKwYBBQUHAwEG +CCsGAQUFBwMCBggrBgEFBQcDBAYIKwYBBQUHAwgGCisGAQQBgjcCARUGCisGAQQB +gjcCARYGCisGAQQBgjcKAwEGCisGAQQBgjcKAwMGCisGAQQBgjcKAwQGCWCGSAGG ++EIEATALBgNVHQ8EBAMCAQYwHQYDVR0OBBYEFFKVDIF+w2Ns4KsJx6tJZpILqWwG +MA0GCSqGSIb3DQEBBQUAA4GBABWYxoYFLgZh/ujz/0jrNsx0pvSNVTU1T669374z +PhO+ScvzuxVbgI2NQv86aqih35pzakK/DyKaTck85QduDiSiLNw2Yb5UfJvO4C0d +dPzQMIKNTInFFiLBjbvxx9cuDwAPyYOF247Xj9M6C2x6e/gq1L+GR75wT5288x9h +rFTJ +-----END CERTIFICATE----- diff --git a/test/data/serverkey.pem b/test/data/serverkey.pem deleted file mode 100644 index 289bfa71..00000000 --- a/test/data/serverkey.pem +++ /dev/null @@ -1,32 +0,0 @@ ------BEGIN RSA PRIVATE KEY----- -MIICXQIBAAKBgQC+N+9bv1YC0GKbGdv2wMuuWTGSNwE/Hq5IIxYN1eITsvbD1GgB -69x++XJd6KTIthnta0KCpCAtbaYbCkhUfxCVv2bP+iQt2AjwMOZlgRZ+RGJ25dBu -AjAxQmqDJcAdS6MoRHWziomnUNfNogVrfqjpvJor+1iRnrj2q00ab9WYCwIDAQAB -AoGBAIM7V9l2UcKzPbQ/zO+Z52urgXWcmTGQ2zBNdIOrEcQBbhmAyxi4PnEja3G6 -dSU77PtNSp+S19g/k5+IIoqY9zkGigdaPhRVRKJgBTAzFzMz+WHpQIffDojFKCnL -gyDnzMRJY8+cnsCqbHRY4hqFiCr8Rq9sCdlynAytdtrnxzqhAkEA9bha6MO+L0JA -6IEEbVY1vtaUO9Xg5DUDjRxQcfniSJACb/2IvF0tvxAnG7I/S8AavCXqtlDPtYkI -WOxY5Sd62QJBAMYtKUxGka4XxwCyBK8EUNaN8m9C++mpjoHD1kFri9B1bXm91nCO -iGWqtqdarwyEc/pAHw5UGzVyBXticPIcs4MCQQCcPvsHsZhYoq91aLyw7bXFQNsH -ZUvYsOEuNIfuwa+i5ne2UKhG5pU1PgcwNFrNRz140D98aMx7KcS2DqvEIyOZAkBF -6Yi4L+0Uza6WwDaGx679AfaU6byVIgv0G3JqgdZBJCwK1r3f12im9SKax5MZh2Ci -2Bwcoe83W5IzhPbzcsyhAkBo8O2U2vig5PQWQ0BUKJrCGHLq//D/ttdLVtmc6eWc -zqssCF3Unkk3bOq35swSKeAx8WotPPVsALWr87N2hCB+ ------END RSA PRIVATE KEY----- ------BEGIN CERTIFICATE----- -MIICsDCCAhmgAwIBAgIJANwogM9sqMHLMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV -BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX -aWRnaXRzIFB0eSBMdGQwHhcNMTAwMTMxMDEzOTEzWhcNMTEwMTMxMDEzOTEzWjBF -MQswCQYDVQQGEwJBVTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50 -ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB -gQC+N+9bv1YC0GKbGdv2wMuuWTGSNwE/Hq5IIxYN1eITsvbD1GgB69x++XJd6KTI -thnta0KCpCAtbaYbCkhUfxCVv2bP+iQt2AjwMOZlgRZ+RGJ25dBuAjAxQmqDJcAd -S6MoRHWziomnUNfNogVrfqjpvJor+1iRnrj2q00ab9WYCwIDAQABo4GnMIGkMB0G -A1UdDgQWBBTTnBZyw7ZZsb8+/6gvZFIHhVgtDzB1BgNVHSMEbjBsgBTTnBZyw7ZZ -sb8+/6gvZFIHhVgtD6FJpEcwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgTClNvbWUt -U3RhdGUxITAfBgNVBAoTGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZIIJANwogM9s -qMHLMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADgYEApz428aOar0EBuAib -I+liefRlK4I3MQQxq3tOeB1dgAIo0ivKtdVJGi1kPg8EO0KMvFfn6IRtssUmFgCp -JBD+HoDzFxwI1bLMVni+g7OzaNSwL3nQ94lZUdpWMYDxqY4bLUv3goX1TlN9lmpG -8FiBLYUC0RNTCCRDFGfDr/wUT/M= ------END CERTIFICATE----- diff --git a/test/mock_urwid.py b/test/mock_urwid.py new file mode 100644 index 00000000..f132e0bd --- /dev/null +++ b/test/mock_urwid.py @@ -0,0 +1,8 @@ +import os, sys, mock +if os.name == "nt": + m = mock.Mock() + m.__version__ = "1.1.1" + m.Widget = mock.Mock + m.WidgetWrap = mock.Mock + sys.modules['urwid'] = m + sys.modules['urwid.util'] = mock.Mock()
\ No newline at end of file diff --git a/test/test_app.py b/test/test_app.py new file mode 100644 index 00000000..f0eab7cc --- /dev/null +++ b/test/test_app.py @@ -0,0 +1,19 @@ +import mock, socket, os, time +from libmproxy import dump +from netlib import certutils, tcp +from libpathod.pathoc import Pathoc +import tutils, tservers + +class TestApp(tservers.HTTPProxTest): + def test_basic(self): + assert self.app("/").status_code == 200 + + def test_cert(self): + path = tutils.test_data.path("data/confdir/") + "mitmproxy-ca-cert." + with tutils.tmpdir() as d: + for ext in ["pem", "p12"]: + resp = self.app("/cert/%s" % ext) + assert resp.status_code == 200 + with open(path + ext, "rb") as f: + assert resp.content == f.read() + diff --git a/test/test_console.py b/test/test_console.py index 4fd9bb9f..0c5b4591 100644 --- a/test/test_console.py +++ b/test/test_console.py @@ -1,10 +1,9 @@ -import os -from nose.plugins.skip import SkipTest -if os.name == "nt": - raise SkipTest("Skipped on Windows.") - +import os, sys, mock, gc +from os.path import normpath +import mock_urwid from libmproxy import console from libmproxy.console import common + import tutils class TestConsoleState: @@ -16,7 +15,7 @@ class TestConsoleState: """ c = console.ConsoleState() f = self._add_request(c) - assert f.request in c._flow_map + assert f in c._flow_list assert c.get_focus() == (f, 0) def test_focus(self): @@ -89,6 +88,7 @@ class TestConsoleState: assert len(c.flowsettings) == 1 c.delete_flow(f) del f + gc.collect() assert len(c.flowsettings) == 0 @@ -107,19 +107,17 @@ def test_format_keyvals(): class TestPathCompleter: def test_lookup_construction(self): c = console._PathCompleter() - assert c.complete("/tm") == "/tmp/" - c.reset() cd = tutils.test_data.path("completion") ca = os.path.join(cd, "a") - assert c.complete(ca).endswith("/completion/aaa") - assert c.complete(ca).endswith("/completion/aab") + assert c.complete(ca).endswith(normpath("/completion/aaa")) + assert c.complete(ca).endswith(normpath("/completion/aab")) c.reset() ca = os.path.join(cd, "aaa") - assert c.complete(ca).endswith("/completion/aaa") - assert c.complete(ca).endswith("/completion/aaa") + assert c.complete(ca).endswith(normpath("/completion/aaa")) + assert c.complete(ca).endswith(normpath("/completion/aaa")) c.reset() - assert c.complete(cd).endswith("/completion/aaa") + assert c.complete(cd).endswith(normpath("/completion/aaa")) def test_completion(self): c = console._PathCompleter(True) diff --git a/test/test_dump.py b/test/test_dump.py index a958a2ec..8b4b9aa5 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -6,11 +6,11 @@ import mock def test_strfuncs(): t = tutils.tresp() - t._set_replay() + t.is_replay = True dump.str_response(t) t = tutils.treq() - t.client_conn = None + t.flow.client_conn = None t.stickycookie = True assert "stickycookie" in dump.str_request(t, False) assert "stickycookie" in dump.str_request(t, True) @@ -20,24 +20,20 @@ def test_strfuncs(): class TestDumpMaster: def _cycle(self, m, content): - req = tutils.treq() - req.content = content + req = tutils.treq(content=content) l = proxy.Log("connect") l.reply = mock.MagicMock() m.handle_log(l) - cc = req.client_conn - cc.connection_error = "error" - resp = tutils.tresp(req) - resp.content = content + cc = req.flow.client_conn + cc.reply = mock.MagicMock() m.handle_clientconnect(cc) - sc = proxy.ServerConnection(m.o, req.scheme, req.host, req.port, None) + sc = proxy.ServerConnection((req.get_host(), req.get_port()), None) sc.reply = mock.MagicMock() m.handle_serverconnection(sc) m.handle_request(req) + resp = tutils.tresp(req, content=content) f = m.handle_response(resp) - cd = flow.ClientDisconnect(cc) - cd.reply = mock.MagicMock() - m.handle_clientdisconnect(cd) + m.handle_clientdisconnect(cc) return f def _dummy_cycle(self, n, filt, content, **options): diff --git a/test/test_filt.py b/test/test_filt.py index 4e059196..452a4505 100644 --- a/test/test_filt.py +++ b/test/test_filt.py @@ -1,6 +1,8 @@ import cStringIO from libmproxy import filt, flow - +from libmproxy.protocol import http +from libmproxy.protocol.primitives import Error +import tutils class TestParsing: def _dump(self, x): @@ -72,41 +74,37 @@ class TestParsing: class TestMatching: def req(self): - conn = flow.ClientConnect(("one", 2222)) headers = flow.ODictCaseless() headers["header"] = ["qvalue"] - req = flow.Request( - conn, - (1, 1), - "host", - 80, - "http", - "GET", - "/path", - headers, - "content_request" + req = http.HTTPRequest( + "absolute", + "GET", + "http", + "host", + 80, + "/path", + (1, 1), + headers, + "content_request", + None, + None ) - return flow.Flow(req) + f = http.HTTPFlow(tutils.tclient_conn(), None) + f.request = req + return f def resp(self): f = self.req() headers = flow.ODictCaseless() headers["header_response"] = ["svalue"] - f.response = flow.Response( - f.request, - (1, 1), - 200, - "message", - headers, - "content_response", - None - ) + f.response = http.HTTPResponse((1, 1), 200, "OK", headers, "content_response", None, None) + return f def err(self): f = self.req() - f.error = flow.Error(f.request, "msg") + f.error = Error("msg") return f def q(self, q, o): diff --git a/test/test_flow.py b/test/test_flow.py index f9198f0c..fbead1ca 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -1,7 +1,10 @@ import Queue, time, os.path from cStringIO import StringIO import email.utils -from libmproxy import filt, flow, controller, utils, tnetstring, proxy +from libmproxy import filt, protocol, controller, utils, tnetstring, proxy, flow +from libmproxy.protocol.primitives import Error, Flow +from libmproxy.protocol.http import decoded +from netlib import tcp import tutils @@ -10,8 +13,7 @@ def test_app_registry(): ar.add("foo", "domain", 80) r = tutils.treq() - r.host = "domain" - r.port = 80 + r.set_url("http://domain:80/") assert ar.get(r) r.port = 81 @@ -30,7 +32,7 @@ class TestStickyCookieState: def _response(self, cookie, host): s = flow.StickyCookieState(filt.parse(".*")) f = tutils.tflow_full() - f.request.host = host + f.server_conn.address = tcp.Address((host, 80)) f.response.headers["Set-Cookie"] = [cookie] s.handle_response(f) return s, f @@ -66,7 +68,7 @@ class TestStickyAuthState: f = tutils.tflow_full() f.request.headers["authorization"] = ["foo"] s.handle_request(f) - assert "host" in s.hosts + assert "address" in s.hosts f = tutils.tflow_full() s.handle_request(f) @@ -171,8 +173,14 @@ class TestServerPlaybackState: class TestFlow: def test_copy(self): f = tutils.tflow_full() + a0 = f._get_state() f2 = f.copy() + a = f._get_state() + b = f2._get_state() + assert f._get_state() == f2._get_state() + assert not f == f2 assert not f is f2 + assert f.request == f2.request assert not f.request is f2.request assert f.request.headers == f2.request.headers assert not f.request.headers is f2.request.headers @@ -189,9 +197,7 @@ class TestFlow: assert not f.error is f2.error def test_match(self): - f = tutils.tflow() - f.response = tutils.tresp() - f.request = f.response.request + f = tutils.tflow_full() assert not f.match("~b test") assert f.match(None) assert not f.match("~b test") @@ -201,11 +207,9 @@ class TestFlow: tutils.raises(ValueError, f.match, "~") - def test_backup(self): f = tutils.tflow() f.response = tutils.tresp() - f.request = f.response.request f.request.content = "foo" assert not f.modified() f.backup() @@ -222,18 +226,19 @@ class TestFlow: f.revert() def test_getset_state(self): - f = tutils.tflow() - f.response = tutils.tresp(f.request) + f = tutils.tflow_full() state = f._get_state() - assert f._get_state() == flow.Flow._from_state(state)._get_state() + assert f._get_state() == protocol.http.HTTPFlow._from_state(state)._get_state() f.response = None - f.error = flow.Error(f.request, "error") + f.error = Error("error") state = f._get_state() - assert f._get_state() == flow.Flow._from_state(state)._get_state() + assert f._get_state() == protocol.http.HTTPFlow._from_state(state)._get_state() - f2 = tutils.tflow() - f2.error = flow.Error(f.request, "e2") + f2 = f.copy() + assert f._get_state() == f2._get_state() + assert not f == f2 + f2.error = Error("e2") assert not f == f2 f._load_state(f2._get_state()) assert f._get_state() == f2._get_state() @@ -249,7 +254,6 @@ class TestFlow: assert f.request.reply.acked f.intercept() f.response = tutils.tresp() - f.request = f.response.request f.request.reply() assert not f.response.reply.acked f.kill(fm) @@ -279,17 +283,12 @@ class TestFlow: f.accept_intercept() assert f.request.reply.acked f.response = tutils.tresp() - f.request = f.response.request f.intercept() f.request.reply() assert not f.response.reply.acked f.accept_intercept() assert f.response.reply.acked - def test_serialization(self): - f = flow.Flow(None) - f.request = tutils.treq() - def test_replace_unicode(self): f = tutils.tflow_full() f.response.content = "\xc2foo" @@ -310,10 +309,6 @@ class TestFlow: assert f.response.headers["bar"] == ["bar"] assert f.response.content == "abarb" - f = tutils.tflow_err() - f.replace("error", "bar") - assert f.error.msg == "bar" - def test_replace_encoded(self): f = tutils.tflow_full() f.request.content = "afoob" @@ -348,30 +343,27 @@ class TestState: connect -> request -> response """ - bc = flow.ClientConnect(("address", 22)) + bc = tutils.tclient_conn() c = flow.State() req = tutils.treq(bc) f = c.add_request(req) assert f assert c.flow_count() == 1 - assert c._flow_map.get(req) assert c.active_flow_count() == 1 newreq = tutils.treq() assert c.add_request(newreq) - assert c._flow_map.get(newreq) assert c.active_flow_count() == 2 resp = tutils.tresp(req) assert c.add_response(resp) assert c.flow_count() == 2 - assert c._flow_map.get(resp.request) assert c.active_flow_count() == 1 unseen_resp = tutils.tresp() + unseen_resp.flow = None assert not c.add_response(unseen_resp) - assert not c._flow_map.get(unseen_resp.request) assert c.active_flow_count() == 1 resp = tutils.tresp(newreq) @@ -382,19 +374,18 @@ class TestState: c = flow.State() req = tutils.treq() f = c.add_request(req) - e = flow.Error(f.request, "message") - assert c.add_error(e) + f.error = Error("message") + assert c.add_error(f.error) - e = flow.Error(tutils.tflow().request, "message") + e = Error("message") assert not c.add_error(e) c = flow.State() req = tutils.treq() f = c.add_request(req) - e = flow.Error(f.request, "message") + e = tutils.terr() c.set_limit("~e") assert not c.view - assert not c.view assert c.add_error(e) assert c.view @@ -448,7 +439,7 @@ class TestState: def _add_error(self, state): req = tutils.treq() f = state.add_request(req) - f.error = flow.Error(f.request, "msg") + f.error = Error("msg") def test_clear(self): c = flow.State() @@ -472,7 +463,7 @@ class TestState: c.clear() c.load_flows(flows) - assert isinstance(c._flow_list[0], flow.Flow) + assert isinstance(c._flow_list[0], Flow) def test_accept_all(self): c = flow.State() @@ -585,7 +576,7 @@ class TestFlowMaster: fm = flow.FlowMaster(None, s) assert not fm.load_script(tutils.test_data.path("scripts/reqerr.py")) req = tutils.treq() - fm.handle_clientconnect(req.client_conn) + fm.handle_clientconnect(req.flow.client_conn) assert fm.handle_request(req) def test_script(self): @@ -593,9 +584,9 @@ class TestFlowMaster: fm = flow.FlowMaster(None, s) assert not fm.load_script(tutils.test_data.path("scripts/all.py")) req = tutils.treq() - fm.handle_clientconnect(req.client_conn) + fm.handle_clientconnect(req.flow.client_conn) assert fm.scripts[0].ns["log"][-1] == "clientconnect" - sc = proxy.ServerConnection(None, req.scheme, req.host, req.port, None) + sc = proxy.ServerConnection((req.get_host(), req.get_port()), None) sc.reply = controller.DummyReply() fm.handle_serverconnection(sc) assert fm.scripts[0].ns["log"][-1] == "serverconnect" @@ -607,9 +598,7 @@ class TestFlowMaster: #load second script assert not fm.load_script(tutils.test_data.path("scripts/all.py")) assert len(fm.scripts) == 2 - dc = flow.ClientDisconnect(req.client_conn) - dc.reply = controller.DummyReply() - fm.handle_clientdisconnect(dc) + fm.handle_clientdisconnect(sc) assert fm.scripts[0].ns["log"][-1] == "clientdisconnect" assert fm.scripts[1].ns["log"][-1] == "clientdisconnect" @@ -619,7 +608,7 @@ class TestFlowMaster: assert len(fm.scripts) == 0 assert not fm.load_script(tutils.test_data.path("scripts/all.py")) - err = flow.Error(f.request, "msg") + err = tutils.terr() err.reply = controller.DummyReply() fm.handle_error(err) assert fm.scripts[0].ns["log"][-1] == "error" @@ -633,7 +622,7 @@ class TestFlowMaster: f2 = fm.duplicate_flow(f) assert f2.response assert s.flow_count() == 2 - assert s.index(f2) + assert s.index(f2) == 1 def test_all(self): s = flow.State() @@ -641,7 +630,7 @@ class TestFlowMaster: fm.anticache = True fm.anticomp = True req = tutils.treq() - fm.handle_clientconnect(req.client_conn) + fm.handle_clientconnect(req.flow.client_conn) f = fm.handle_request(req) assert s.flow_count() == 1 @@ -651,16 +640,14 @@ class TestFlowMaster: assert s.flow_count() == 1 rx = tutils.tresp() + rx.flow = None assert not fm.handle_response(rx) - dc = flow.ClientDisconnect(req.client_conn) - dc.reply = controller.DummyReply() - req.client_conn.requestcount = 1 - fm.handle_clientdisconnect(dc) + fm.handle_clientdisconnect(req.flow.client_conn) - err = flow.Error(f.request, "msg") - err.reply = controller.DummyReply() - fm.handle_error(err) + f.error = Error("msg") + f.error.reply = controller.DummyReply() + fm.handle_error(f.error) fm.load_script(tutils.test_data.path("scripts/a.py")) fm.shutdown() @@ -679,9 +666,9 @@ class TestFlowMaster: fm.tick(q) assert fm.state.flow_count() - err = flow.Error(f.request, "error") - err.reply = controller.DummyReply() - fm.handle_error(err) + f.error = Error("error") + f.error.reply = controller.DummyReply() + fm.handle_error(f.error) def test_server_playback(self): controller.should_exit = False @@ -784,20 +771,16 @@ class TestFlowMaster: assert r()[0].response - tf = tutils.tflow_full() + tf = tutils.tflow() fm.start_stream(file(p, "ab"), None) fm.handle_request(tf.request) fm.shutdown() assert not r()[1].response - class TestRequest: def test_simple(self): - h = flow.ODictCaseless() - h["test"] = ["test"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() u = r.get_url() assert r.set_url(u) assert not r.set_url("") @@ -812,31 +795,34 @@ class TestRequest: assert r._assemble() assert r.size() == len(r._assemble()) - r.close = True - assert "connection: close" in r._assemble() - - assert r._assemble(True) - r.content = flow.CONTENT_MISSING - assert not r._assemble() + tutils.raises("Cannot assemble flow with CONTENT_MISSING", r._assemble) def test_get_url(self): - h = flow.ODictCaseless() - h["test"] = ["test"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - assert r.get_url() == "https://host:22/" - assert r.get_url(hostheader=True) == "https://host:22/" + r = tutils.tflow().request + + assert r.get_url() == "http://address:22/path" + + r.flow.server_conn.ssl_established = True + assert r.get_url() == "https://address:22/path" + + r.flow.server_conn.address = tcp.Address(("host", 42)) + assert r.get_url() == "https://host:42/path" + + r.host = "address" + r.port = 22 + assert r.get_url() == "https://address:22/path" + + assert r.get_url(hostheader=True) == "https://address:22/path" r.headers["Host"] = ["foo.com"] - assert r.get_url() == "https://host:22/" - assert r.get_url(hostheader=True) == "https://foo.com:22/" + assert r.get_url() == "https://address:22/path" + assert r.get_url(hostheader=True) == "https://foo.com:22/path" def test_path_components(self): - h = flow.ODictCaseless() - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.path = "/" assert r.get_path_components() == [] - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/foo/bar", h, "content") + r.path = "/foo/bar" assert r.get_path_components() == ["foo", "bar"] q = flow.ODict() q["test"] = ["123"] @@ -852,10 +838,9 @@ class TestRequest: assert "%2F" in r.path def test_getset_form_urlencoded(self): - h = flow.ODictCaseless() - h["content-type"] = [flow.HDR_FORM_URLENCODED] d = flow.ODict([("one", "two"), ("three", "four")]) - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/", h, utils.urlencode(d.lst)) + r = tutils.treq(content=utils.urlencode(d.lst)) + r.headers["content-type"] = [protocol.http.HDR_FORM_URLENCODED] assert r.get_form_urlencoded() == d d = flow.ODict([("x", "y")]) @@ -868,19 +853,20 @@ class TestRequest: def test_getset_query(self): h = flow.ODictCaseless() - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/foo?x=y&a=b", h, "content") + r = tutils.treq() + r.path = "/foo?x=y&a=b" q = r.get_query() assert q.lst == [("x", "y"), ("a", "b")] - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r.path = "/" q = r.get_query() assert not q - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/?adsfa", h, "content") + r.path = "/?adsfa" q = r.get_query() assert q.lst == [("adsfa", "")] - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/foo?x=y&a=b", h, "content") + r.path = "/foo?x=y&a=b" assert r.get_query() r.set_query(flow.ODict([])) assert not r.get_query() @@ -890,34 +876,14 @@ class TestRequest: def test_anticache(self): h = flow.ODictCaseless() - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.headers = h h["if-modified-since"] = ["test"] h["if-none-match"] = ["test"] r.anticache() assert not "if-modified-since" in r.headers assert not "if-none-match" in r.headers - def test_getset_state(self): - h = flow.ODictCaseless() - h["test"] = ["test"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - state = r._get_state() - assert flow.Request._from_state(state) == r - - r.client_conn = None - state = r._get_state() - assert flow.Request._from_state(state) == r - - r2 = flow.Request(c, (1, 1), "testing", 20, "http", "PUT", "/foo", h, "test") - assert not r == r2 - r._load_state(r2._get_state()) - assert r == r2 - - r2.client_conn = None - r._load_state(r2._get_state()) - assert not r.client_conn - def test_replace(self): r = tutils.treq() r.path = "path/foo" @@ -975,15 +941,15 @@ class TestRequest: def test_get_cookies_none(self): h = flow.ODictCaseless() - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - assert r.get_cookies() == None + r = tutils.treq() + r.headers = h + assert r.get_cookies() is None def test_get_cookies_single(self): h = flow.ODictCaseless() h["Cookie"] = ["cookiename=cookievalue"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.headers = h result = r.get_cookies() assert len(result)==1 assert result['cookiename']==('cookievalue',{}) @@ -991,8 +957,8 @@ class TestRequest: def test_get_cookies_double(self): h = flow.ODictCaseless() h["Cookie"] = ["cookiename=cookievalue;othercookiename=othercookievalue"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.headers = h result = r.get_cookies() assert len(result)==2 assert result['cookiename']==('cookievalue',{}) @@ -1001,49 +967,35 @@ class TestRequest: def test_get_cookies_withequalsign(self): h = flow.ODictCaseless() h["Cookie"] = ["cookiename=coo=kievalue;othercookiename=othercookievalue"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.headers = h result = r.get_cookies() assert len(result)==2 assert result['cookiename']==('coo=kievalue',{}) assert result['othercookiename']==('othercookievalue',{}) - def test_get_header_size(self): + def test_header_size(self): h = flow.ODictCaseless() h["headername"] = ["headervalue"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - result = r.get_header_size() - assert result==43 - - def test_get_transmitted_size(self): - h = flow.ODictCaseless() - h["headername"] = ["headervalue"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - result = r.get_transmitted_size() - assert result==len("content") - r.content = None - assert r.get_transmitted_size() == 0 + r = tutils.treq() + r.headers = h + result = len(r._assemble_headers()) + assert result == 62 def test_get_content_type(self): h = flow.ODictCaseless() h["Content-Type"] = ["text/plain"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - assert r.get_content_type()=="text/plain" + resp = tutils.tresp() + resp.headers = h + assert resp.headers.get_first("content-type") == "text/plain" class TestResponse: def test_simple(self): - h = flow.ODictCaseless() - h["test"] = ["test"] - c = flow.ClientConnect(("addr", 2222)) - req = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - resp = flow.Response(req, (1, 1), 200, "msg", h.copy(), "content", None) + f = tutils.tflow_full() + resp = f.response assert resp._assemble() assert resp.size() == len(resp._assemble()) - resp2 = resp.copy() assert resp2 == resp @@ -1052,7 +1004,7 @@ class TestResponse: assert resp.size() == len(resp._assemble()) resp.content = flow.CONTENT_MISSING - assert not resp._assemble() + tutils.raises("Cannot assemble flow with CONTENT_MISSING", resp._assemble) def test_refresh(self): r = tutils.tresp() @@ -1081,21 +1033,6 @@ class TestResponse: c = "MOO=BAR; Expires=Tue, 08-Mar-2011 00:20:38 GMT; Path=foo.com; Secure" assert "00:21:38" in r._refresh_cookie(c, 60) - def test_getset_state(self): - h = flow.ODictCaseless() - h["test"] = ["test"] - c = flow.ClientConnect(("addr", 2222)) - req = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - resp = flow.Response(req, (1, 1), 200, "msg", h.copy(), "content", None) - - state = resp._get_state() - assert flow.Response._from_state(req, state) == resp - - resp2 = flow.Response(req, (1, 1), 220, "foo", h.copy(), "test", None) - assert not resp == resp2 - resp._load_state(resp2._get_state()) - assert resp == resp2 - def test_replace(self): r = tutils.tresp() r.headers["Foo"] = ["fOo"] @@ -1108,7 +1045,7 @@ class TestResponse: r = tutils.tresp() r.headers["content-encoding"] = ["identity"] r.content = "falafel" - r.decode() + assert r.decode() assert not r.headers["content-encoding"] assert r.content == "falafel" @@ -1125,24 +1062,30 @@ class TestResponse: r.encode("gzip") assert r.headers["content-encoding"] == ["gzip"] assert r.content != "falafel" - r.decode() + assert r.decode() assert not r.headers["content-encoding"] assert r.content == "falafel" - def test_get_header_size(self): + r.headers["content-encoding"] = ["gzip"] + assert not r.decode() + assert r.content == "falafel" + + def test_header_size(self): r = tutils.tresp() - result = r.get_header_size() - assert result==49 + result = len(r._assemble_headers()) + assert result==44 def test_get_cookies_none(self): h = flow.ODictCaseless() - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) + resp = tutils.tresp() + resp.headers = h assert not resp.get_cookies() def test_get_cookies_simple(self): h = flow.ODictCaseless() h["Set-Cookie"] = ["cookiename=cookievalue"] - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) + resp = tutils.tresp() + resp.headers = h result = resp.get_cookies() assert len(result)==1 assert "cookiename" in result @@ -1151,7 +1094,8 @@ class TestResponse: def test_get_cookies_with_parameters(self): h = flow.ODictCaseless() h["Set-Cookie"] = ["cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly"] - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) + resp = tutils.tresp() + resp.headers = h result = resp.get_cookies() assert len(result)==1 assert "cookiename" in result @@ -1165,7 +1109,8 @@ class TestResponse: def test_get_cookies_no_value(self): h = flow.ODictCaseless() h["Set-Cookie"] = ["cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/"] - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) + resp = tutils.tresp() + resp.headers = h result = resp.get_cookies() assert len(result)==1 assert "cookiename" in result @@ -1175,7 +1120,8 @@ class TestResponse: def test_get_cookies_twocookies(self): h = flow.ODictCaseless() h["Set-Cookie"] = ["cookiename=cookievalue","othercookie=othervalue"] - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) + resp = tutils.tresp() + resp.headers = h result = resp.get_cookies() assert len(result)==2 assert "cookiename" in result @@ -1186,19 +1132,20 @@ class TestResponse: def test_get_content_type(self): h = flow.ODictCaseless() h["Content-Type"] = ["text/plain"] - resp = flow.Response(None, (1, 1), 200, "OK", h, "content", None) - assert resp.get_content_type()=="text/plain" + resp = tutils.tresp() + resp.headers = h + assert resp.headers.get_first("content-type") == "text/plain" class TestError: def test_getset_state(self): - e = flow.Error(None, "Error") + e = Error("Error") state = e._get_state() - assert flow.Error._from_state(None, state) == e + assert Error._from_state(state) == e assert e.copy() - e2 = flow.Error(None, "bar") + e2 = Error("bar") assert not e == e2 e._load_state(e2._get_state()) assert e == e2 @@ -1207,23 +1154,20 @@ class TestError: e3 = e.copy() assert e3 == e - def test_replace(self): - e = flow.Error(None, "amoop") - e.replace("moo", "bar") - assert e.msg == "abarp" - -class TestClientConnect: +class TestClientConnection: def test_state(self): - c = flow.ClientConnect(("a", 22)) - assert flow.ClientConnect._from_state(c._get_state()) == c - c2 = flow.ClientConnect(("a", 25)) + c = tutils.tclient_conn() + assert proxy.ClientConnection._from_state(c._get_state()) == c + + c2 = tutils.tclient_conn() + c2.address.address = (c2.address.host, 4242) assert not c == c2 - c2.requestcount = 99 + c2.timestamp_start = 42 c._load_state(c2._get_state()) - assert c.requestcount == 99 + assert c.timestamp_start == 42 c3 = c.copy() assert c3 == c @@ -1238,13 +1182,13 @@ def test_decoded(): r.encode("gzip") assert r.headers["content-encoding"] assert r.content != "content" - with flow.decoded(r): + with decoded(r): assert not r.headers["content-encoding"] assert r.content == "content" assert r.headers["content-encoding"] assert r.content != "content" - with flow.decoded(r): + with decoded(r): r.content = "foo" assert r.content != "foo" diff --git a/test/test_fuzzing.py b/test/test_fuzzing.py index ba7b751c..646ce5c1 100644 --- a/test/test_fuzzing.py +++ b/test/test_fuzzing.py @@ -32,8 +32,8 @@ class TestFuzzy(tservers.HTTPProxTest): assert p.request(req%self.server.port).status_code == 502 def test_upstream_disconnect(self): - req = r'200:d0:h"Date"="Sun, 03 Mar 2013 04:00:00 GMT"' + req = r'200:d0' p = self.pathod(req) - assert p.status_code == 400 + assert p.status_code == 502 diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py new file mode 100644 index 00000000..3bf5af22 --- /dev/null +++ b/test/test_protocol_http.py @@ -0,0 +1,201 @@ +from libmproxy import proxy # FIXME: Remove +from libmproxy.protocol.http import * +from libmproxy.protocol import KILL +from cStringIO import StringIO +import tutils, tservers + + +def test_HttpAuthenticationError(): + x = HttpAuthenticationError({"foo": "bar"}) + assert str(x) + assert "foo" in x.auth_headers + + +def test_stripped_chunked_encoding_no_content(): + """ + https://github.com/mitmproxy/mitmproxy/issues/186 + """ + r = tutils.tresp(content="") + r.headers["Transfer-Encoding"] = ["chunked"] + assert "Content-Length" in r._assemble_headers() + + r = tutils.treq(content="") + r.headers["Transfer-Encoding"] = ["chunked"] + assert "Content-Length" in r._assemble_headers() + + +class TestHTTPRequest: + def test_asterisk_form(self): + s = StringIO("OPTIONS * HTTP/1.1") + f = tutils.tflow_noreq() + f.request = HTTPRequest.from_stream(s) + assert f.request.form_in == "asterisk" + x = f.request._assemble() + assert f.request._assemble() == "OPTIONS * HTTP/1.1\r\nHost: address:22\r\n\r\n" + + def test_origin_form(self): + s = StringIO("GET /foo\xff HTTP/1.1") + tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) + + def test_authority_form(self): + s = StringIO("CONNECT oops-no-port.com HTTP/1.1") + tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) + s = StringIO("CONNECT address:22 HTTP/1.1") + r = HTTPRequest.from_stream(s) + assert r._assemble() == "CONNECT address:22 HTTP/1.1\r\nHost: address:22\r\n\r\n" + + + def test_absolute_form(self): + s = StringIO("GET oops-no-protocol.com HTTP/1.1") + tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) + s = StringIO("GET http://address:22/ HTTP/1.1") + r = HTTPRequest.from_stream(s) + assert r._assemble() == "GET http://address:22/ HTTP/1.1\r\nHost: address:22\r\n\r\n" + + def test_assemble_unknown_form(self): + r = tutils.treq() + tutils.raises("Invalid request form", r._assemble, "antiauthority") + + + def test_set_url(self): + r = tutils.treq_absolute() + r.set_url("https://otheraddress:42/ORLY") + assert r.scheme == "https" + assert r.host == "otheraddress" + assert r.port == 42 + assert r.path == "/ORLY" + + +class TestHTTPResponse: + def test_read_from_stringio(self): + _s = "HTTP/1.1 200 OK\r\n" \ + "Content-Length: 7\r\n" \ + "\r\n"\ + "content\r\n" \ + "HTTP/1.1 204 OK\r\n" \ + "\r\n" + s = StringIO(_s) + r = HTTPResponse.from_stream(s, "GET") + assert r.code == 200 + assert r.content == "content" + assert HTTPResponse.from_stream(s, "GET").code == 204 + + s = StringIO(_s) + r = HTTPResponse.from_stream(s, "HEAD") # HEAD must not have content by spec. We should leave it on the pipe. + assert r.code == 200 + assert r.content == "" + tutils.raises("Invalid server response: 'content", HTTPResponse.from_stream, s, "GET") + + +class TestInvalidRequests(tservers.HTTPProxTest): + ssl = True + + def test_double_connect(self): + p = self.pathoc() + r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port)) + assert r.status_code == 502 + assert "Must not CONNECT on already encrypted connection" in r.content + + def test_origin_request(self): + p = self.pathoc_raw() + p.connect() + r = p.request("get:/p/200") + assert r.status_code == 400 + assert "Invalid request form" in r.content + + +class TestProxyChaining(tservers.HTTPChainProxyTest): + def test_all(self): + self.chain[1].tmaster.replacehooks.add("~q", "foo", "bar") # replace in request + self.chain[0].tmaster.replacehooks.add("~q", "foo", "oh noes!") + self.proxy.tmaster.replacehooks.add("~q", "bar", "baz") + self.chain[0].tmaster.replacehooks.add("~s", "baz", "ORLY") # replace in response + + p = self.pathoc() + req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) + assert req.content == "ORLY" + assert req.status_code == 418 + +class TestProxyChainingSSL(tservers.HTTPChainProxyTest): + ssl = True + + def test_simple(self): + + p = self.pathoc() + req = p.request("get:'/p/418:b\"content\"'") + assert req.content == "content" + assert req.status_code == 418 + + assert self.chain[1].tmaster.state.flow_count() == 2 # CONNECT from pathoc to chain[0], + # request from pathoc to chain[0] + assert self.chain[0].tmaster.state.flow_count() == 2 # CONNECT from chain[1] to proxy, + # request from chain[1] to proxy + assert self.proxy.tmaster.state.flow_count() == 1 # request from chain[0] (regular proxy doesn't store CONNECTs) + +class TestProxyChainingSSLReconnect(tservers.HTTPChainProxyTest): + ssl = True + + def test_reconnect(self): + """ + Tests proper functionality of ConnectionHandler.server_reconnect mock. + If we have a disconnect on a secure connection that's transparently proxified to + an upstream http proxy, we need to send the CONNECT request again. + """ + def kill_requests(master, attr, exclude): + k = [0] # variable scope workaround: put into array + _func = getattr(master, attr) + def handler(r): + k[0] += 1 + if not (k[0] in exclude): + r.flow.client_conn.finish() + r.flow.error = Error("terminated") + r.reply(KILL) + return _func(r) + setattr(master, attr, handler) + + kill_requests(self.proxy.tmaster, "handle_request", + exclude=[ + # fail first request + 2, # allow second request + ]) + + kill_requests(self.chain[0].tmaster, "handle_request", + exclude=[ + 1, # CONNECT + # fail first request + 3, # reCONNECT + 4, # request + ]) + + p = self.pathoc() + req = p.request("get:'/p/418:b\"content\"'") + assert self.chain[1].tmaster.state.flow_count() == 2 # CONNECT and request + assert self.chain[0].tmaster.state.flow_count() == 4 # CONNECT, failing request, + # reCONNECT, request + assert self.proxy.tmaster.state.flow_count() == 2 # failing request, request + # (doesn't store (repeated) CONNECTs from chain[0] + # as it is a regular proxy) + assert req.content == "content" + assert req.status_code == 418 + + assert not self.proxy.tmaster.state._flow_list[0].response # killed + assert self.proxy.tmaster.state._flow_list[1].response + + assert self.chain[1].tmaster.state._flow_list[0].request.form_in == "authority" + assert self.chain[1].tmaster.state._flow_list[1].request.form_in == "origin" + + assert self.chain[0].tmaster.state._flow_list[0].request.form_in == "authority" + assert self.chain[0].tmaster.state._flow_list[1].request.form_in == "origin" + assert self.chain[0].tmaster.state._flow_list[2].request.form_in == "authority" + assert self.chain[0].tmaster.state._flow_list[3].request.form_in == "origin" + + assert self.proxy.tmaster.state._flow_list[0].request.form_in == "origin" + assert self.proxy.tmaster.state._flow_list[1].request.form_in == "origin" + + req = p.request("get:'/p/418:b\"content2\"'") + + assert req.status_code == 502 + assert self.chain[1].tmaster.state.flow_count() == 3 # + new request + assert self.chain[0].tmaster.state.flow_count() == 6 # + new request, repeated CONNECT from chain[1] + # (both terminated) + assert self.proxy.tmaster.state.flow_count() == 2 # nothing happened here diff --git a/test/test_proxy.py b/test/test_proxy.py index 371e5ef7..c42d66e7 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -19,25 +19,24 @@ class TestServerConnection: self.d.shutdown() def test_simple(self): - sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com") + sc = proxy.ServerConnection((self.d.IFACE, self.d.port), None) sc.connect() r = tutils.treq() + r.flow.server_conn = sc r.path = "/p/200:da" - sc.send(r) + sc.send(r._assemble()) assert http.read_response(sc.rfile, r.method, 1000) assert self.d.last_log() - r.content = flow.CONTENT_MISSING - tutils.raises("incomplete request", sc.send, r) - - sc.terminate() + sc.finish() def test_terminate_error(self): - sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com") + sc = proxy.ServerConnection((self.d.IFACE, self.d.port), None) sc.connect() sc.connection = mock.Mock() + sc.connection.recv = mock.Mock(return_value=False) sc.connection.flush = mock.Mock(side_effect=tcp.NetLibDisconnect) - sc.terminate() + sc.finish() class MockParser: diff --git a/test/test_script.py b/test/test_script.py index 025e9f37..13903066 100644 --- a/test/test_script.py +++ b/test/test_script.py @@ -32,8 +32,8 @@ class TestScript: r = tutils.treq() fm.handle_request(r) assert fm.state.flow_count() == 2 - assert not fm.state.view[0].request.is_replay() - assert fm.state.view[1].request.is_replay() + assert not fm.state.view[0].request.is_replay + assert fm.state.view[1].request.is_replay def test_err(self): s = flow.State() @@ -75,9 +75,6 @@ class TestScript: # Two instantiations assert m.call_count == 2 assert (time.time() - t_start) < 0.09 - time.sleep(0.2) - # Plus two invocations - assert m.call_count == 4 def test_concurrent2(self): s = flow.State() @@ -89,13 +86,17 @@ class TestScript: f.reply = f.request.reply with mock.patch("libmproxy.controller.DummyReply.__call__") as m: + t_start = time.time() s.run("clientconnect", f) s.run("serverconnect", f) s.run("response", f) s.run("error", f) s.run("clientdisconnect", f) - time.sleep(0.1) - assert m.call_count == 5 + while (time.time() - t_start) < 1 and m.call_count <= 5: + if m.call_count == 5: + return + time.sleep(0.001) + assert False def test_concurrent_err(self): s = flow.State() diff --git a/test/test_server.py b/test/test_server.py index 646460ab..2f9e6728 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -4,6 +4,7 @@ from netlib import tcp, http_auth, http from libpathod import pathoc, pathod import tutils, tservers from libmproxy import flow, proxy +from libmproxy.protocol import KILL """ Note that the choice of response code in these tests matters more than you @@ -41,16 +42,17 @@ class CommonMixin: assert f.status_code == 304 l = self.master.state.view[0] - assert l.request.client_conn.address + assert l.client_conn.address assert "host" in l.request.headers assert l.response.code == 304 def test_invalid_http(self): - t = tcp.TCPClient("127.0.0.1", self.proxy.port) + t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) t.connect() t.wfile.write("invalid\r\n\r\n") t.wfile.flush() - assert "Bad Request" in t.rfile.readline() + line = t.rfile.readline() + assert ("Bad Request" in line) or ("Bad Gateway" in line) @@ -70,7 +72,7 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin): assert "ValueError" in ret.content def test_invalid_connect(self): - t = tcp.TCPClient("127.0.0.1", self.proxy.port) + t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) t.connect() t.wfile.write("CONNECT invalid\n\n") t.wfile.flush() @@ -105,22 +107,17 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin): assert p.request(req) assert p.request(req) - # However, if the server disconnects on our first try, it's an error. - req = "get:'%s/p/200:b@1:d0'"%self.server.urlbase - p = self.pathoc() - tutils.raises("server disconnect", p.request, req) - def test_proxy_ioerror(self): # Tests a difficult-to-trigger condition, where an IOError is raised # within our read loop. - with mock.patch("libmproxy.proxy.ProxyHandler.read_request") as m: + with mock.patch("libmproxy.protocol.http.HTTPRequest.from_stream") as m: m.side_effect = IOError("error!") tutils.raises("server disconnect", self.pathod, "304") def test_get_connection_switching(self): def switched(l): for i in l: - if "switching" in i: + if "serverdisconnect" in i: return True req = "get:'%s/p/200:b@1'" p = self.pathoc() @@ -156,6 +153,7 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin): connection.close() assert "content-length" in resp.lower() + class TestHTTPAuth(tservers.HTTPProxTest): authenticator = http_auth.BasicProxyAuth(http_auth.PassManSingleUser("test", "test"), "realm") def test_auth(self): @@ -230,12 +228,13 @@ class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin): f = self.pathod("304", sni="testserver.com") assert f.status_code == 304 l = self.server.last_log() - assert self.server.last_log()["request"]["sni"] == "testserver.com" + assert l["request"]["sni"] == "testserver.com" def test_sslerr(self): - p = pathoc.Pathoc("localhost", self.proxy.port) + p = pathoc.Pathoc(("localhost", self.proxy.port)) p.connect() - assert p.request("get:/").status_code == 400 + r = p.request("get:/") + assert r.status_code == 502 class TestProxy(tservers.HTTPProxTest): @@ -243,10 +242,10 @@ class TestProxy(tservers.HTTPProxTest): f = self.pathod("304") assert f.status_code == 304 - l = self.master.state.view[0] - assert l.request.client_conn.address - assert "host" in l.request.headers - assert l.response.code == 304 + f = self.master.state.view[0] + assert f.client_conn.address + assert "host" in f.request.headers + assert f.response.code == 304 def test_response_timestamps(self): # test that we notice at least 2 sec delay between timestamps @@ -288,8 +287,7 @@ class TestProxy(tservers.HTTPProxTest): assert request.timestamp_end - request.timestamp_start <= 0.1 def test_request_tcp_setup_timestamp_presence(self): - # tests that the first request in a tcp connection has a tcp_setup_timestamp - # while others do not + # tests that the client_conn a tcp connection has a tcp_setup_timestamp connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection.connect(("localhost", self.proxy.port)) connection.send("GET http://localhost:%d/p/304:b@1k HTTP/1.1\r\n"%self.server.port) @@ -300,18 +298,18 @@ class TestProxy(tservers.HTTPProxTest): connection.recv(5000) connection.close() - first_request = self.master.state.view[0].request - second_request = self.master.state.view[1].request - assert first_request.tcp_setup_timestamp - assert first_request.ssl_setup_timestamp == None - assert second_request.tcp_setup_timestamp == None - assert second_request.ssl_setup_timestamp == None + first_flow = self.master.state.view[0] + second_flow = self.master.state.view[1] + assert first_flow.server_conn.timestamp_tcp_setup + assert first_flow.server_conn.timestamp_ssl_setup is None + assert second_flow.server_conn.timestamp_tcp_setup + assert first_flow.server_conn.timestamp_tcp_setup == second_flow.server_conn.timestamp_tcp_setup def test_request_ip(self): f = self.pathod("200:b@100") assert f.status_code == 200 - request = self.master.state.view[0].request - assert request.ip == "127.0.0.1" + f = self.master.state.view[0] + assert f.server_conn.peername == ("127.0.0.1", self.server.port) class TestProxySSL(tservers.HTTPProxTest): ssl=True @@ -320,7 +318,7 @@ class TestProxySSL(tservers.HTTPProxTest): f = self.pathod("304:b@10k") assert f.status_code == 304 first_request = self.master.state.view[0].request - assert first_request.ssl_setup_timestamp + assert first_request.flow.server_conn.timestamp_ssl_setup class MasterFakeResponse(tservers.TestMaster): def handle_request(self, m): @@ -335,10 +333,9 @@ class TestFakeResponse(tservers.HTTPProxTest): assert "header_response" in f.headers.keys() - class MasterKillRequest(tservers.TestMaster): def handle_request(self, m): - m.reply(proxy.KILL) + m.reply(KILL) class TestKillRequest(tservers.HTTPProxTest): @@ -351,7 +348,7 @@ class TestKillRequest(tservers.HTTPProxTest): class MasterKillResponse(tservers.TestMaster): def handle_response(self, m): - m.reply(proxy.KILL) + m.reply(KILL) class TestKillResponse(tservers.HTTPProxTest): diff --git a/test/tservers.py b/test/tservers.py index ac95b168..812e8921 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -21,13 +21,12 @@ def errapp(environ, start_response): class TestMaster(flow.FlowMaster): - def __init__(self, testq, config): + def __init__(self, config): s = proxy.ProxyServer(config, 0) state = flow.State() flow.FlowMaster.__init__(self, s, state) self.apps.add(testapp, "testapp", 80) self.apps.add(errapp, "errapp", 80) - self.testq = testq self.clear_log() self.start_app(APP_HOST, APP_PORT, False) @@ -51,11 +50,12 @@ class ProxyThread(threading.Thread): def __init__(self, tmaster): threading.Thread.__init__(self) self.tmaster = tmaster + self.name = "ProxyThread (%s:%s)" % (tmaster.server.address.host, tmaster.server.address.port) controller.should_exit = False @property def port(self): - return self.tmaster.server.port + return self.tmaster.server.address.port @property def log(self): @@ -68,7 +68,7 @@ class ProxyThread(threading.Thread): self.tmaster.shutdown() -class ProxTestBase: +class ProxTestBase(object): # Test Configuration ssl = None ssloptions = False @@ -79,17 +79,16 @@ class ProxTestBase: masterclass = TestMaster @classmethod def setupAll(cls): - cls.tqueue = Queue.Queue() cls.server = libpathod.test.Daemon(ssl=cls.ssl, ssloptions=cls.ssloptions) cls.server2 = libpathod.test.Daemon(ssl=cls.ssl, ssloptions=cls.ssloptions) pconf = cls.get_proxy_config() config = proxy.ProxyConfig( no_upstream_cert = cls.no_upstream_cert, - cacert = tutils.test_data.path("data/serverkey.pem"), + cacert = tutils.test_data.path("data/confdir/mitmproxy-ca.pem"), authenticator = cls.authenticator, **pconf ) - tmaster = cls.masterclass(cls.tqueue, config) + tmaster = cls.masterclass(config) cls.proxy = ProxyThread(tmaster) cls.proxy.start() @@ -134,13 +133,13 @@ class ProxTestBase: class HTTPProxTest(ProxTestBase): def pathoc_raw(self): - return libpathod.pathoc.Pathoc("127.0.0.1", self.proxy.port) + return libpathod.pathoc.Pathoc(("127.0.0.1", self.proxy.port)) def pathoc(self, sni=None): """ Returns a connected Pathoc instance. """ - p = libpathod.pathoc.Pathoc("localhost", self.proxy.port, ssl=self.ssl, sni=sni) + p = libpathod.pathoc.Pathoc(("localhost", self.proxy.port), ssl=self.ssl, sni=sni) if self.ssl: p.connect(("127.0.0.1", self.server.port)) else: @@ -161,10 +160,8 @@ class HTTPProxTest(ProxTestBase): def app(self, page): if self.ssl: - p = libpathod.pathoc.Pathoc("127.0.0.1", self.proxy.port, True) - print "PRE" + p = libpathod.pathoc.Pathoc(("127.0.0.1", self.proxy.port), True) p.connect((APP_HOST, APP_PORT)) - print "POST" return p.request("get:'/%s'"%page) else: p = self.pathoc() @@ -211,7 +208,7 @@ class TransparentProxTest(ProxTestBase): """ Returns a connected Pathoc instance. """ - p = libpathod.pathoc.Pathoc("localhost", self.proxy.port, ssl=self.ssl, sni=sni) + p = libpathod.pathoc.Pathoc(("localhost", self.proxy.port), ssl=self.ssl, sni=sni) p.connect() return p @@ -232,7 +229,7 @@ class ReverseProxTest(ProxTestBase): """ Returns a connected Pathoc instance. """ - p = libpathod.pathoc.Pathoc("localhost", self.proxy.port, ssl=self.ssl, sni=sni) + p = libpathod.pathoc.Pathoc(("localhost", self.proxy.port), ssl=self.ssl, sni=sni) p.connect() return p @@ -249,5 +246,49 @@ class ReverseProxTest(ProxTestBase): return p.request(q) +class ChainProxTest(ProxTestBase): + """ + Chain n instances of mitmproxy in a row - because we can. + """ + n = 2 + chain_config = [lambda: proxy.ProxyConfig( + cacert = tutils.test_data.path("data/confdir/mitmproxy-ca.pem"), + )] * n + @classmethod + def setupAll(cls): + super(ChainProxTest, cls).setupAll() + cls.chain = [] + for i in range(cls.n): + config = cls.chain_config[i]() + config.forward_proxy = ("http", "127.0.0.1", + cls.proxy.port if i == 0 else + cls.chain[-1].port + ) + tmaster = cls.masterclass(config) + cls.chain.append(ProxyThread(tmaster)) + cls.chain[-1].start() + @classmethod + def teardownAll(cls): + super(ChainProxTest, cls).teardownAll() + for p in cls.chain: + p.tmaster.server.shutdown() + def setUp(self): + super(ChainProxTest, self).setUp() + for p in self.chain: + p.tmaster.clear_log() + p.tmaster.state.clear() + + +class HTTPChainProxyTest(ChainProxTest): + def pathoc(self, sni=None): + """ + Returns a connected Pathoc instance. + """ + p = libpathod.pathoc.Pathoc(("localhost", self.chain[-1].port), ssl=self.ssl, sni=sni) + if self.ssl: + p.connect(("127.0.0.1", self.server.port)) + else: + p.connect() + return p diff --git a/test/tutils.py b/test/tutils.py index fb41d77a..75fb7c0b 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -1,12 +1,15 @@ import os, shutil, tempfile from contextlib import contextmanager -from libmproxy import flow, utils, controller -if os.name != "nt": - from libmproxy.console.flowview import FlowView - from libmproxy.console import ConsoleState +from libmproxy import flow, utils, controller, proxy +from libmproxy.protocol import http +import mock_urwid +from libmproxy.console.flowview import FlowView +from libmproxy.console import ConsoleState +from libmproxy.protocol.primitives import Error from netlib import certutils from nose.plugins.skip import SkipTest from mock import Mock +from time import time def _SkipWindows(): raise SkipTest("Skipped on Windows.") @@ -16,40 +19,82 @@ def SkipWindows(fn): else: return fn + +def tclient_conn(): + c = proxy.ClientConnection._from_state(dict( + address=dict(address=("address", 22), use_ipv6=True), + clientcert=None + )) + c.reply = controller.DummyReply() + return c + + +def tserver_conn(): + c = proxy.ServerConnection._from_state(dict( + address=dict(address=("address", 22), use_ipv6=True), + source_address=dict(address=("address", 22), use_ipv6=True), + cert=None + )) + c.reply = controller.DummyReply() + return c + + +def treq_absolute(conn=None, content="content"): + r = treq(conn, content) + r.form_in = r.form_out = "absolute" + r.host = "address" + r.port = 22 + r.scheme = "http" + return r + def treq(conn=None, content="content"): if not conn: - conn = flow.ClientConnect(("address", 22)) - conn.reply = controller.DummyReply() + conn = tclient_conn() + server_conn = tserver_conn() headers = flow.ODictCaseless() headers["header"] = ["qvalue"] - r = flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, - content) - r.reply = controller.DummyReply() - return r + f = http.HTTPFlow(conn, server_conn) + f.request = http.HTTPRequest("origin", "GET", None, None, None, "/path", (1, 1), headers, content, + None, None, None) + f.request.reply = controller.DummyReply() + return f.request -def tresp(req=None): + +def tresp(req=None, content="message"): if not req: req = treq() + f = req.flow + headers = flow.ODictCaseless() headers["header_response"] = ["svalue"] - cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert"),"rb").read()) - resp = flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert) - resp.reply = controller.DummyReply() - return resp + cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert"), "rb").read()) + f.server_conn = proxy.ServerConnection._from_state(dict( + address=dict(address=("address", 22), use_ipv6=True), + source_address=None, + cert=cert.to_pem())) + f.response = http.HTTPResponse((1, 1), 200, "OK", headers, content, time(), time()) + f.response.reply = controller.DummyReply() + return f.response + def terr(req=None): if not req: req = treq() - err = flow.Error(req, "error") - err.reply = controller.DummyReply() - return err + f = req.flow + f.error = Error("error") + f.error.reply = controller.DummyReply() + return f.error +def tflow_noreq(): + f = tflow() + f.request = None + return f -def tflow(r=None): - if r == None: - r = treq() - return flow.Flow(r) +def tflow(req=None): + if not req: + req = treq() + return req.flow def tflow_full(): |